source.utils.custom_classes

Custom PyTorch Lightning classes for Garbage Classification.

This module contains custom implementations of:

  • GarbageClassifier: ResNet18-based classifier
  • GarbageDataModule: Data loading and preprocessing
  • LossCurveCallback: Training metrics visualization
 1"""
 2Custom PyTorch Lightning classes for Garbage Classification.
 3
 4This module contains custom implementations of:
 5- GarbageClassifier: ResNet18-based classifier
 6- GarbageDataModule: Data loading and preprocessing
 7- LossCurveCallback: Training metrics visualization
 8"""
 9
10from source.utils.custom_classes.GarbageClassifier import GarbageClassifier
11from source.utils.custom_classes.GarbageDataModule import GarbageDataModule
12from source.utils.custom_classes.LossCurveCallback import LossCurveCallback
13
14__all__ = [
15    "GarbageClassifier",
16    "GarbageDataModule",
17    "LossCurveCallback",
18]
class GarbageClassifier(pytorch_lightning.core.module.LightningModule):
 26class GarbageClassifier(pl.LightningModule):
 27    """
 28    Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification
 29    problem.
 30    It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
 31
 32    Attributes
 33    ----------
 34    model : torchvision.models.resnet18
 35        Pretrained ResNet18 model.
 36    loss_fn : torch.nn.CrossEntropyLoss
 37        Cross entropy loss function.
 38
 39    Examples
 40    --------
 41    >>> model = GarbageClassifier(num_classes=6, lr=1e-3)
 42    >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
 43    >>> trainer.fit(model, datamodule=data_module)
 44    """
 45
 46    def __init__(self, num_classes, lr=1e-3):
 47        """
 48        Initialize the GarbageClassifier model.
 49
 50        Parameters
 51        ----------
 52        num_classes : int
 53            Number of output classes for classification.
 54        lr : float, optional
 55            Learning rate for the optimizer (default is 1e-3).
 56        """
 57        super().__init__()
 58        self.save_hyperparameters()
 59        self.model = models.resnet18(
 60            weights=models.ResNet18_Weights.IMAGENET1K_V1
 61        )
 62
 63        # Freeze feature layers
 64        for param in self.model.parameters():
 65            param.requires_grad = False
 66
 67        # New classifier layer
 68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
 69        self.loss_fn = nn.CrossEntropyLoss()
 70
 71    def forward(self, x):
 72        """
 73        Forward pass through the model.
 74
 75        Parameters
 76        ----------
 77        x : torch.Tensor
 78            Input tensor of images with shape (batch_size, channels, height,
 79            width).
 80
 81        Returns
 82        -------
 83        torch.Tensor
 84            Output logits with shape (batch_size, num_classes).
 85        """
 86        return self.model(x)
 87
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109
110        # Accuracy
111        preds = out.argmax(dim=1)
112        acc = (preds == yb).float().mean()
113
114        self.log(
115            "train_loss",
116            loss,
117            on_step=False,
118            on_epoch=True,
119            prog_bar=True
120        )
121
122        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
123
124        return loss
125
126    def validation_step(self, batch, batch_idx):
127        """
128        Compute validation loss and accuracy for a batch of validation images.
129
130        Parameters
131        ----------
132        batch : Tuple[torch.Tensor, torch.Tensor]
133            Subset (batch) of images coming from the validation dataloader.
134            Contains input images and corresponding labels.
135        batch_idx : int
136            Identifier of the batch within the current validation epoch.
137
138        Returns
139        -------
140        torch.Tensor
141            Validation accuracy for the current batch.
142        """
143        xb, yb = batch
144        out = self(xb)
145        loss = self.loss_fn(out, yb)
146        preds = out.argmax(dim=1)
147        acc = (preds == yb).float().mean()
148        self.log(
149            "val_loss",
150            loss,
151            on_step=False,
152            on_epoch=True,
153            prog_bar=False
154        )
155        self.log(
156            "val_acc",
157            acc,
158            on_step=False,
159            on_epoch=True,
160            prog_bar=True
161        )
162        return acc
163
164    def configure_optimizers(self):
165        """
166        Configure the optimizer for training.
167
168        Returns
169        -------
170        torch.optim.Adam
171            Adam optimizer configured with model parameters and learning rate
172            from hyperparameters.
173        """
174        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.

