Skip to content

train

Autoencoder(base_channel_size, latent_dim, encoder_class=Encoder, decoder_class=Decoder, num_input_channels=3, width=32, height=32)

Bases: LightningModule

Source code in unsat/train.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def __init__(
    self,
    base_channel_size: int,
    latent_dim: int,
    encoder_class: object = Encoder,
    decoder_class: object = Decoder,
    num_input_channels: int = 3,
    width: int = 32,
    height: int = 32,
):
    super().__init__()
    # Saving hyperparameters of autoencoder
    self.save_hyperparameters()
    # Creating encoder and decoder
    self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
    self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
    # Example input array needed for visualizing the graph of the network
    self.example_input_array = torch.zeros(2, num_input_channels, width, height)

forward(x)

The forward function takes in an image and returns the reconstructed image.

Source code in unsat/train.py
187
188
189
190
191
192
193
def forward(self, x):
    """The forward function takes in an image and returns the reconstructed image."""
    print(x.shape, flush=True)
    breakpoint()
    z = self.encoder(x)
    x_hat = self.decoder(z)
    return x_hat

LightningTrainer(network, class_names, dimension, input_channels, optimizer, **kwargs)

Bases: LightningModule

Args: network (nn.Module): The network to train. class_names (List[str]): The names of the classes. dimension (int): The number of spatial dimensions. input_channels (int): The number of input channels.

Source code in unsat/train.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    network,
    class_names,
    dimension,
    input_channels,
    optimizer: OptimizerCallable,
    **kwargs,
):
    """
    Lightning module defining the network and the training loop.

    Args:
        network (nn.Module):
            The network to train.
        class_names (List[str]):
            The names of the classes.
        dimension (int):
            The number of spatial dimensions.
        input_channels (int):
            The number of input channels.
    """
    torch.autograd.set_detect_anomaly(True)
    super().__init__()
    self.optimizer = optimizer

    self.class_names = class_names
    self.num_classes = len(class_names)

    self.network = network
    self.network.dimension = dimension
    self.network.num_classes = self.num_classes
    self.network.input_channels = input_channels
    self.network.build()

    metrics_args = dict(task="multiclass", num_classes=self.num_classes, ignore_index=-1)
    self.metrics = torch.nn.ModuleDict()
    self.metrics["acc"] = torch.nn.ModuleDict(
        {
            "train_": Accuracy(**metrics_args, average="macro"),
            "val_": Accuracy(**metrics_args, average="macro"),
        }
    )
    self.metrics["f1"] = torch.nn.ModuleDict(
        {
            "train_": F1Score(**metrics_args, average="macro"),
            "val_": F1Score(**metrics_args, average="macro"),
        }
    )
    self.metrics["acc_per_class"] = torch.nn.ModuleDict(
        {
            "train_": ClasswiseWrapper(
                Accuracy(**metrics_args, average=None), labels=self.class_names
            ),
            "val_": ClasswiseWrapper(
                Accuracy(**metrics_args, average=None), labels=self.class_names
            ),
        }
    )
    self.metrics["f1_per_class"] = torch.nn.ModuleDict(
        {
            "train_": ClasswiseWrapper(
                F1Score(**metrics_args, average=None), labels=self.class_names
            ),
            "val_": ClasswiseWrapper(
                F1Score(**metrics_args, average=None), labels=self.class_names
            ),
        }
    )

    metrics_args["normalize"] = "true"
    self.metrics["confusion"] = torch.nn.ModuleDict(
        {"train_": ConfusionMatrix(**metrics_args), "val_": ConfusionMatrix(**metrics_args)}
    )

    # These can be overriden to represent class frequencies by using the ClassWeightsCallback
    self.class_weights = torch.ones(self.num_classes)

WandbSaveConfigCallback

Bases: SaveConfigCallback

Custom callback to save the lightning config to wandb.