source.utils.custom_classes
Custom PyTorch Lightning classes for Garbage Classification.
This module contains custom implementations of:
- GarbageClassifier: ResNet18-based classifier
- GarbageDataModule: Data loading and preprocessing
- LossCurveCallback: Training metrics visualization
1""" 2Custom PyTorch Lightning classes for Garbage Classification. 3 4This module contains custom implementations of: 5- GarbageClassifier: ResNet18-based classifier 6- GarbageDataModule: Data loading and preprocessing 7- LossCurveCallback: Training metrics visualization 8""" 9 10from source.utils.custom_classes.GarbageClassifier import GarbageClassifier 11from source.utils.custom_classes.GarbageDataModule import GarbageDataModule 12from source.utils.custom_classes.LossCurveCallback import LossCurveCallback 13 14__all__ = [ 15 "GarbageClassifier", 16 "GarbageDataModule", 17 "LossCurveCallback", 18]
26class GarbageClassifier(pl.LightningModule): 27 """ 28 Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification 29 problem. 30 It considers 6 classes: cardboard, glass, metal, paper, plastic and trash. 31 32 Attributes 33 ---------- 34 model : torchvision.models.resnet18 35 Pretrained ResNet18 model. 36 loss_fn : torch.nn.CrossEntropyLoss 37 Cross entropy loss function. 38 39 Examples 40 -------- 41 >>> model = GarbageClassifier(num_classes=6, lr=1e-3) 42 >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto") 43 >>> trainer.fit(model, datamodule=data_module) 44 """ 45 46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss() 70 71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x) 87 88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 110 # Accuracy 111 preds = out.argmax(dim=1) 112 acc = (preds == yb).float().mean() 113 114 self.log( 115 "train_loss", 116 loss, 117 on_step=False, 118 on_epoch=True, 119 prog_bar=True 120 ) 121 122 self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) 123 124 return loss 125 126 def validation_step(self, batch, batch_idx): 127 """ 128 Compute validation loss and accuracy for a batch of validation images. 129 130 Parameters 131 ---------- 132 batch : Tuple[torch.Tensor, torch.Tensor] 133 Subset (batch) of images coming from the validation dataloader. 134 Contains input images and corresponding labels. 135 batch_idx : int 136 Identifier of the batch within the current validation epoch. 137 138 Returns 139 ------- 140 torch.Tensor 141 Validation accuracy for the current batch. 142 """ 143 xb, yb = batch 144 out = self(xb) 145 loss = self.loss_fn(out, yb) 146 preds = out.argmax(dim=1) 147 acc = (preds == yb).float().mean() 148 self.log( 149 "val_loss", 150 loss, 151 on_step=False, 152 on_epoch=True, 153 prog_bar=False 154 ) 155 self.log( 156 "val_acc", 157 acc, 158 on_step=False, 159 on_epoch=True, 160 prog_bar=True 161 ) 162 return acc 163 164 def configure_optimizers(self): 165 """ 166 Configure the optimizer for training. 167 168 Returns 169 ------- 170 torch.optim.Adam 171 Adam optimizer configured with model parameters and learning rate 172 from hyperparameters. 173 """ 174 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
Attributes
- model (torchvision.models.resnet18): Pretrained ResNet18 model.
- loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss()
Initialize the GarbageClassifier model.
Parameters
- num_classes (int): Number of output classes for classification.
- lr (float, optional): Learning rate for the optimizer (default is 1e-3).
71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x)
Forward pass through the model.
Parameters
- x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
- torch.Tensor: Output logits with shape (batch_size, num_classes).
88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 110 # Accuracy 111 preds = out.argmax(dim=1) 112 acc = (preds == yb).float().mean() 113 114 self.log( 115 "train_loss", 116 loss, 117 on_step=False, 118 on_epoch=True, 119 prog_bar=True 120 ) 121 122 self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) 123 124 return loss
Model parameters are updated according to the classification error of a subset of train images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current epoch.
Returns
- torch.Tensor: Classification error (cross entropy loss) of trained image batch.
126 def validation_step(self, batch, batch_idx): 127 """ 128 Compute validation loss and accuracy for a batch of validation images. 129 130 Parameters 131 ---------- 132 batch : Tuple[torch.Tensor, torch.Tensor] 133 Subset (batch) of images coming from the validation dataloader. 134 Contains input images and corresponding labels. 135 batch_idx : int 136 Identifier of the batch within the current validation epoch. 137 138 Returns 139 ------- 140 torch.Tensor 141 Validation accuracy for the current batch. 142 """ 143 xb, yb = batch 144 out = self(xb) 145 loss = self.loss_fn(out, yb) 146 preds = out.argmax(dim=1) 147 acc = (preds == yb).float().mean() 148 self.log( 149 "val_loss", 150 loss, 151 on_step=False, 152 on_epoch=True, 153 prog_bar=False 154 ) 155 self.log( 156 "val_acc", 157 acc, 158 on_step=False, 159 on_epoch=True, 160 prog_bar=True 161 ) 162 return acc
Compute validation loss and accuracy for a batch of validation images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
- torch.Tensor: Validation accuracy for the current batch.
164 def configure_optimizers(self): 165 """ 166 Configure the optimizer for training. 167 168 Returns 169 ------- 170 torch.optim.Adam 171 Adam optimizer configured with model parameters and learning rate 172 from hyperparameters. 173 """ 174 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Configure the optimizer for training.
Returns
- torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.
20class GarbageDataModule(pl.LightningDataModule): 21 """ 22 PyTorch Lightning DataModule for Garbage Classification Dataset. 23 24 This DataModule handles loading, splitting, and creating dataloaders for 25 the garbage classification dataset. It performs a stratified 90/10 26 train/test split and applies ResNet18 ImageNet preprocessing transforms. 27 28 Attributes 29 ---------- 30 batch_size : int 31 Number of samples per batch for training. 32 num_workers : int 33 Number of subprocesses to use for data loading. 34 transform : torchvision.transforms.Compose 35 Image preprocessing transforms from ResNet18 ImageNet weights. 36 train_dataset : torch.utils.data.Subset 37 Training dataset subset. 38 test_dataset : torch.utils.data.Subset 39 Test/validation dataset subset. 40 train_idx : numpy.ndarray 41 Indices of samples in the training set. 42 test_idx : numpy.ndarray 43 Indices of samples in the test set. 44 num_classes : int 45 Number of classes in the dataset. 46 47 Examples 48 -------- 49 >>> data_module = GarbageDataModule(batch_size=32, num_workers=4) 50 >>> data_module.setup() 51 >>> train_loader = data_module.train_dataloader() 52 >>> val_loader = data_module.val_dataloader() 53 """ 54 55 def __init__(self, batch_size=32, num_workers=4): 56 """ 57 Initialize the GarbageDataModule. 58 59 Parameters 60 ---------- 61 batch_size : int, optional 62 Number of samples per batch for training (default is 32). 63 num_workers : int, optional 64 Number of subprocesses to use for data loading (default is 4). 65 """ 66 67 super().__init__() 68 self.batch_size = batch_size 69 self.num_workers = num_workers 70 self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 71 72 def setup(self, stage=None): 73 """ 74 Prepare the dataset by loading and splitting into train and test sets. 75 76 Loads the full dataset from the configured path, performs a stratified 77 90/10 train/test split to ensure balanced class distribution, and 78 creates dataset subsets for training and validation. 79 80 Parameters 81 ---------- 82 stage : str, optional 83 Current stage ('fit', 'validate', 'test', or 'predict'). 84 Not used in this implementation (default is None). 85 86 Notes 87 ----- 88 The dataset is split using a stratified approach with random_state=42 89 for reproducibility. The split ratio is 90% training and 10% testing. 90 """ 91 92 # Load full dataset 93 full_dataset = datasets.ImageFolder( 94 cfg.DATASET_PATH, transform=self.transform 95 ) 96 targets = [label for _, label in full_dataset] 97 self.num_classes = cfg.NUM_CLASSES 98 99 # Stratified split 90/10 100 train_idx, test_idx = train_test_split( 101 np.arange(len(targets)), test_size=0.1, 102 stratify=targets, random_state=42 103 ) 104 105 self.train_dataset = Subset(full_dataset, train_idx) 106 self.test_dataset = Subset(full_dataset, test_idx) 107 self.train_idx = train_idx 108 self.test_idx = test_idx 109 110 def train_dataloader(self): 111 """ 112 Create and return the training dataloader. 113 114 Returns 115 ------- 116 torch.utils.data.DataLoader 117 DataLoader for the training dataset with shuffling enabled. 118 119 Notes 120 ----- 121 The dataloader uses the configured batch_size and num_workers, 122 and shuffles the data at each epoch. 123 """ 124 125 return DataLoader( 126 self.train_dataset, 127 batch_size=self.batch_size, 128 shuffle=True, 129 num_workers=self.num_workers, 130 ) 131 132 def val_dataloader(self): 133 """ 134 Create and return the validation dataloader. 135 136 Returns 137 ------- 138 torch.utils.data.DataLoader 139 DataLoader for the validation/test dataset without shuffling. 140 141 Notes 142 ----- 143 The dataloader uses a fixed batch_size of 1000 for faster validation, 144 with num_workers from configuration. Shuffling is disabled to ensure 145 consistent validation metrics. 146 """ 147 148 return DataLoader( 149 self.test_dataset, 150 batch_size=1000, 151 shuffle=False, 152 num_workers=self.num_workers, 153 )
PyTorch Lightning DataModule for Garbage Classification Dataset.
This DataModule handles loading, splitting, and creating dataloaders for the garbage classification dataset. It performs a stratified 90/10 train/test split and applies ResNet18 ImageNet preprocessing transforms.
Attributes
- batch_size (int): Number of samples per batch for training.
- num_workers (int): Number of subprocesses to use for data loading.
- transform (torchvision.transforms.Compose): Image preprocessing transforms from ResNet18 ImageNet weights.
- train_dataset (torch.utils.data.Subset): Training dataset subset.
- test_dataset (torch.utils.data.Subset): Test/validation dataset subset.
- train_idx (numpy.ndarray): Indices of samples in the training set.
- test_idx (numpy.ndarray): Indices of samples in the test set.
- num_classes (int): Number of classes in the dataset.
Examples
>>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
>>> data_module.setup()
>>> train_loader = data_module.train_dataloader()
>>> val_loader = data_module.val_dataloader()
55 def __init__(self, batch_size=32, num_workers=4): 56 """ 57 Initialize the GarbageDataModule. 58 59 Parameters 60 ---------- 61 batch_size : int, optional 62 Number of samples per batch for training (default is 32). 63 num_workers : int, optional 64 Number of subprocesses to use for data loading (default is 4). 65 """ 66 67 super().__init__() 68 self.batch_size = batch_size 69 self.num_workers = num_workers 70 self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
Initialize the GarbageDataModule.
Parameters
- batch_size (int, optional): Number of samples per batch for training (default is 32).
- num_workers (int, optional): Number of subprocesses to use for data loading (default is 4).
72 def setup(self, stage=None): 73 """ 74 Prepare the dataset by loading and splitting into train and test sets. 75 76 Loads the full dataset from the configured path, performs a stratified 77 90/10 train/test split to ensure balanced class distribution, and 78 creates dataset subsets for training and validation. 79 80 Parameters 81 ---------- 82 stage : str, optional 83 Current stage ('fit', 'validate', 'test', or 'predict'). 84 Not used in this implementation (default is None). 85 86 Notes 87 ----- 88 The dataset is split using a stratified approach with random_state=42 89 for reproducibility. The split ratio is 90% training and 10% testing. 90 """ 91 92 # Load full dataset 93 full_dataset = datasets.ImageFolder( 94 cfg.DATASET_PATH, transform=self.transform 95 ) 96 targets = [label for _, label in full_dataset] 97 self.num_classes = cfg.NUM_CLASSES 98 99 # Stratified split 90/10 100 train_idx, test_idx = train_test_split( 101 np.arange(len(targets)), test_size=0.1, 102 stratify=targets, random_state=42 103 ) 104 105 self.train_dataset = Subset(full_dataset, train_idx) 106 self.test_dataset = Subset(full_dataset, test_idx) 107 self.train_idx = train_idx 108 self.test_idx = test_idx
Prepare the dataset by loading and splitting into train and test sets.
Loads the full dataset from the configured path, performs a stratified 90/10 train/test split to ensure balanced class distribution, and creates dataset subsets for training and validation.
Parameters
- stage (str, optional): Current stage ('fit', 'validate', 'test', or 'predict'). Not used in this implementation (default is None).
Notes
The dataset is split using a stratified approach with random_state=42 for reproducibility. The split ratio is 90% training and 10% testing.
110 def train_dataloader(self): 111 """ 112 Create and return the training dataloader. 113 114 Returns 115 ------- 116 torch.utils.data.DataLoader 117 DataLoader for the training dataset with shuffling enabled. 118 119 Notes 120 ----- 121 The dataloader uses the configured batch_size and num_workers, 122 and shuffles the data at each epoch. 123 """ 124 125 return DataLoader( 126 self.train_dataset, 127 batch_size=self.batch_size, 128 shuffle=True, 129 num_workers=self.num_workers, 130 )
Create and return the training dataloader.
Returns
- torch.utils.data.DataLoader: DataLoader for the training dataset with shuffling enabled.
Notes
The dataloader uses the configured batch_size and num_workers, and shuffles the data at each epoch.
132 def val_dataloader(self): 133 """ 134 Create and return the validation dataloader. 135 136 Returns 137 ------- 138 torch.utils.data.DataLoader 139 DataLoader for the validation/test dataset without shuffling. 140 141 Notes 142 ----- 143 The dataloader uses a fixed batch_size of 1000 for faster validation, 144 with num_workers from configuration. Shuffling is disabled to ensure 145 consistent validation metrics. 146 """ 147 148 return DataLoader( 149 self.test_dataset, 150 batch_size=1000, 151 shuffle=False, 152 num_workers=self.num_workers, 153 )
Create and return the validation dataloader.
Returns
- torch.utils.data.DataLoader: DataLoader for the validation/test dataset without shuffling.
Notes
The dataloader uses a fixed batch_size of 1000 for faster validation, with num_workers from configuration. Shuffling is disabled to ensure consistent validation metrics.
20class LossCurveCallback(Callback): 21 """ 22 PyTorch Lightning callback for tracking and plotting loss curves. 23 24 This callback monitors training loss, validation loss, and validation 25 accuracy throughout the training process. At the end of training, it 26 generates and saves visualization plots and raw metric data. 27 28 Attributes 29 ---------- 30 save_dir : str 31 Directory path where plots and metrics will be saved. 32 train_losses : list of float 33 Training loss values collected at the end of each training epoch. 34 val_losses : list of float 35 Validation loss values collected at the end of each validation epoch. 36 val_accs : list of float 37 Validation accuracy values collected at the end of each validation 38 epoch. 39 40 Examples 41 -------- 42 >>> from pytorch_lightning import Trainer 43 >>> callback = LossCurveCallback(save_dir="./outputs/curves") 44 >>> trainer = Trainer(callbacks=[callback]) 45 >>> trainer.fit(model, datamodule=data_module) 46 47 Notes 48 ----- 49 The callback creates three output files in the save_dir: 50 - loss_curve.png: Plot of training and validation losses 51 - val_acc_curve.png: Plot of validation accuracy 52 - metrics.json: Raw metric data in JSON format 53 """ 54 55 def __init__(self, save_dir=cfg.LOSS_CURVES_PATH): 56 """ 57 Initialize the LossCurveCallback. 58 59 Parameters 60 ---------- 61 save_dir : str, optional 62 Directory path where plots and metrics will be saved 63 (default is cfg.LOSS_CURVES_PATH). 64 65 Notes 66 ----- 67 The save directory is created automatically if it does not exist. 68 """ 69 70 super().__init__() 71 self.save_dir = save_dir 72 os.makedirs(self.save_dir, exist_ok=True) 73 self.train_losses = [] 74 self.train_accs = [] 75 self.val_losses = [] 76 self.val_accs = [] 77 78 # ---------- Train loss per epoch ---------- 79 def on_train_epoch_end(self, trainer, pl_module): 80 """ 81 Called at the end of each training epoch to collect training loss. 82 83 Parameters 84 ---------- 85 trainer : pytorch_lightning.Trainer 86 The PyTorch Lightning trainer instance. 87 pl_module : pytorch_lightning.LightningModule 88 The LightningModule being trained. 89 90 Notes 91 ----- 92 Extracts the 'train_loss' metric from trainer.callback_metrics and 93 appends it to the train_losses list. 94 """ 95 96 metrics = trainer.callback_metrics 97 if "train_loss" in metrics: 98 self.train_losses.append(metrics["train_loss"].item()) 99 if "train_acc" in metrics: 100 self.train_accs.append(metrics["train_acc"].item()) 101 102 # ---------- Val loss and acc per epoch ---------- 103 def on_validation_epoch_end(self, trainer, pl_module): 104 """ 105 Called at the end of each validation epoch to collect validation 106 metrics. 107 108 Parameters 109 ---------- 110 trainer : pytorch_lightning.Trainer 111 The PyTorch Lightning trainer instance. 112 pl_module : pytorch_lightning.LightningModule 113 The LightningModule being validated. 114 115 Notes 116 ----- 117 Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics 118 and appends them to their respective lists. 119 """ 120 121 metrics = trainer.callback_metrics 122 if "val_loss" in metrics: 123 self.val_losses.append(metrics["val_loss"].item()) 124 if "val_acc" in metrics: 125 self.val_accs.append(metrics["val_acc"].item()) 126 127 def on_train_end(self, trainer, pl_module): 128 """ 129 Called at the end of training to generate and save plots and metrics. 130 131 Parameters 132 ---------- 133 trainer : pytorch_lightning.Trainer 134 The PyTorch Lightning trainer instance. 135 pl_module : pytorch_lightning.LightningModule 136 The trained LightningModule. 137 138 Notes 139 ----- 140 This method performs three main tasks: 141 1. Generates and saves a loss curve plot (loss_curve.png) showing 142 training loss and validation loss over epochs. 143 2. Generates and saves a validation accuracy plot (val_acc_curve.png) 144 if validation accuracy was tracked. 145 3. Saves all raw metric data to a JSON file (metrics.json) for later 146 analysis or reproduction. 147 148 All output files are saved to the directory specified in save_dir. 149 """ 150 151 # ---------- Save curves as a PNG ---------- 152 plt.figure() 153 plt.plot(self.train_losses, label="Train Loss") 154 if len(self.val_losses) > 0: 155 plt.plot(self.val_losses, label="Val Loss") 156 plt.legend() 157 plt.title("Loss Curves") 158 plt.xlabel("Epochs") 159 plt.ylabel("Loss") 160 plt.savefig(os.path.join(self.save_dir, "loss_curve.png")) 161 plt.close() 162 163 plt.figure() 164 plt.plot(self.train_accs, label="Train Accuracy") 165 if len(self.val_accs) > 0: 166 plt.plot(self.val_accs, label="Val Accuracy") 167 plt.legend() 168 plt.title("Accuracy Curves") 169 plt.xlabel("Epochs") 170 plt.ylabel("Accuracy") 171 plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png")) 172 plt.close() 173 174 # ---------- Save raw data ---------- 175 data = { 176 "train_losses": self.train_losses, 177 "val_losses": self.val_losses, 178 "train_accs": self.train_accs, 179 "val_accs": self.val_accs, 180 } 181 182 with open(os.path.join(self.save_dir, "metrics.json"), "w") as f: 183 json.dump(data, f)
PyTorch Lightning callback for tracking and plotting loss curves.
This callback monitors training loss, validation loss, and validation accuracy throughout the training process. At the end of training, it generates and saves visualization plots and raw metric data.
Attributes
- save_dir (str): Directory path where plots and metrics will be saved.
- train_losses (list of float): Training loss values collected at the end of each training epoch.
- val_losses (list of float): Validation loss values collected at the end of each validation epoch.
- val_accs (list of float): Validation accuracy values collected at the end of each validation epoch.
Examples
>>> from pytorch_lightning import Trainer
>>> callback = LossCurveCallback(save_dir="./outputs/curves")
>>> trainer = Trainer(callbacks=[callback])
>>> trainer.fit(model, datamodule=data_module)
Notes
The callback creates three output files in the save_dir:
- loss_curve.png: Plot of training and validation losses
- val_acc_curve.png: Plot of validation accuracy
- metrics.json: Raw metric data in JSON format
55 def __init__(self, save_dir=cfg.LOSS_CURVES_PATH): 56 """ 57 Initialize the LossCurveCallback. 58 59 Parameters 60 ---------- 61 save_dir : str, optional 62 Directory path where plots and metrics will be saved 63 (default is cfg.LOSS_CURVES_PATH). 64 65 Notes 66 ----- 67 The save directory is created automatically if it does not exist. 68 """ 69 70 super().__init__() 71 self.save_dir = save_dir 72 os.makedirs(self.save_dir, exist_ok=True) 73 self.train_losses = [] 74 self.train_accs = [] 75 self.val_losses = [] 76 self.val_accs = []
Initialize the LossCurveCallback.
Parameters
- save_dir (str, optional): Directory path where plots and metrics will be saved (default is cfg.LOSS_CURVES_PATH).
Notes
The save directory is created automatically if it does not exist.
79 def on_train_epoch_end(self, trainer, pl_module): 80 """ 81 Called at the end of each training epoch to collect training loss. 82 83 Parameters 84 ---------- 85 trainer : pytorch_lightning.Trainer 86 The PyTorch Lightning trainer instance. 87 pl_module : pytorch_lightning.LightningModule 88 The LightningModule being trained. 89 90 Notes 91 ----- 92 Extracts the 'train_loss' metric from trainer.callback_metrics and 93 appends it to the train_losses list. 94 """ 95 96 metrics = trainer.callback_metrics 97 if "train_loss" in metrics: 98 self.train_losses.append(metrics["train_loss"].item()) 99 if "train_acc" in metrics: 100 self.train_accs.append(metrics["train_acc"].item())
Called at the end of each training epoch to collect training loss.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The LightningModule being trained.
Notes
Extracts the 'train_loss' metric from trainer.callback_metrics and appends it to the train_losses list.
103 def on_validation_epoch_end(self, trainer, pl_module): 104 """ 105 Called at the end of each validation epoch to collect validation 106 metrics. 107 108 Parameters 109 ---------- 110 trainer : pytorch_lightning.Trainer 111 The PyTorch Lightning trainer instance. 112 pl_module : pytorch_lightning.LightningModule 113 The LightningModule being validated. 114 115 Notes 116 ----- 117 Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics 118 and appends them to their respective lists. 119 """ 120 121 metrics = trainer.callback_metrics 122 if "val_loss" in metrics: 123 self.val_losses.append(metrics["val_loss"].item()) 124 if "val_acc" in metrics: 125 self.val_accs.append(metrics["val_acc"].item())
Called at the end of each validation epoch to collect validation metrics.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The LightningModule being validated.
Notes
Extracts 'val_loss' and 'val_acc' metrics from trainer.callback_metrics and appends them to their respective lists.
127 def on_train_end(self, trainer, pl_module): 128 """ 129 Called at the end of training to generate and save plots and metrics. 130 131 Parameters 132 ---------- 133 trainer : pytorch_lightning.Trainer 134 The PyTorch Lightning trainer instance. 135 pl_module : pytorch_lightning.LightningModule 136 The trained LightningModule. 137 138 Notes 139 ----- 140 This method performs three main tasks: 141 1. Generates and saves a loss curve plot (loss_curve.png) showing 142 training loss and validation loss over epochs. 143 2. Generates and saves a validation accuracy plot (val_acc_curve.png) 144 if validation accuracy was tracked. 145 3. Saves all raw metric data to a JSON file (metrics.json) for later 146 analysis or reproduction. 147 148 All output files are saved to the directory specified in save_dir. 149 """ 150 151 # ---------- Save curves as a PNG ---------- 152 plt.figure() 153 plt.plot(self.train_losses, label="Train Loss") 154 if len(self.val_losses) > 0: 155 plt.plot(self.val_losses, label="Val Loss") 156 plt.legend() 157 plt.title("Loss Curves") 158 plt.xlabel("Epochs") 159 plt.ylabel("Loss") 160 plt.savefig(os.path.join(self.save_dir, "loss_curve.png")) 161 plt.close() 162 163 plt.figure() 164 plt.plot(self.train_accs, label="Train Accuracy") 165 if len(self.val_accs) > 0: 166 plt.plot(self.val_accs, label="Val Accuracy") 167 plt.legend() 168 plt.title("Accuracy Curves") 169 plt.xlabel("Epochs") 170 plt.ylabel("Accuracy") 171 plt.savefig(os.path.join(self.save_dir, "val_acc_curve.png")) 172 plt.close() 173 174 # ---------- Save raw data ---------- 175 data = { 176 "train_losses": self.train_losses, 177 "val_losses": self.val_losses, 178 "train_accs": self.train_accs, 179 "val_accs": self.val_accs, 180 } 181 182 with open(os.path.join(self.save_dir, "metrics.json"), "w") as f: 183 json.dump(data, f)
Called at the end of training to generate and save plots and metrics.
Parameters
- trainer (pytorch_lightning.Trainer): The PyTorch Lightning trainer instance.
- pl_module (pytorch_lightning.LightningModule): The trained LightningModule.
Notes
This method performs three main tasks:
- Generates and saves a loss curve plot (loss_curve.png) showing training loss and validation loss over epochs.
- Generates and saves a validation accuracy plot (val_acc_curve.png) if validation accuracy was tracked.
- Saves all raw metric data to a JSON file (metrics.json) for later analysis or reproduction.
All output files are saved to the directory specified in save_dir.