Attributes
  • model (torchvision.models.resnet18): Pretrained ResNet18 model.
  • loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
GarbageClassifier(num_classes, lr=0.001)
46    def __init__(self, num_classes, lr=1e-3):
47        """
48        Initialize the GarbageClassifier model.
49
50        Parameters
51        ----------
52        num_classes : int
53            Number of output classes for classification.
54        lr : float, optional
55            Learning rate for the optimizer (default is 1e-3).
56        """
57        super().__init__()
58        self.save_hyperparameters()
59        self.model = models.resnet18(
60            weights=models.ResNet18_Weights.IMAGENET1K_V1
61        )
62
63        # Freeze feature layers
64        for param in self.model.parameters():
65            param.requires_grad = False
66
67        # New classifier layer
68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
69        self.loss_fn = nn.CrossEntropyLoss()

Initialize the GarbageClassifier model.

Parameters
  • num_classes (int): Number of output classes for classification.
  • lr (float, optional): Learning rate for the optimizer (default is 1e-3).
model
loss_fn
def forward(self, x):
71    def forward(self, x):
72        """
73        Forward pass through the model.
74
75        Parameters
76        ----------
77        x : torch.Tensor
78            Input tensor of images with shape (batch_size, channels, height,
79            width).
80
81        Returns
82        -------
83        torch.Tensor
84            Output logits with shape (batch_size, num_classes).
85        """
86        return self.model(x)

Forward pass through the model.

Parameters
  • x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
  • torch.Tensor: Output logits with shape (batch_size, num_classes).
def training_step(self, batch, batch_idx):
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109
110        # Accuracy
111        preds = out.argmax(dim=1)
112        acc = (preds == yb).float().mean()
113
114        self.log(
115            "train_loss",
116            loss,
117            on_step=False,
118            on_epoch=True,
119            prog_bar=True
120        )
121
122        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
123
124        return loss

Model parameters are updated according to the classification error of a subset of train images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current epoch.
Returns
  • torch.Tensor: Classification error (cross entropy loss) of trained image batch.
def validation_step(self, batch, batch_idx):
126    def validation_step(self, batch, batch_idx):
127        """
128        Compute validation loss and accuracy for a batch of validation images.
129
130        Parameters
131        ----------
132        batch : Tuple[torch.Tensor, torch.Tensor]
133            Subset (batch) of images coming from the validation dataloader.
134            Contains input images and corresponding labels.
135        batch_idx : int
136            Identifier of the batch within the current validation epoch.
137
138        Returns
139        -------
140        torch.Tensor
141            Validation accuracy for the current batch.
142        """
143        xb, yb = batch
144        out = self(xb)
145        loss = self.loss_fn(out, yb)
146        preds = out.argmax(dim=1)
147        acc = (preds == yb).float().mean()
148        self.log(
149            "val_loss",
150            loss,
151            on_step=False,
152            on_epoch=True,
153            prog_bar=False
154        )
155        self.log(
156            "val_acc",
157            acc,
158            on_step=False,
159            on_epoch=True,
160            prog_bar=True
161        )
162        return acc

Compute validation loss and accuracy for a batch of validation images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
  • torch.Tensor: Validation accuracy for the current batch.
def configure_optimizers(self):
164    def configure_optimizers(self):
165        """
166        Configure the optimizer for training.
167
168        Returns
169        -------
170        torch.optim.Adam
171            Adam optimizer configured with model parameters and learning rate
172            from hyperparameters.
173        """
174        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Configure the optimizer for training.

Returns
  • torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.
