utils.custom_classes.LossCurveCallback

PyTorch Lightning Callback for Loss and Accuracy Curve Visualization.

This module provides a custom callback that tracks and visualizes training and validation metrics during model training, saving plots and raw data to disk.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4PyTorch Lightning Callback for Loss and Accuracy Curve Visualization.
  5
  6This module provides a custom callback that tracks and visualizes training
  7and validation metrics during model training, saving plots and raw data to
  8disk.
  9"""
 10__docformat__ = "numpy"
 11
 12import os
 13import json
 14import matplotlib.pyplot as plt
 15from pytorch_lightning.callbacks import Callback
 16from utils import config as cfg
 17
 18
 19class LossCurveCallback(Callback):
 20    """
 21    PyTorch Lightning callback for tracking and plotting loss curves.
 22
 23    This callback monitors training loss, validation loss, and validation
 24    accuracy throughout the training process. At the end of training, it
 25    generates and saves visualization plots and raw metric data.
 26
 27    Attributes
 28    ----------
 29    save_dir : str
 30        Directory path where plots and metrics will be saved.
 31    train_losses : list of float
 32        Training loss values collected at the end of each training epoch.
 33    val_losses : list of float
 34        Validation loss values collected at the end of each validation epoch.
 35    val_accs : list of float
 36        Validation accuracy values collected at the end of each validation
 37        epoch.
 38
 39    Examples
 40    --------
 41    >>> from pytorch_lightning import Trainer
 42    >>> callback = LossCurveCallback(save_dir="./outputs/curves")
 43    >>> trainer = Trainer(callbacks=[callback])
 44    >>> trainer.fit(model, datamodule=data_module)
 45
 46    Notes
 47    -----
 48    The callback creates three output files in the save_dir:
 49    - loss_curve.png: Plot of training and validation losses
 50    - val_acc_curve.png: Plot of validation accuracy
 51    - metrics.json: Raw metric data in JSON format
 52    """
 53
 54    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
 55        """
 56        Initialize the LossCurveCallback.
 57
 58        Parameters
 59        ----------
 60        save_dir : str, optional
 61            Directory path where plots and metrics will be saved
 62            (default is cfg.LOSS_CURVES_PATH).
 63
 64        Notes
 65        -----
 66        The save directory is created automatically if it does not exist.
 67        """
 68
 69        super().__init__()
 70        self.save_dir = save_dir
 71        os.makedirs(self.save_dir, exist_ok=True)
 72        self.train_losses = []
 73        self.val_losses = []
 74        self.val_accs = []
 75
 76    # ---------- Train loss per epoch ----------
 77    def on_train_epoch_end(self, trainer, pl_module):
 78        """
 79        Called at the end of each training epoch to collect training loss.
 80
 81        Parameters
 82        ----------
 83        trainer : pytorch_lightning.Trainer
 84            The PyTorch Lightning trainer instance.
 85        pl_module : pytorch_lightning.LightningModule
 86            The LightningModule being trained.
 87
 88        Notes
 89        -----
 90        Extracts the 'train_loss' metric from trainer.callback_metrics and
 91        appends it to the train_losses list.
 92        """
 93
 94        metrics = trainer.callback_metrics
 95        if "train_loss" in metrics:
 96            self.train_losses.append(metrics["train_loss"].item())
 97
 98    # ---------- Val loss and acc per epoch ----------
 99    def on_validation_epoch_end(self, trainer, pl_module):
100        """
101        Called at the end of each validation epoch to collect validation
102        metrics.
103
104        Parameters
105        ----------
106        trainer : pytorch_lightning.Trainer
107            The PyTorch Lightning trainer instance.
108        pl_module : pytorch_lightning.LightningModule
109            The LightningModule being validated.
110
111        Notes
112        -----
113        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
114        and appends them to their respective lists.
115        """
116
117        metrics = trainer.callback_metrics
118        if "val_loss" in metrics:
119            self.val_losses.append(metrics["val_loss"].item())
120        if "val_acc" in metrics:
121            self.val_accs.append(metrics["val_acc"].item())
122
123    def on_train_end(self, trainer, pl_module):
124        """
125        Called at the end of training to generate and save plots and metrics.
126
127        Parameters
128        ----------
129        trainer : pytorch_lightning.Trainer
130            The PyTorch Lightning trainer instance.
131        pl_module : pytorch_lightning.LightningModule
132            The trained LightningModule.
133
134        Notes
135        -----
136        This method performs three main tasks:
137        1. Generates and saves a loss curve plot (loss_curve.png) showing
138           training loss and validation loss over epochs.
139        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
140           if validation accuracy was tracked.
141        3. Saves all raw metric data to a JSON file (metrics.json) for later
142           analysis or reproduction.
143
144        All output files are saved to the directory specified in save_dir.
145        """
146
147        # ---------- Save curves as a PNG ----------
148        plt.figure()
149        plt.plot(self.train_losses, label="Train Loss")
150        if len(self.val_losses) > 0:
151            plt.plot(self.val_losses, label="Val Loss")
152        plt.legend()
153        plt.title("Loss Curves")
154        plt.xlabel("Steps / Epochs")
155        plt.ylabel("Loss")
156        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
157        plt.close()
158
159        if len(self.val_accs) > 0:
160            plt.figure()
161            plt.plot(self.val_accs, label="Val Accuracy")
162            plt.legend()
163            plt.title("Validation Accuracy")
164            plt.xlabel("Epochs")
165            plt.ylabel("Accuracy")
166            plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
167            plt.close()
168
169        # ---------- Save raw data ----------
170        data = {
171            "train_losses": self.train_losses,
172            "val_losses": self.val_losses,
173            "val_accs": self.val_accs,
174        }
175
176        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
177            json.dump(data, f)
class LossCurveCallback(pytorch_lightning.callbacks.callback.Callback):
 20class LossCurveCallback(Callback):
 21    """
 22    PyTorch Lightning callback for tracking and plotting loss curves.
 23
 24    This callback monitors training loss, validation loss, and validation
 25    accuracy throughout the training process. At the end of training, it
 26    generates and saves visualization plots and raw metric data.
 27
 28    Attributes
 29    ----------
 30    save_dir : str
 31        Directory path where plots and metrics will be saved.
 32    train_losses : list of float
 33        Training loss values collected at the end of each training epoch.
 34    val_losses : list of float
 35        Validation loss values collected at the end of each validation epoch.
 36    val_accs : list of float
 37        Validation accuracy values collected at the end of each validation
 38        epoch.
 39
 40    Examples
 41    --------
 42    >>> from pytorch_lightning import Trainer
 43    >>> callback = LossCurveCallback(save_dir="./outputs/curves")
 44    >>> trainer = Trainer(callbacks=[callback])
 45    >>> trainer.fit(model, datamodule=data_module)
 46
 47    Notes
 48    -----
 49    The callback creates three output files in the save_dir:
 50    - loss_curve.png: Plot of training and validation losses
 51    - val_acc_curve.png: Plot of validation accuracy
 52    - metrics.json: Raw metric data in JSON format
 53    """
 54
 55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
 56        """
 57        Initialize the LossCurveCallback.
 58
 59        Parameters
 60        ----------
 61        save_dir : str, optional
 62            Directory path where plots and metrics will be saved
 63            (default is cfg.LOSS_CURVES_PATH).
 64
 65        Notes
 66        -----
 67        The save directory is created automatically if it does not exist.
 68        """
 69
 70        super().__init__()
 71        self.save_dir = save_dir
 72        os.makedirs(self.save_dir, exist_ok=True)
 73        self.train_losses = []
 74        self.val_losses = []
 75        self.val_accs = []
 76
 77    # ---------- Train loss per epoch ----------
 78    def on_train_epoch_end(self, trainer, pl_module):
 79        """
 80        Called at the end of each training epoch to collect training loss.
 81
 82        Parameters
 83        ----------
 84        trainer : pytorch_lightning.Trainer
 85            The PyTorch Lightning trainer instance.
 86        pl_module : pytorch_lightning.LightningModule
 87            The LightningModule being trained.
 88
 89        Notes
 90        -----
 91        Extracts the 'train_loss' metric from trainer.callback_metrics and
 92        appends it to the train_losses list.
 93        """
 94
 95        metrics = trainer.callback_metrics
 96        if "train_loss" in metrics:
 97            self.train_losses.append(metrics["train_loss"].item())
 98
 99    # ---------- Val loss and acc per epoch ----------
100    def on_validation_epoch_end(self, trainer, pl_module):
101        """
102        Called at the end of each validation epoch to collect validation
103        metrics.
104
105        Parameters
106        ----------
107        trainer : pytorch_lightning.Trainer
108            The PyTorch Lightning trainer instance.
109        pl_module : pytorch_lightning.LightningModule
110            The LightningModule being validated.
111
112        Notes
113        -----
114        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
115        and appends them to their respective lists.
116        """
117
118        metrics = trainer.callback_metrics
119        if "val_loss" in metrics:
120            self.val_losses.append(metrics["val_loss"].item())
121        if "val_acc" in metrics:
122            self.val_accs.append(metrics["val_acc"].item())
123
124    def on_train_end(self, trainer, pl_module):
125        """
126        Called at the end of training to generate and save plots and metrics.
127
128        Parameters
129        ----------
130        trainer : pytorch_lightning.Trainer
131            The PyTorch Lightning trainer instance.
132        pl_module : pytorch_lightning.LightningModule
133            The trained LightningModule.
134
135        Notes
136        -----
137        This method performs three main tasks:
138        1. Generates and saves a loss curve plot (loss_curve.png) showing
139           training loss and validation loss over epochs.
140        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
141           if validation accuracy was tracked.
142        3. Saves all raw metric data to a JSON file (metrics.json) for later
143           analysis or reproduction.
144
145        All output files are saved to the directory specified in save_dir.
146        """
147
148        # ---------- Save curves as a PNG ----------
149        plt.figure()
150        plt.plot(self.train_losses, label="Train Loss")
151        if len(self.val_losses) > 0:
152            plt.plot(self.val_losses, label="Val Loss")
153        plt.legend()
154        plt.title("Loss Curves")
155        plt.xlabel("Steps / Epochs")
156        plt.ylabel("Loss")
157        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
158        plt.close()
159
160        if len(self.val_accs) > 0:
161            plt.figure()
162            plt.plot(self.val_accs, label="Val Accuracy")
163            plt.legend()
164            plt.title("Validation Accuracy")
165            plt.xlabel("Epochs")
166            plt.ylabel("Accuracy")
167            plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
168            plt.close()
169
170        # ---------- Save raw data ----------
171        data = {
172            "train_losses": self.train_losses,
173            "val_losses": self.val_losses,
174            "val_accs": self.val_accs,
175        }
176
177        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
178            json.dump(data, f)

PyTorch Lightning callback for tracking and plotting loss curves.

This callback monitors training loss, validation loss, and validation accuracy throughout the training process. At the end of training, it generates and saves visualization plots and raw metric data.

Attributes
  • save_dir (str): Directory path where plots and metrics will be saved.
  • train_losses (list of float): Training loss values collected at the end of each training epoch.
  • val_losses (list of float): Validation loss values collected at the end of each validation epoch.
  • val_accs (list of float): Validation accuracy values collected at the end of each validation epoch.
Examples
>>> from pytorch_lightning import Trainer
>>> callback = LossCurveCallback(save_dir="./outputs/curves")
>>> trainer = Trainer(callbacks=[callback])
>>> trainer.fit(model, datamodule=data_module)
Notes

The callback creates three output files in the save_dir:

  • loss_curve.png: Plot of training and validation losses
  • val_acc_curve.png: Plot of validation accuracy
  • metrics.json: Raw metric data in JSON format
LossCurveCallback(save_dir='../models/performance/loss_curves/')
55    def __init__(self, save_dir=cfg.LOSS_CURVES_PATH):
56        """
57        Initialize the LossCurveCallback.
58
59        Parameters
60        ----------
61        save_dir : str, optional
62            Directory path where plots and metrics will be saved
63            (default is cfg.LOSS_CURVES_PATH).
64
65        Notes
66        -----
67        The save directory is created automatically if it does not exist.
68        """
69
70        super().__init__()
71        self.save_dir = save_dir
72        os.makedirs(self.save_dir, exist_ok=True)
73        self.train_losses = []
74        self.val_losses = []
75        self.val_accs = []

Initialize the LossCurveCallback.

Parameters
  • save_dir (str, optional): Directory path where plots and metrics will be saved (default is cfg.LOSS_CURVES_PATH).
Notes

The save directory is created automatically if it does not exist.

def on_train_epoch_end(self, trainer, pl_module):
78    def on_train_epoch_end(self, trainer, pl_module):
79        """
80        Called at the end of each training epoch to collect training loss.
81
82        Parameters
83        ----------
84        trainer : pytorch_lightning.Trainer
85            The PyTorch Lightning trainer instance.
86        pl_module : pytorch_lightning.LightningModule
87            The LightningModule being trained.
88
89        Notes
90        -----
91        Extracts the 'train_loss' metric from trainer.callback_metrics and
92        appends it to the train_losses list.
93        """
94
95        metrics = trainer.callback_metrics
96        if "train_loss" in metrics:
97            self.train_losses.append(metrics["train_loss"].item())

Called at the end of each training epoch to collect training loss.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being trained.
Notes

Extracts the 'train_loss' metric from trainer.callback_metrics and appends it to the train_losses list.

def on_validation_epoch_end(self, trainer, pl_module):
100    def on_validation_epoch_end(self, trainer, pl_module):
101        """
102        Called at the end of each validation epoch to collect validation
103        metrics.
104
105        Parameters
106        ----------
107        trainer : pytorch_lightning.Trainer
108            The PyTorch Lightning trainer instance.
109        pl_module : pytorch_lightning.LightningModule
110            The LightningModule being validated.
111
112        Notes
113        -----
114        Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics
115        and appends them to their respective lists.
116        """
117
118        metrics = trainer.callback_metrics
119        if "val_loss" in metrics:
120            self.val_losses.append(metrics["val_loss"].item())
121        if "val_acc" in metrics:
122            self.val_accs.append(metrics["val_acc"].item())

Called at the end of each validation epoch to collect validation metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The LightningModule being validated.
Notes

Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics and appends them to their respective lists.

def on_train_end(self, trainer, pl_module):
124    def on_train_end(self, trainer, pl_module):
125        """
126        Called at the end of training to generate and save plots and metrics.
127
128        Parameters
129        ----------
130        trainer : pytorch_lightning.Trainer
131            The PyTorch Lightning trainer instance.
132        pl_module : pytorch_lightning.LightningModule
133            The trained LightningModule.
134
135        Notes
136        -----
137        This method performs three main tasks:
138        1. Generates and saves a loss curve plot (loss_curve.png) showing
139           training loss and validation loss over epochs.
140        2. Generates and saves a validation accuracy plot (val_acc_curve.png)
141           if validation accuracy was tracked.
142        3. Saves all raw metric data to a JSON file (metrics.json) for later
143           analysis or reproduction.
144
145        All output files are saved to the directory specified in save_dir.
146        """
147
148        # ---------- Save curves as a PNG ----------
149        plt.figure()
150        plt.plot(self.train_losses, label="Train Loss")
151        if len(self.val_losses) > 0:
152            plt.plot(self.val_losses, label="Val Loss")
153        plt.legend()
154        plt.title("Loss Curves")
155        plt.xlabel("Steps / Epochs")
156        plt.ylabel("Loss")
157        plt.savefig(os.path.join(self.save_dir, "loss_curve.png"))
158        plt.close()
159
160        if len(self.val_accs) > 0:
161            plt.figure()
162            plt.plot(self.val_accs, label="Val Accuracy")
163            plt.legend()
164            plt.title("Validation Accuracy")
165            plt.xlabel("Epochs")
166            plt.ylabel("Accuracy")
167            plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png"))
168            plt.close()
169
170        # ---------- Save raw data ----------
171        data = {
172            "train_losses": self.train_losses,
173            "val_losses": self.val_losses,
174            "val_accs": self.val_accs,
175        }
176
177        with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
178            json.dump(data, f)

Called at the end of training to generate and save plots and metrics.

Parameters
  • trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
  • pl_module (pytorch_lightning.LightningModule): The trained LightningModule.
Notes

This method performs three main tasks:

  1. Generates and saves a loss curve plot (loss_curve.png) showing training loss and validation loss over epochs.
  2. Generates and saves a validation accuracy plot (val_acc_curve.png) if validation accuracy was tracked.
  3. Saves all raw metric data to a JSON file (metrics.json) for later analysis or reproduction.

All output files are saved to the directory specified in save_dir.