source.utils.custom_classes.LossCurveCallback

PyTorch Lightning Callback for Loss and Accuracy Curve Visualization.

This module provides a custom callback that tracks and visualizes training and validation metrics during model training, saving plots and raw data to disk.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4PyTorch Lightning Callback for Loss and Accuracy Curve Visualization.
  5
  6This module provides a custom callback that tracks and visualizes training
  7and validation metrics during model training, saving plots and raw data to
  8disk.
  9"""
 10__docformat__ = "numpy"
 11
 12import os
 13import json
 14import matplotlib.pyplot as plt
 15from pytorch_lightning.callbacks import Callback
 16from source.utils import config as cfg
 17
 18
 19class LossCurveCallback(Callback):
 20    """
 21    PyTorch Lightning callback for tracking and plotting loss curves.
 22
 23    This callback monitors training loss, validation loss, and validation
 24    accuracy throughout the training process. At the end of training, it
 25    generates and saves visualization plots and raw metric data.
 26
 27    Attributes
 28    ----------
 29    save_dir : str
 30        Directory path where plots and metrics will be saved.
 31    train_losses : list of float
 32        Training loss values collected at the end of each training epoch.
 33    val_losses : list of float
 34        Validation loss values collected at the end of each validation epoch.
 35    val_accs : list of float
 36        Validation accuracy values collected at the end of each validation
 37        epoch.
 38
 39    Examples
 40    --------
 41    >>> from pytorch_lightning import Trainer
 42    >>> callback = LossCurveCallback(save_dir="./outputs/curves")
 43    >>> trainer = Trainer(callbacks=[callback])
 44    >>> trainer.fit(model, datamodule=data_module)
 45
 46    Notes
 47    -----
 48    The callback creates three output files in the save_dir:
 49    - loss_curve.png: Plot of training and validation losses
 50    - val_acc_curve.png: Plot of validation accuracy
 51    - metrics.json: Raw metric data in JSON format
 52    """
 53
 54    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
 55        """
 56        Initialize the LossCurveCallback.
 57
 58        Parameters
 59        ----------
 60        save_dir : str, optional
 61            Directory path where plots and metrics will be saved
 62            (default is cfg.LOSS_CURVES_PATH).
 63
 64        Notes
 65        -----
 66        The save directory is created automatically if it does not exist.
 67        """
 68
 69        super().__init__()
 70        self.save_dir = save_dir
 71        os.makedirs(self.save_dir, exist_ok=True)
 72        self.train_losses = []
 73        self.train_accs = []
 74        self.val_losses = []
 75        self.val_accs = []
 76
 77    # ---------- Train loss per epoch ----------
 78    def on_train_epoch_end(self, trainer, pl_module):
 79        """
 80        Called at the end of each training epoch to collect training loss.
 81
 82        Parameters
 83        ----------
 84        trainer : pytorch_lightning.Trainer
 85            The PyTorch Lightning trainer instance.
 86        pl_module : pytorch_lightning.LightningModule
 87            The LightningModule being trained.
 88
 89        Notes
 90        -----
 91        Extracts the 'train_loss' metric from trainer.callback_metrics and
 92        appends it to the train_losses list.
 93        """
 94
 95        metrics = trainer.callback_metrics
 96        if "train_loss" in metrics:
 97            self.train_losses.append(metrics["train_loss"].item())
 98        if "train_acc" in metrics:
 99            self.train_accs.append(metrics["train_acc"].item())
100
101    # ---------- Val loss and acc per epoch ----------
102    def on_validation_epoch_end(self, trainer, pl_module):
103        """
104        Called at the end of each validation epoch to collect validation
105        metrics.
106
107        Parameters
108        ----------
109        trainer : pytorch_lightning.Trainer
110            The PyTorch Lightning trainer instance.
111        pl_module : pytorch_lightning.LightningModule
112            The LightningModule being validated.
113
114        Notes
115        -----
116        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
117        and appends them to their respective lists.
118        """
119
120        metrics = trainer.callback_metrics
121        if "val_loss" in metrics:
122            self.val_losses.append(metrics["val_loss"].item())
123        if "val_acc" in metrics:
124            self.val_accs.append(metrics["val_acc"].item())
125
126    def on_train_end(self, trainer, pl_module):
127        """
128        Called at the end of training to generate and save plots and metrics.
129
130        Parameters
131        ----------
132        trainer : pytorch_lightning.Trainer
133            The PyTorch Lightning trainer instance.
134        pl_module : pytorch_lightning.LightningModule
135            The trained LightningModule.
136
137        Notes
138        -----
139        This method performs three main tasks:
140        1. Generates and saves a loss curve plot (loss_curve.png) showing
141           training loss and validation loss over epochs.
142        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
143           if validation accuracy was tracked.
144        3. Saves all raw metric data to a JSON file (metrics.json) for later
145           analysis or reproduction.
146
147        All output files are saved to the directory specified in save_dir.
148        """
149
150        # ---------- Save curves as a PNG ----------
151        plt.figure()
152        plt.plot(self.train_losses, label="Train Loss")
153        if len(self.val_losses) > 0:
154            plt.plot(self.val_losses, label="Val Loss")
155        plt.legend()
156        plt.title("Loss Curves")
157        plt.xlabel("Epochs")
158        plt.ylabel("Loss")
159        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
160        plt.close()
161
162        plt.figure()
163        plt.plot(self.train_accs, label="Train Accuracy")
164        if len(self.val_accs) > 0:
165            plt.plot(self.val_accs, label="Val Accuracy")
166        plt.legend()
167        plt.title("Accuracy Curves")
168        plt.xlabel("Epochs")
169        plt.ylabel("Accuracy")
170        plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
171        plt.close()
172
173        # ---------- Save raw data ----------
174        data = {
175            "train_losses": self.train_losses,
176            "val_losses": self.val_losses,
177            "train_accs": self.train_accs,
178            "val_accs": self.val_accs,
179        }
180
181        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
182            json.dump(data, f)
class LossCurveCallback(pytorch_lightning.callbacks.callback.Callback):
 20class LossCurveCallback(Callback):
 21    """
 22    PyTorch Lightning callback for tracking and plotting loss curves.
 23
 24    This callback monitors training loss, validation loss, and validation
 25    accuracy throughout the training process. At the end of training, it
 26    generates and saves visualization plots and raw metric data.
 27
 28    Attributes
 29    ----------
 30    save_dir : str
 31        Directory path where plots and metrics will be saved.
 32    train_losses : list of float
 33        Training loss values collected at the end of each training epoch.
 34    val_losses : list of float
 35        Validation loss values collected at the end of each validation epoch.
 36    val_accs : list of float
 37        Validation accuracy values collected at the end of each validation
 38        epoch.
 39
 40    Examples
 41    --------
 42    >>> from pytorch_lightning import Trainer
 43    >>> callback = LossCurveCallback(save_dir="./outputs/curves")
 44    >>> trainer = Trainer(callbacks=[callback])
 45    >>> trainer.fit(model, datamodule=data_module)
 46
 47    Notes
 48    -----
 49    The callback creates three output files in the save_dir:
 50    - loss_curve.png: Plot of training and validation losses
 51    - val_acc_curve.png: Plot of validation accuracy
 52    - metrics.json: Raw metric data in JSON format
 53    """
 54
 55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
 56        """
 57        Initialize the LossCurveCallback.
 58
 59        Parameters
 60        ----------
 61        save_dir : str, optional
 62            Directory path where plots and metrics will be saved
 63            (default is cfg.LOSS_CURVES_PATH).
 64
 65        Notes
 66        -----
 67        The save directory is created automatically if it does not exist.
 68        """
 69
 70        super().__init__()
 71        self.save_dir = save_dir
 72        os.makedirs(self.save_dir, exist_ok=True)
 73        self.train_losses = []
 74        self.train_accs = []
 75        self.val_losses = []
 76        self.val_accs = []
 77
 78    # ---------- Train loss per epoch ----------
 79    def on_train_epoch_end(self, trainer, pl_module):
 80        """
 81        Called at the end of each training epoch to collect training loss.
 82
 83        Parameters
 84        ----------
 85        trainer : pytorch_lightning.Trainer
 86            The PyTorch Lightning trainer instance.
 87        pl_module : pytorch_lightning.LightningModule
 88            The LightningModule being trained.
 89
 90        Notes
 91        -----
 92        Extracts the 'train_loss' metric from trainer.callback_metrics and
 93        appends it to the train_losses list.
 94        """
 95
 96        metrics = trainer.callback_metrics
 97        if "train_loss" in metrics:
 98            self.train_losses.append(metrics["train_loss"].item())
 99        if "train_acc" in metrics:
100            self.train_accs.append(metrics["train_acc"].item())
101
102    # ---------- Val loss and acc per epoch ----------
103    def on_validation_epoch_end(self, trainer, pl_module):
104        """
105        Called at the end of each validation epoch to collect validation
106        metrics.
107
108        Parameters
109        ----------
110        trainer : pytorch_lightning.Trainer
111            The PyTorch Lightning trainer instance.
112        pl_module : pytorch_lightning.LightningModule
113            The LightningModule being validated.
114
115        Notes
116        -----
117        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
118        and appends them to their respective lists.
119        """
120
121        metrics = trainer.callback_metrics
122        if "val_loss" in metrics:
123            self.val_losses.append(metrics["val_loss"].item())
124        if "val_acc" in metrics:
125            self.val_accs.append(metrics["val_acc"].item())
126
127    def on_train_end(self, trainer, pl_module):
128        """
129        Called at the end of training to generate and save plots and metrics.
130
131        Parameters
132        ----------
133        trainer : pytorch_lightning.Trainer
134            The PyTorch Lightning trainer instance.
135        pl_module : pytorch_lightning.LightningModule
136            The trained LightningModule.
137
138        Notes
139        -----
140        This method performs three main tasks:
141        1. Generates and saves a loss curve plot (loss_curve.png) showing
142           training loss and validation loss over epochs.
143        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
144           if validation accuracy was tracked.
145        3. Saves all raw metric data to a JSON file (metrics.json) for later
146           analysis or reproduction.
147
148        All output files are saved to the directory specified in save_dir.
149        """
150
151        # ---------- Save curves as a PNG ----------
152        plt.figure()
153        plt.plot(self.train_losses, label="Train Loss")
154        if len(self.val_losses) > 0:
155            plt.plot(self.val_losses, label="Val Loss")
156        plt.legend()
157        plt.title("Loss Curves")
158        plt.xlabel("Epochs")
159        plt.ylabel("Loss")
160        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
161        plt.close()
162
163        plt.figure()
164        plt.plot(self.train_accs, label="Train Accuracy")
165        if len(self.val_accs) > 0:
166            plt.plot(self.val_accs, label="Val Accuracy")
167        plt.legend()
168        plt.title("Accuracy Curves")
169        plt.xlabel("Epochs")
170        plt.ylabel("Accuracy")
171        plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
172        plt.close()
173
174        # ---------- Save raw data ----------
175        data = {
176            "train_losses": self.train_losses,
177            "val_losses": self.val_losses,
178            "train_accs": self.train_accs,
179            "val_accs": self.val_accs,
180        }
181
182        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
183            json.dump(data, f)

PyTorch Lightning callback for tracking and plotting loss curves.

This callback monitors training loss, validation loss, and validation accuracy throughout the training process. At the end of training, it generates and saves visualization plots and raw metric data.

Attributes
  • save_dir (str): Directory path where plots and metrics will be saved.
  • train_losses (list of float): Training loss values collected at the end of each training epoch.
  • val_losses (list of float): Validation loss values collected at the end of each validation epoch.
  • val_accs (list of float): Validation accuracy values collected at the end of each validation epoch.
Examples
>>> from pytorch_lightning import Trainer
>>> callback = LossCurveCallback(save_dir="./outputs/curves")
>>> trainer = Trainer(callbacks=[callback])
>>> trainer.fit(model, datamodule=data_module)
Notes

The callback creates three output files in the save_dir:

  • loss_curve.png: Plot of training and validation losses
  • val_acc_curve.png: Plot of validation accuracy
  • metrics.json: Raw metric data in JSON format
LossCurveCallback( save_dir='/home/alumno/Desktop/datos/SDOML/garbage_classifier/ models/performance/loss_curves/')
55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
56        """
57        Initialize the LossCurveCallback.
58
59        Parameters
60        ----------
61        save_dir : str, optional
62            Directory path where plots and metrics will be saved
63            (default is cfg.LOSS_CURVES_PATH).
64
65        Notes
66        -----
67        The save directory is created automatically if it does not exist.
68        """
69
70        super().__init__()
71        self.save_dir = save_dir
72        os.makedirs(self.save_dir, exist_ok=True)
73        self.train_losses = []
74        self.train_accs = []
75        self.val_losses = []
76        self.val_accs = []

Initialize the LossCurveCallback.

Parameters
  • save_dir (str, optional): Directory path where plots and metrics will be saved (default is cfg.LOSS_CURVES_PATH).
Notes

The save directory is created automatically if it does not exist.

save_dir
train_losses
train_accs
val_losses
val_accs
def on_train_epoch_end(self, trainer, pl_module):
 79    def on_train_epoch_end(self, trainer, pl_module):
 80        """
 81        Called at the end of each training epoch to collect training loss.
 82
 83        Parameters
 84        ----------
 85        trainer : pytorch_lightning.Trainer
 86            The PyTorch Lightning trainer instance.
 87        pl_module : pytorch_lightning.LightningModule
 88            The LightningModule being trained.
 89
 90        Notes
 91        -----
 92        Extracts the 'train_loss' metric from trainer.callback_metrics and
 93        appends it to the train_losses list.
 94        """
 95
 96        metrics = trainer.callback_metrics
 97        if "train_loss" in metrics:
 98            self.train_losses.append(metrics["train_loss"].item())
 99        if "train_acc" in metrics:
100            self.train_accs.append(metrics["train_acc"].item())

Called at the end of each training epoch to collect training loss.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being trained.
Notes

Extracts the 'train_loss' metric from trainer.callback_metrics and appends it to the train_losses list.

def on_validation_epoch_end(self, trainer, pl_module):
103    def on_validation_epoch_end(self, trainer, pl_module):
104        """
105        Called at the end of each validation epoch to collect validation
106        metrics.
107
108        Parameters
109        ----------
110        trainer : pytorch_lightning.Trainer
111            The PyTorch Lightning trainer instance.
112        pl_module : pytorch_lightning.LightningModule
113            The LightningModule being validated.
114
115        Notes
116        -----
117        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
118        and appends them to their respective lists.
119        """
120
121        metrics = trainer.callback_metrics
122        if "val_loss" in metrics:
123            self.val_losses.append(metrics["val_loss"].item())
124        if "val_acc" in metrics:
125            self.val_accs.append(metrics["val_acc"].item())

Called at the end of each validation epoch to collect validation metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being validated.
Notes

Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics and appends them to their respective lists.

def on_train_end(self, trainer, pl_module):
127    def on_train_end(self, trainer, pl_module):
128        """
129        Called at the end of training to generate and save plots and metrics.
130
131        Parameters
132        ----------
133        trainer : pytorch_lightning.Trainer
134            The PyTorch Lightning trainer instance.
135        pl_module : pytorch_lightning.LightningModule
136            The trained LightningModule.
137
138        Notes
139        -----
140        This method performs three main tasks:
141        1. Generates and saves a loss curve plot (loss_curve.png) showing
142           training loss and validation loss over epochs.
143        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
144           if validation accuracy was tracked.
145        3. Saves all raw metric data to a JSON file (metrics.json) for later
146           analysis or reproduction.
147
148        All output files are saved to the directory specified in save_dir.
149        """
150
151        # ---------- Save curves as a PNG ----------
152        plt.figure()
153        plt.plot(self.train_losses, label="Train Loss")
154        if len(self.val_losses) > 0:
155            plt.plot(self.val_losses, label="Val Loss")
156        plt.legend()
157        plt.title("Loss Curves")
158        plt.xlabel("Epochs")
159        plt.ylabel("Loss")
160        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
161        plt.close()
162
163        plt.figure()
164        plt.plot(self.train_accs, label="Train Accuracy")
165        if len(self.val_accs) > 0:
166            plt.plot(self.val_accs, label="Val Accuracy")
167        plt.legend()
168        plt.title("Accuracy Curves")
169        plt.xlabel("Epochs")
170        plt.ylabel("Accuracy")
171        plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
172        plt.close()
173
174        # ---------- Save raw data ----------
175        data = {
176            "train_losses": self.train_losses,
177            "val_losses": self.val_losses,
178            "train_accs": self.train_accs,
179            "val_accs": self.val_accs,
180        }
181
182        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
183            json.dump(data, f)

Called at the end of training to generate and save plots and metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The trained LightningModule.
Notes

This method performs three main tasks:

  1. Generates and saves a loss curve plot (loss_curve.png) showing training loss and validation loss over epochs.
  2. Generates and saves a validation accuracy plot (val_acc_curve.png) if validation accuracy was tracked.
  3. Saves all raw metric data to a JSON file (metrics.json) for later analysis or reproduction.

All output files are saved to the directory specified in save_dir.