class GarbageDataModule(pytorch_lightning.core.datamodule.LightningDataModule):
 20class GarbageDataModule(pl.LightningDataModule):
 21    """
 22    PyTorch Lightning DataModule for Garbage Classification Dataset.
 23
 24    This DataModule handles loading, splitting, and creating dataloaders for
 25    the garbage classification dataset. It performs a stratified 90/10
 26    train/test split and applies ResNet18 ImageNet preprocessing transforms.
 27
 28    Attributes
 29    ----------
 30    batch_size : int
 31        Number of samples per batch for training.
 32    num_workers : int
 33        Number of subprocesses to use for data loading.
 34    transform : torchvision.transforms.Compose
 35        Image preprocessing transforms from ResNet18 ImageNet weights.
 36    train_dataset : torch.utils.data.Subset
 37        Training dataset subset.
 38    test_dataset : torch.utils.data.Subset
 39        Test/validation dataset subset.
 40    train_idx : numpy.ndarray
 41        Indices of samples in the training set.
 42    test_idx : numpy.ndarray
 43        Indices of samples in the test set.
 44    num_classes : int
 45        Number of classes in the dataset.
 46
 47    Examples
 48    --------
 49    >>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
 50    >>> data_module.setup()
 51    >>> train_loader = data_module.train_dataloader()
 52    >>> val_loader = data_module.val_dataloader()
 53    """
 54
 55    def __init__(self, batch_size=32, num_workers=4):
 56        """
 57        Initialize the GarbageDataModule.
 58
 59        Parameters
 60        ----------
 61        batch_size : int, optional
 62            Number of samples per batch for training (default is 32).
 63        num_workers : int, optional
 64            Number of subprocesses to use for data loading (default is 4).
 65        """
 66
 67        super().__init__()
 68        self.batch_size = batch_size
 69        self.num_workers = num_workers
 70        self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 71
 72    def setup(self, stage=None):
 73        """
 74        Prepare the dataset by loading and splitting into train and test sets.
 75
 76        Loads the full dataset from the configured path, performs a stratified
 77        90/10 train/test split to ensure balanced class distribution, and
 78        creates dataset subsets for training and validation.
 79
 80        Parameters
 81        ----------
 82        stage : str, optional
 83            Current stage ('fit', 'validate', 'test', or 'predict').
 84            Not used in this implementation (default is None).
 85
 86        Notes
 87        -----
 88        The dataset is split using a stratified approach with random_state=42
 89        for reproducibility. The split ratio is 90% training and 10% testing.
 90        """
 91
 92        # Load full dataset
 93        full_dataset = datasets.ImageFolder(
 94            cfg.DATASET_PATH, transform=self.transform
 95        )
 96        targets = [label for _, label in full_dataset]
 97        self.num_classes = cfg.NUM_CLASSES
 98
 99        # Stratified split 90/10
100        train_idx, test_idx = train_test_split(
101            np.arange(len(targets)), test_size=0.1,
102            stratify=targets, random_state=42
103        )
104
105        self.train_dataset = Subset(full_dataset, train_idx)
106        self.test_dataset = Subset(full_dataset, test_idx)
107        self.train_idx = train_idx
108        self.test_idx = test_idx
109
110    def train_dataloader(self):
111        """
112        Create and return the training dataloader.
113
114        Returns
115        -------
116        torch.utils.data.DataLoader
117            DataLoader for the training dataset with shuffling enabled.
118
119        Notes
120        -----
121        The dataloader uses the configured batch_size and num_workers,
122        and shuffles the data at each epoch.
123        """
124
125        return DataLoader(
126            self.train_dataset,
127            batch_size=self.batch_size,
128            shuffle=True,
129            num_workers=self.num_workers,
130        )
131
132    def val_dataloader(self):
133        """
134        Create and return the validation dataloader.
135
136        Returns
137        -------
138        torch.utils.data.DataLoader
139            DataLoader for the validation/test dataset without shuffling.
140
141        Notes
142        -----
143        The dataloader uses a fixed batch_size of 1000 for faster validation,
144        with num_workers from configuration. Shuffling is disabled to ensure
145        consistent validation metrics.
146        """
147
148        return DataLoader(
149            self.test_dataset,
150            batch_size=1000,
151            shuffle=False,
152            num_workers=self.num_workers,
153        )

PyTorch Lightning DataModule for Garbage Classification Dataset.

This DataModule handles loading, splitting, and creating dataloaders for the garbage classification dataset. It performs a stratified 90/10 train/test split and applies ResNet18 ImageNet preprocessing transforms.

