utils.custom_classes.LossCurveCallback
PyTorch Lightning Callback for Loss and Accuracy Curve Visualization.
This module provides a custom callback that tracks and visualizes training and validation metrics during model training, saving plots and raw data to disk.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4PyTorch Lightning Callback for Loss and Accuracy Curve Visualization. 5 6This module provides a custom callback that tracks and visualizes training 7and validation metrics during model training, saving plots and raw data to 8disk. 9""" 10__docformat__ = "numpy" 11 12import os 13import json 14import matplotlib.pyplot as plt 15from pytorch_lightning.callbacks import Callback 16from utils import config as cfg 17 18 19class LossCurveCallback(Callback): 20 """ 21 PyTorch Lightning callback for tracking and plotting loss curves. 22 23 This callback monitors training loss, validation loss, and validation 24 accuracy throughout the training process. At the end of training, it 25 generates and saves visualization plots and raw metric data. 26 27 Attributes 28 ---------- 29 save_dir : str 30 Directory path where plots and metrics will be saved. 31 train_losses : list of float 32 Training loss values collected at the end of each training epoch. 33 val_losses : list of float 34 Validation loss values collected at the end of each validation epoch. 35 val_accs : list of float 36 Validation accuracy values collected at the end of each validation 37 epoch. 38 39 Examples 40 -------- 41 >>> from pytorch_lightning import Trainer 42 >>> callback = LossCurveCallback(save_dir="./outputs/curves") 43 >>> trainer = Trainer(callbacks=[callback]) 44 >>> trainer.fit(model, datamodule=data_module) 45 46 Notes 47 ----- 48 The callback creates three output files in the save_dir: 49 - loss_curve.png: Plot of training and validation losses 50 - val_acc_curve.png: Plot of validation accuracy 51 - metrics.json: Raw metric data in JSON format 52 """ 53 54 def __init__(self, save_dir=cfg.LOSS_CURVES_PATH): 55 """ 56 Initialize the LossCurveCallback. 57 58 Parameters 59 ---------- 60 save_dir : str, optional 61 Directory path where plots and metrics will be saved 62 (default is cfg.LOSS_CURVES_PATH). 63 64 Notes 65 ----- 66 The save directory is created automatically if it does not exist. 67 """ 68 69 super().__init__() 70 self.save_dir = save_dir 71 os.makedirs(self.save_dir, exist_ok=True) 72 self.train_losses = [] 73 self.val_losses = [] 74 self.val_accs = [] 75 76 # ---------- Train loss per epoch ---------- 77 def on_train_epoch_end(self, trainer, pl_module): 78 """ 79 Called at the end of each training epoch to collect training loss. 80 81 Parameters 82 ---------- 83 trainer : pytorch_lightning.Trainer 84 The PyTorch Lightning trainer instance. 85 pl_module : pytorch_lightning.LightningModule 86 The LightningModule being trained. 87 88 Notes 89 ----- 90 Extracts the 'train_loss' metric from trainer.callback_metrics and 91 appends it to the train_losses list. 92 """ 93 94 metrics = trainer.callback_metrics 95 if "train_loss" in metrics: 96 self.train_losses.append(metrics["train_loss"].item()) 97 98 # ---------- Val loss and acc per epoch ---------- 99 def on_validation_epoch_end(self, trainer, pl_module): 100 """ 101 Called at the end of each validation epoch to collect validation 102 metrics. 103 104 Parameters 105 ---------- 106 trainer : pytorch_lightning.Trainer 107 The PyTorch Lightning trainer instance. 108 pl_module : pytorch_lightning.LightningModule 109 The LightningModule being validated. 110 111 Notes 112 ----- 113 Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics 114 and appends them to their respective lists. 115 """ 116 117 metrics = trainer.callback_metrics 118 if "val_loss" in metrics: 119 self.val_losses.append(metrics["val_loss"].item()) 120 if "val_acc" in metrics: 121 self.val_accs.append(metrics["val_acc"].item()) 122 123 def on_train_end(self, trainer, pl_module): 124 """ 125 Called at the end of training to generate and save plots and metrics. 126 127 Parameters 128 ---------- 129 trainer : pytorch_lightning.Trainer 130 The PyTorch Lightning trainer instance. 131 pl_module : pytorch_lightning.LightningModule 132 The trained LightningModule. 133 134 Notes 135 ----- 136 This method performs three main tasks: 137 1. Generates and saves a loss curve plot (loss_curve.png) showing 138 training loss and validation loss over epochs. 139 2. Generates and saves a validation accuracy plot (val_acc_curve.png) 140 if validation accuracy was tracked. 141 3. Saves all raw metric data to a JSON file (metrics.json) for later 142 analysis or reproduction. 143 144 All output files are saved to the directory specified in save_dir. 145 """ 146 147 # ---------- Save curves as a PNG ---------- 148 plt.figure() 149 plt.plot(self.train_losses, label="Train Loss") 150 if len(self.val_losses) > 0: 151 plt.plot(self.val_losses, label="Val Loss") 152 plt.legend() 153 plt.title("Loss Curves") 154 plt.xlabel("Steps / Epochs") 155 plt.ylabel("Loss") 156 plt.savefig(os.path.join(self.save_dir, "loss_curve.png")) 157 plt.close() 158 159 if len(self.val_accs) > 0: 160 plt.figure() 161 plt.plot(self.val_accs, label="Val Accuracy") 162 plt.legend() 163 plt.title("Validation Accuracy") 164 plt.xlabel("Epochs") 165 plt.ylabel("Accuracy") 166 plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png")) 167 plt.close() 168 169 # ---------- Save raw data ---------- 170 data = { 171 "train_losses": self.train_losses, 172 "val_losses": self.val_losses, 173 "val_accs": self.val_accs, 174 } 175 176 with open(os.path.join(self.save_dir, "metrics.json"), "w") as f: 177 json.dump(data, f)
20class LossCurveCallback(Callback): 21 """ 22 PyTorch Lightning callback for tracking and plotting loss curves. 23 24 This callback monitors training loss, validation loss, and validation 25 accuracy throughout the training process. At the end of training, it 26 generates and saves visualization plots and raw metric data. 27 28 Attributes 29 ---------- 30 save_dir : str 31 Directory path where plots and metrics will be saved. 32 train_losses : list of float 33 Training loss values collected at the end of each training epoch. 34 val_losses : list of float 35 Validation loss values collected at the end of each validation epoch. 36 val_accs : list of float 37 Validation accuracy values collected at the end of each validation 38 epoch. 39 40 Examples 41 -------- 42 >>> from pytorch_lightning import Trainer 43 >>> callback = LossCurveCallback(save_dir="./outputs/curves") 44 >>> trainer = Trainer(callbacks=[callback]) 45 >>> trainer.fit(model, datamodule=data_module) 46 47 Notes 48 ----- 49 The callback creates three output files in the save_dir: 50 - loss_curve.png: Plot of training and validation losses 51 - val_acc_curve.png: Plot of validation accuracy 52 - metrics.json: Raw metric data in JSON format 53 """ 54 55 def __init__(self, save_dir=cfg.LOSS_CURVES_PATH): 56 """ 57 Initialize the LossCurveCallback. 58 59 Parameters 60 ---------- 61 save_dir : str, optional 62 Directory path where plots and metrics will be saved 63 (default is cfg.LOSS_CURVES_PATH). 64 65 Notes 66 ----- 67 The save directory is created automatically if it does not exist. 68 """ 69 70 super().__init__() 71 self.save_dir = save_dir 72 os.makedirs(self.save_dir, exist_ok=True) 73 self.train_losses = [] 74 self.val_losses = [] 75 self.val_accs = [] 76 77 # ---------- Train loss per epoch ---------- 78 def on_train_epoch_end(self, trainer, pl_module): 79 """ 80 Called at the end of each training epoch to collect training loss. 81 82 Parameters 83 ---------- 84 trainer : pytorch_lightning.Trainer 85 The PyTorch Lightning trainer instance. 86 pl_module : pytorch_lightning.LightningModule 87 The LightningModule being trained. 88 89 Notes 90 ----- 91 Extracts the 'train_loss' metric from trainer.callback_metrics and 92 appends it to the train_losses list. 93 """ 94 95 metrics = trainer.callback_metrics 96 if "train_loss" in metrics: 97 self.train_losses.append(metrics["train_loss"].item()) 98 99 # ---------- Val loss and acc per epoch ---------- 100 def on_validation_epoch_end(self, trainer, pl_module): 101 """ 102 Called at the end of each validation epoch to collect validation 103 metrics. 104 105 Parameters 106 ---------- 107 trainer : pytorch_lightning.Trainer 108 The PyTorch Lightning trainer instance. 109 pl_module : pytorch_lightning.LightningModule 110 The LightningModule being validated. 111 112 Notes 113 ----- 114 Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics 115 and appends them to their respective lists. 116 """ 117 118 metrics = trainer.callback_metrics 119 if "val_loss" in metrics: 120 self.val_losses.append(metrics["val_loss"].item()) 121 if "val_acc" in metrics: 122 self.val_accs.append(metrics["val_acc"].item()) 123 124 def on_train_end(self, trainer, pl_module): 125 """ 126 Called at the end of training to generate and save plots and metrics. 127 128 Parameters 129 ---------- 130 trainer : pytorch_lightning.Trainer 131 The PyTorch Lightning trainer instance. 132 pl_module : pytorch_lightning.LightningModule 133 The trained LightningModule. 134 135 Notes 136 ----- 137 This method performs three main tasks: 138 1. Generates and saves a loss curve plot (loss_curve.png) showing 139 training loss and validation loss over epochs. 140 2. Generates and saves a validation accuracy plot (val_acc_curve.png) 141 if validation accuracy was tracked. 142 3. Saves all raw metric data to a JSON file (metrics.json) for later 143 analysis or reproduction. 144 145 All output files are saved to the directory specified in save_dir. 146 """ 147 148 # ---------- Save curves as a PNG ---------- 149 plt.figure() 150 plt.plot(self.train_losses, label="Train Loss") 151 if len(self.val_losses) > 0: 152 plt.plot(self.val_losses, label="Val Loss") 153 plt.legend() 154 plt.title("Loss Curves") 155 plt.xlabel("Steps / Epochs") 156 plt.ylabel("Loss") 157 plt.savefig(os.path.join(self.save_dir, "loss_curve.png")) 158 plt.close() 159 160 if len(self.val_accs) > 0: 161 plt.figure() 162 plt.plot(self.val_accs, label="Val Accuracy") 163 plt.legend() 164 plt.title("Validation Accuracy") 165 plt.xlabel("Epochs") 166 plt.ylabel("Accuracy") 167 plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png")) 168 plt.close() 169 170 # ---------- Save raw data ---------- 171 data = { 172 "train_losses": self.train_losses, 173 "val_losses": self.val_losses, 174 "val_accs": self.val_accs, 175 } 176 177 with open(os.path.join(self.save_dir, "metrics.json"), "w") as f: 178 json.dump(data, f)
PyTorch Lightning callback for tracking and plotting loss curves.
This callback monitors training loss, validation loss, and validation accuracy throughout the training process. At the end of training, it generates and saves visualization plots and raw metric data.
Attributes
- save_dir (str): Directory path where plots and metrics will be saved.
- train_losses (list of float): Training loss values collected at the end of each training epoch.
- val_losses (list of float): Validation loss values collected at the end of each validation epoch.
- val_accs (list of float): Validation accuracy values collected at the end of each validation epoch.
Examples
>>> from pytorch_lightning import Trainer
>>> callback = LossCurveCallback(save_dir="./outputs/curves")
>>> trainer = Trainer(callbacks=[callback])
>>> trainer.fit(model, datamodule=data_module)
Notes
The callback creates three output files in the save_dir:
- loss_curve.png: Plot of training and validation losses
- val_acc_curve.png: Plot of validation accuracy
- metrics.json: Raw metric data in JSON format
55 def __init__(self, save_dir=cfg.LOSS_CURVES_PATH): 56 """ 57 Initialize the LossCurveCallback. 58 59 Parameters 60 ---------- 61 save_dir : str, optional 62 Directory path where plots and metrics will be saved 63 (default is cfg.LOSS_CURVES_PATH). 64 65 Notes 66 ----- 67 The save directory is created automatically if it does not exist. 68 """ 69 70 super().__init__() 71 self.save_dir = save_dir 72 os.makedirs(self.save_dir, exist_ok=True) 73 self.train_losses = [] 74 self.val_losses = [] 75 self.val_accs = []
Initialize the LossCurveCallback.
Parameters
- save_dir (str, optional): Directory path where plots and metrics will be saved (default is cfg.LOSS_CURVES_PATH).
Notes
The save directory is created automatically if it does not exist.
78 def on_train_epoch_end(self, trainer, pl_module): 79 """ 80 Called at the end of each training epoch to collect training loss. 81 82 Parameters 83 ---------- 84 trainer : pytorch_lightning.Trainer 85 The PyTorch Lightning trainer instance. 86 pl_module : pytorch_lightning.LightningModule 87 The LightningModule being trained. 88 89 Notes 90 ----- 91 Extracts the 'train_loss' metric from trainer.callback_metrics and 92 appends it to the train_losses list. 93 """ 94 95 metrics = trainer.callback_metrics 96 if "train_loss" in metrics: 97 self.train_losses.append(metrics["train_loss"].item())
Called at the end of each training epoch to collect training loss.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The LightningModule being trained.
Notes
Extracts the 'train_loss' metric from trainer.callback_metrics and appends it to the train_losses list.
100 def on_validation_epoch_end(self, trainer, pl_module): 101 """ 102 Called at the end of each validation epoch to collect validation 103 metrics. 104 105 Parameters 106 ---------- 107 trainer : pytorch_lightning.Trainer 108 The PyTorch Lightning trainer instance. 109 pl_module : pytorch_lightning.LightningModule 110 The LightningModule being validated. 111 112 Notes 113 ----- 114 Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics 115 and appends them to their respective lists. 116 """ 117 118 metrics = trainer.callback_metrics 119 if "val_loss" in metrics: 120 self.val_losses.append(metrics["val_loss"].item()) 121 if "val_acc" in metrics: 122 self.val_accs.append(metrics["val_acc"].item())
Called at the end of each validation epoch to collect validation metrics.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The LightningModule being validated.
Notes
Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics and appends them to their respective lists.
124 def on_train_end(self, trainer, pl_module): 125 """ 126 Called at the end of training to generate and save plots and metrics. 127 128 Parameters 129 ---------- 130 trainer : pytorch_lightning.Trainer 131 The PyTorch Lightning trainer instance. 132 pl_module : pytorch_lightning.LightningModule 133 The trained LightningModule. 134 135 Notes 136 ----- 137 This method performs three main tasks: 138 1. Generates and saves a loss curve plot (loss_curve.png) showing 139 training loss and validation loss over epochs. 140 2. Generates and saves a validation accuracy plot (val_acc_curve.png) 141 if validation accuracy was tracked. 142 3. Saves all raw metric data to a JSON file (metrics.json) for later 143 analysis or reproduction. 144 145 All output files are saved to the directory specified in save_dir. 146 """ 147 148 # ---------- Save curves as a PNG ---------- 149 plt.figure() 150 plt.plot(self.train_losses, label="Train Loss") 151 if len(self.val_losses) > 0: 152 plt.plot(self.val_losses, label="Val Loss") 153 plt.legend() 154 plt.title("Loss Curves") 155 plt.xlabel("Steps / Epochs") 156 plt.ylabel("Loss") 157 plt.savefig(os.path.join(self.save_dir, "loss_curve.png")) 158 plt.close() 159 160 if len(self.val_accs) > 0: 161 plt.figure() 162 plt.plot(self.val_accs, label="Val Accuracy") 163 plt.legend() 164 plt.title("Validation Accuracy") 165 plt.xlabel("Epochs") 166 plt.ylabel("Accuracy") 167 plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png")) 168 plt.close() 169 170 # ---------- Save raw data ---------- 171 data = { 172 "train_losses": self.train_losses, 173 "val_losses": self.val_losses, 174 "val_accs": self.val_accs, 175 } 176 177 with open(os.path.join(self.save_dir, "metrics.json"), "w") as f: 178 json.dump(data, f)
Called at the end of training to generate and save plots and metrics.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The trained LightningModule.
Notes
This method performs three main tasks:
- Generates and saves a loss curve plot (loss_curve.png) showing training loss and validation loss over epochs.
- Generates and saves a validation accuracy plot (val_acc_curve.png) if validation accuracy was tracked.
- Saves all raw metric data to a JSON file (metrics.json) for later analysis or reproduction.
All output files are saved to the directory specified in save_dir.