Attributes
  • batch_size (int): Number of samples per batch for training.
  • num_workers (int): Number of subprocesses to use for data loading.
  • transform (torchvision.transforms.Compose): Image preprocessing transforms from ResNet18 ImageNet weights.
  • train_dataset (torch.utils.data.Subset): Training dataset subset.
  • test_dataset (torch.utils.data.Subset): Test/validation dataset subset.
  • train_idx (numpy.ndarray): Indices of samples in the training set.
  • test_idx (numpy.ndarray): Indices of samples in the test set.
  • num_classes (int): Number of classes in the dataset.
Examples
>>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
>>> data_module.setup()
>>> train_loader = data_module.train_dataloader()
>>> val_loader = data_module.val_dataloader()
GarbageDataModule(batch_size=32, num_workers=4)
55    def __init__(self, batch_size=32, num_workers=4):
56        """
57        Initialize the GarbageDataModule.
58
59        Parameters
60        ----------
61        batch_size : int, optional
62            Number of samples per batch for training (default is 32).
63        num_workers : int, optional
64            Number of subprocesses to use for data loading (default is 4).
65        """
66
67        super().__init__()
68        self.batch_size = batch_size
69        self.num_workers = num_workers
70        self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()

Initialize the GarbageDataModule.

Parameters
  • batch_size (int, optional): Number of samples per batch for training (default is 32).
  • num_workers (int, optional): Number of subprocesses to use for data loading (default is 4).
batch_size
num_workers
transform
def setup(self, stage=None):
 72    def setup(self, stage=None):
 73        """
 74        Prepare the dataset by loading and splitting into train and test sets.
 75
 76        Loads the full dataset from the configured path, performs a stratified
 77        90/10 train/test split to ensure balanced class distribution, and
 78        creates dataset subsets for training and validation.
 79
 80        Parameters
 81        ----------
 82        stage : str, optional
 83            Current stage ('fit', 'validate', 'test', or 'predict').
 84            Not used in this implementation (default is None).
 85
 86        Notes
 87        -----
 88        The dataset is split using a stratified approach with random_state=42
 89        for reproducibility. The split ratio is 90% training and 10% testing.
 90        """
 91
 92        # Load full dataset
 93        full_dataset = datasets.ImageFolder(
 94            cfg.DATASET_PATH, transform=self.transform
 95        )
 96        targets = [label for _, label in full_dataset]
 97        self.num_classes = cfg.NUM_CLASSES
 98
 99        # Stratified split 90/10
100        train_idx, test_idx = train_test_split(
101            np.arange(len(targets)), test_size=0.1,
102            stratify=targets, random_state=42
103        )
104
105        self.train_dataset = Subset(full_dataset, train_idx)
106        self.test_dataset = Subset(full_dataset, test_idx)
107        self.train_idx = train_idx
108        self.test_idx = test_idx

Prepare the dataset by loading and splitting into train and test sets.

Loads the full dataset from the configured path, performs a stratified 90/10 train/test split to ensure balanced class distribution, and creates dataset subsets for training and validation.

Parameters
  • stage (str, optional): Current stage ('fit', 'validate', 'test', or 'predict'). Not used in this implementation (default is None).
Notes

The dataset is split using a stratified approach with random_state=42 for reproducibility. The split ratio is 90% training and 10% testing.

def train_dataloader(self):
110    def train_dataloader(self):
111        """
112        Create and return the training dataloader.
113
114        Returns
115        -------
116        torch.utils.data.DataLoader
117            DataLoader for the training dataset with shuffling enabled.
118
119        Notes
120        -----
121        The dataloader uses the configured batch_size and num_workers,
122        and shuffles the data at each epoch.
123        """
124
125        return DataLoader(
126            self.train_dataset,
127            batch_size=self.batch_size,
128            shuffle=True,
129            num_workers=self.num_workers,
130        )

Create and return the training dataloader.

Returns
  • torch.utils.data.DataLoader: DataLoader for the training dataset with shuffling enabled.
Notes

The dataloader uses the configured batch_size and num_workers, and shuffles the data at each epoch.

def val_dataloader(self):
132    def val_dataloader(self):
133        """
134        Create and return the validation dataloader.
135
136        Returns
137        -------
138        torch.utils.data.DataLoader
139            DataLoader for the validation/test dataset without shuffling.
140
141        Notes
142        -----
143        The dataloader uses a fixed batch_size of 1000 for faster validation,
144        with num_workers from configuration. Shuffling is disabled to ensure
145        consistent validation metrics.
146        """
147
148        return DataLoader(
149            self.test_dataset,
150            batch_size=1000,
151            shuffle=False,
152            num_workers=self.num_workers,
153        )

Create and return the validation dataloader.

Returns
  • torch.utils.data.DataLoader: DataLoader for the validation/test dataset without shuffling.
Notes

The dataloader uses a fixed batch_size of 1000 for faster validation, with num_workers from configuration. Shuffling is disabled to ensure consistent validation metrics.

class LossCurveCallback(pytorch_lightning.callbacks.callback.Callback):
 20class LossCurveCallback(Callback):
 21    """
 22    PyTorch Lightning callback for tracking and plotting loss curves.
 23
 24    This callback monitors training loss, validation loss, and validation
 25    accuracy throughout the training process. At the end of training, it
 26    generates and saves visualization plots and raw metric data.
 27
 28    Attributes
 29    ----------
 30    save_dir : str
 31        Directory path where plots and metrics will be saved.
 32    train_losses : list of float
 33        Training loss values collected at the end of each training epoch.
 34    val_losses : list of float
 35        Validation loss values collected at the end of each validation epoch.
 36    val_accs : list of float
 37        Validation accuracy values collected at the end of each validation
 38        epoch.
 39
 40    Examples
 41    --------
 42    >>> from pytorch_lightning import Trainer
 43    >>> callback = LossCurveCallback(save_dir="./outputs/curves")
 44    >>> trainer = Trainer(callbacks=[callback])
 45    >>> trainer.fit(model, datamodule=data_module)
 46
 47    Notes
 48    -----
 49    The callback creates three output files in the save_dir:
 50    - loss_curve.png: Plot of training and validation losses
 51    - val_acc_curve.png: Plot of validation accuracy
 52    - metrics.json: Raw metric data in JSON format
 53    """
 54
 55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
 56        """
 57        Initialize the LossCurveCallback.
 58
 59        Parameters
 60        ----------
 61        save_dir : str, optional
 62            Directory path where plots and metrics will be saved
 63            (default is cfg.LOSS_CURVES_PATH).
 64
 65        Notes
 66        -----
 67        The save directory is created automatically if it does not exist.
 68        """
 69
 70        super().__init__()
 71        self.save_dir = save_dir
 72        os.makedirs(self.save_dir, exist_ok=True)
 73        self.train_losses = []
 74        self.train_accs = []
 75        self.val_losses = []
 76        self.val_accs = []
 77
 78    # ---------- Train loss per epoch ----------
 79    def on_train_epoch_end(self, trainer, pl_module):
 80        """
 81        Called at the end of each training epoch to collect training loss.
 82
 83        Parameters
 84        ----------
 85        trainer : pytorch_lightning.Trainer
 86            The PyTorch Lightning trainer instance.
 87        pl_module : pytorch_lightning.LightningModule
 88            The LightningModule being trained.
 89
 90        Notes
 91        -----
 92        Extracts the 'train_loss' metric from trainer.callback_metrics and
 93        appends it to the train_losses list.
 94        """
 95
 96        metrics = trainer.callback_metrics
 97        if "train_loss" in metrics:
 98            self.train_losses.append(metrics["train_loss"].item())
 99        if "train_acc" in metrics:
100            self.train_accs.append(metrics["train_acc"].item())
101
102    # ---------- Val loss and acc per epoch ----------
103    def on_validation_epoch_end(self, trainer, pl_module):
104        """
105        Called at the end of each validation epoch to collect validation
106        metrics.
107
108        Parameters
109        ----------
110        trainer : pytorch_lightning.Trainer
111            The PyTorch Lightning trainer instance.
112        pl_module : pytorch_lightning.LightningModule
113            The LightningModule being validated.
114
115        Notes
116        -----
117        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
118        and appends them to their respective lists.
119        """
120
121        metrics = trainer.callback_metrics
122        if "val_loss" in metrics:
123            self.val_losses.append(metrics["val_loss"].item())
124        if "val_acc" in metrics:
125            self.val_accs.append(metrics["val_acc"].item())
126
127    def on_train_end(self, trainer, pl_module):
128        """
129        Called at the end of training to generate and save plots and metrics.
130
131        Parameters
132        ----------
133        trainer : pytorch_lightning.Trainer
134            The PyTorch Lightning trainer instance.
135        pl_module : pytorch_lightning.LightningModule
136            The trained LightningModule.
137
138        Notes
139        -----
140        This method performs three main tasks:
141        1. Generates and saves a loss curve plot (loss_curve.png) showing
142           training loss and validation loss over epochs.
143        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
144           if validation accuracy was tracked.
145        3. Saves all raw metric data to a JSON file (metrics.json) for later
146           analysis or reproduction.
147
148        All output files are saved to the directory specified in save_dir.
149        """
150
151        # ---------- Save curves as a PNG ----------
152        plt.figure()
153        plt.plot(self.train_losses, label="Train Loss")
154        if len(self.val_losses) > 0:
155            plt.plot(self.val_losses, label="Val Loss")
156        plt.legend()
157        plt.title("Loss Curves")
158        plt.xlabel("Epochs")
159        plt.ylabel("Loss")
160        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
161        plt.close()
162
163        plt.figure()
164        plt.plot(self.train_accs, label="Train Accuracy")
165        if len(self.val_accs) > 0:
166            plt.plot(self.val_accs, label="Val Accuracy")
167        plt.legend()
168        plt.title("Accuracy Curves")
169        plt.xlabel("Epochs")
170        plt.ylabel("Accuracy")
171        plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
172        plt.close()
173
174        # ---------- Save raw data ----------
175        data = {
176            "train_losses": self.train_losses,
177            "val_losses": self.val_losses,
178            "train_accs": self.train_accs,
179            "val_accs": self.val_accs,
180        }
181
182        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
183            json.dump(data, f)

PyTorch Lightning callback for tracking and plotting loss curves.

This callback monitors training loss, validation loss, and validation accuracy throughout the training process. At the end of training, it generates and saves visualization plots and raw metric data.

Attributes
  • save_dir (str): Directory path where plots and metrics will be saved.
  • train_losses (list of float): Training loss values collected at the end of each training epoch.
  • val_losses (list of float): Validation loss values collected at the end of each validation epoch.
  • val_accs (list of float): Validation accuracy values collected at the end of each validation epoch.
Examples
>>> from pytorch_lightning import Trainer
>>> callback = LossCurveCallback(save_dir="./outputs/curves")
>>> trainer = Trainer(callbacks=[callback])
>>> trainer.fit(model, datamodule=data_module)
Notes

The callback creates three output files in the save_dir:

  • loss_curve.png: Plot of training and validation losses
  • val_acc_curve.png: Plot of validation accuracy
  • metrics.json: Raw metric data in JSON format
LossCurveCallback( save_dir='/home/alumno/Desktop/datos/SDOML/garbage_classifier/ models/performance/loss_curves/')
55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
56        """
57        Initialize the LossCurveCallback.
58
59        Parameters
60        ----------
61        save_dir : str, optional
62            Directory path where plots and metrics will be saved
63            (default is cfg.LOSS_CURVES_PATH).
64
65        Notes
66        -----
67        The save directory is created automatically if it does not exist.
68        """
69
70        super().__init__()
71        self.save_dir = save_dir
72        os.makedirs(self.save_dir, exist_ok=True)
73        self.train_losses = []
74        self.train_accs = []
75        self.val_losses = []
76        self.val_accs = []

Initialize the LossCurveCallback.

Parameters
  • save_dir (str, optional): Directory path where plots and metrics will be saved (default is cfg.LOSS_CURVES_PATH).
Notes

The save directory is created automatically if it does not exist.

save_dir
train_losses
train_accs
val_losses
val_accs
def on_train_epoch_end(self, trainer, pl_module):
 79    def on_train_epoch_end(self, trainer, pl_module):
 80        """
 81        Called at the end of each training epoch to collect training loss.
 82
 83        Parameters
 84        ----------
 85        trainer : pytorch_lightning.Trainer
 86            The PyTorch Lightning trainer instance.
 87        pl_module : pytorch_lightning.LightningModule
 88            The LightningModule being trained.
 89
 90        Notes
 91        -----
 92        Extracts the 'train_loss' metric from trainer.callback_metrics and
 93        appends it to the train_losses list.
 94        """
 95
 96        metrics = trainer.callback_metrics
 97        if "train_loss" in metrics:
 98            self.train_losses.append(metrics["train_loss"].item())
 99        if "train_acc" in metrics:
100            self.train_accs.append(metrics["train_acc"].item())

Called at the end of each training epoch to collect training loss.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being trained.
Notes

Extracts the 'train_loss' metric from trainer.callback_metrics and appends it to the train_losses list.

def on_validation_epoch_end(self, trainer, pl_module):
103    def on_validation_epoch_end(self, trainer, pl_module):
104        """
105        Called at the end of each validation epoch to collect validation
106        metrics.
107
108        Parameters
109        ----------
110        trainer : pytorch_lightning.Trainer
111            The PyTorch Lightning trainer instance.
112        pl_module : pytorch_lightning.LightningModule
113            The LightningModule being validated.
114
115        Notes
116        -----
117        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
118        and appends them to their respective lists.
119        """
120
121        metrics = trainer.callback_metrics
122        if "val_loss" in metrics:
123            self.val_losses.append(metrics["val_loss"].item())
124        if "val_acc" in metrics:
125            self.val_accs.append(metrics["val_acc"].item())

Called at the end of each validation epoch to collect validation metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being validated.
Notes

Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics and appends them to their respective lists.

def on_train_end(self, trainer, pl_module):
127    def on_train_end(self, trainer, pl_module):
128        """
129        Called at the end of training to generate and save plots and metrics.
130
131        Parameters
132        ----------
133        trainer : pytorch_lightning.Trainer
134            The PyTorch Lightning trainer instance.
135        pl_module : pytorch_lightning.LightningModule
136            The trained LightningModule.
137
138        Notes
139        -----
140        This method performs three main tasks:
141        1. Generates and saves a loss curve plot (loss_curve.png) showing
142           training loss and validation loss over epochs.
143        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
144           if validation accuracy was tracked.
145        3. Saves all raw metric data to a JSON file (metrics.json) for later
146           analysis or reproduction.
147
148        All output files are saved to the directory specified in save_dir.
149        """
150
151        # ---------- Save curves as a PNG ----------
152        plt.figure()
153        plt.plot(self.train_losses, label="Train Loss")
154        if len(self.val_losses) > 0:
155            plt.plot(self.val_losses, label="Val Loss")
156        plt.legend()
157        plt.title("Loss Curves")
158        plt.xlabel("Epochs")
159        plt.ylabel("Loss")
160        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
161        plt.close()
162
163        plt.figure()
164        plt.plot(self.train_accs, label="Train Accuracy")
165        if len(self.val_accs) > 0:
166            plt.plot(self.val_accs, label="Val Accuracy")
167        plt.legend()
168        plt.title("Accuracy Curves")
169        plt.xlabel("Epochs")
170        plt.ylabel("Accuracy")
171        plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
172        plt.close()
173
174        # ---------- Save raw data ----------
175        data = {
176            "train_losses": self.train_losses,
177            "val_losses": self.val_losses,
178            "train_accs": self.train_accs,
179            "val_accs": self.val_accs,
180        }
181
182        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
183            json.dump(data, f)

Called at the end of training to generate and save plots and metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The trained LightningModule.
Notes

This method performs three main tasks:

  1. Generates and saves a loss curve plot (loss_curve.png) showing training loss and validation loss over epochs.
  2. Generates and saves a validation accuracy plot (val_acc_curve.png) if validation accuracy was tracked.
  3. Saves all raw metric data to a JSON file (metrics.json) for later analysis or reproduction.

All output files are saved to the directory specified in save_dir.