source.utils.custom_classes.GarbageClassifier
Garbage Classification Model Module.
This module implements a PyTorch Lightning module for garbage classification using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class garbage classification problem (cardboard, glass, metal, paper, plastic, trash).
The model uses transfer learning by freezing the pretrained ResNet18 feature extraction layers and training only the final classification layer.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Garbage Classification Model Module. 5 6This module implements a PyTorch Lightning module for garbage classification 7using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class 8garbage classification problem (cardboard, glass, metal, paper, plastic, 9trash). 10 11The model uses transfer learning by freezing the pretrained ResNet18 feature 12extraction layers and training only the final classification layer. 13""" 14__docformat__ = "numpy" 15 16import pytorch_lightning as pl 17import torch 18from torch import nn 19from torchvision import models 20 21 22# ======================== 23# LightningModule 24# ======================== 25class GarbageClassifier(pl.LightningModule): 26 """ 27 Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification 28 problem. 29 It considers 6 classes: cardboard, glass, metal, paper, plastic and trash. 30 31 Attributes 32 ---------- 33 model : torchvision.models.resnet18 34 Pretrained ResNet18 model. 35 loss_fn : torch.nn.CrossEntropyLoss 36 Cross entropy loss function. 37 38 Examples 39 -------- 40 >>> model = GarbageClassifier(num_classes=6, lr=1e-3) 41 >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto") 42 >>> trainer.fit(model, datamodule=data_module) 43 """ 44 45 def __init__(self, num_classes, lr=1e-3): 46 """ 47 Initialize the GarbageClassifier model. 48 49 Parameters 50 ---------- 51 num_classes : int 52 Number of output classes for classification. 53 lr : float, optional 54 Learning rate for the optimizer (default is 1e-3). 55 """ 56 super().__init__() 57 self.save_hyperparameters() 58 self.model = models.resnet18( 59 weights=models.ResNet18_Weights.IMAGENET1K_V1 60 ) 61 62 # Freeze feature layers 63 for param in self.model.parameters(): 64 param.requires_grad = False 65 66 # New classifier layer 67 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 68 self.loss_fn = nn.CrossEntropyLoss() 69 70 def forward(self, x): 71 """ 72 Forward pass through the model. 73 74 Parameters 75 ---------- 76 x : torch.Tensor 77 Input tensor of images with shape (batch_size, channels, height, 78 width). 79 80 Returns 81 ------- 82 torch.Tensor 83 Output logits with shape (batch_size, num_classes). 84 """ 85 return self.model(x) 86 87 def training_step(self, batch, batch_idx): 88 """ 89 Model parameters are updated according to the classification error 90 of a subset of train images. 91 92 Parameters 93 ---------- 94 batch : Tuple[torch.Tensor, torch.Tensor] 95 Subset (batch) of images coming from the train dataloader. 96 Contains input images and corresponding labels. 97 batch_idx : int 98 Identifier of the batch within the current epoch. 99 100 Returns 101 ------- 102 torch.Tensor 103 Classification error (cross entropy loss) of trained image batch. 104 """ 105 xb, yb = batch 106 out = self(xb) 107 loss = self.loss_fn(out, yb) 108 109 # Accuracy 110 preds = out.argmax(dim=1) 111 acc = (preds == yb).float().mean() 112 113 self.log( 114 "train_loss", 115 loss, 116 on_step=False, 117 on_epoch=True, 118 prog_bar=True 119 ) 120 121 self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) 122 123 return loss 124 125 def validation_step(self, batch, batch_idx): 126 """ 127 Compute validation loss and accuracy for a batch of validation images. 128 129 Parameters 130 ---------- 131 batch : Tuple[torch.Tensor, torch.Tensor] 132 Subset (batch) of images coming from the validation dataloader. 133 Contains input images and corresponding labels. 134 batch_idx : int 135 Identifier of the batch within the current validation epoch. 136 137 Returns 138 ------- 139 torch.Tensor 140 Validation accuracy for the current batch. 141 """ 142 xb, yb = batch 143 out = self(xb) 144 loss = self.loss_fn(out, yb) 145 preds = out.argmax(dim=1) 146 acc = (preds == yb).float().mean() 147 self.log( 148 "val_loss", 149 loss, 150 on_step=False, 151 on_epoch=True, 152 prog_bar=False 153 ) 154 self.log( 155 "val_acc", 156 acc, 157 on_step=False, 158 on_epoch=True, 159 prog_bar=True 160 ) 161 return acc 162 163 def configure_optimizers(self): 164 """ 165 Configure the optimizer for training. 166 167 Returns 168 ------- 169 torch.optim.Adam 170 Adam optimizer configured with model parameters and learning rate 171 from hyperparameters. 172 """ 173 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
26class GarbageClassifier(pl.LightningModule): 27 """ 28 Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification 29 problem. 30 It considers 6 classes: cardboard, glass, metal, paper, plastic and trash. 31 32 Attributes 33 ---------- 34 model : torchvision.models.resnet18 35 Pretrained ResNet18 model. 36 loss_fn : torch.nn.CrossEntropyLoss 37 Cross entropy loss function. 38 39 Examples 40 -------- 41 >>> model = GarbageClassifier(num_classes=6, lr=1e-3) 42 >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto") 43 >>> trainer.fit(model, datamodule=data_module) 44 """ 45 46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss() 70 71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x) 87 88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 110 # Accuracy 111 preds = out.argmax(dim=1) 112 acc = (preds == yb).float().mean() 113 114 self.log( 115 "train_loss", 116 loss, 117 on_step=False, 118 on_epoch=True, 119 prog_bar=True 120 ) 121 122 self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) 123 124 return loss 125 126 def validation_step(self, batch, batch_idx): 127 """ 128 Compute validation loss and accuracy for a batch of validation images. 129 130 Parameters 131 ---------- 132 batch : Tuple[torch.Tensor, torch.Tensor] 133 Subset (batch) of images coming from the validation dataloader. 134 Contains input images and corresponding labels. 135 batch_idx : int 136 Identifier of the batch within the current validation epoch. 137 138 Returns 139 ------- 140 torch.Tensor 141 Validation accuracy for the current batch. 142 """ 143 xb, yb = batch 144 out = self(xb) 145 loss = self.loss_fn(out, yb) 146 preds = out.argmax(dim=1) 147 acc = (preds == yb).float().mean() 148 self.log( 149 "val_loss", 150 loss, 151 on_step=False, 152 on_epoch=True, 153 prog_bar=False 154 ) 155 self.log( 156 "val_acc", 157 acc, 158 on_step=False, 159 on_epoch=True, 160 prog_bar=True 161 ) 162 return acc 163 164 def configure_optimizers(self): 165 """ 166 Configure the optimizer for training. 167 168 Returns 169 ------- 170 torch.optim.Adam 171 Adam optimizer configured with model parameters and learning rate 172 from hyperparameters. 173 """ 174 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
Attributes
- model (torchvision.models.resnet18): Pretrained ResNet18 model.
- loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss()
Initialize the GarbageClassifier model.
Parameters
- num_classes (int): Number of output classes for classification.
- lr (float, optional): Learning rate for the optimizer (default is 1e-3).
71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x)
Forward pass through the model.
Parameters
- x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
- torch.Tensor: Output logits with shape (batch_size, num_classes).
88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 110 # Accuracy 111 preds = out.argmax(dim=1) 112 acc = (preds == yb).float().mean() 113 114 self.log( 115 "train_loss", 116 loss, 117 on_step=False, 118 on_epoch=True, 119 prog_bar=True 120 ) 121 122 self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True) 123 124 return loss
Model parameters are updated according to the classification error of a subset of train images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current epoch.
Returns
- torch.Tensor: Classification error (cross entropy loss) of trained image batch.
126 def validation_step(self, batch, batch_idx): 127 """ 128 Compute validation loss and accuracy for a batch of validation images. 129 130 Parameters 131 ---------- 132 batch : Tuple[torch.Tensor, torch.Tensor] 133 Subset (batch) of images coming from the validation dataloader. 134 Contains input images and corresponding labels. 135 batch_idx : int 136 Identifier of the batch within the current validation epoch. 137 138 Returns 139 ------- 140 torch.Tensor 141 Validation accuracy for the current batch. 142 """ 143 xb, yb = batch 144 out = self(xb) 145 loss = self.loss_fn(out, yb) 146 preds = out.argmax(dim=1) 147 acc = (preds == yb).float().mean() 148 self.log( 149 "val_loss", 150 loss, 151 on_step=False, 152 on_epoch=True, 153 prog_bar=False 154 ) 155 self.log( 156 "val_acc", 157 acc, 158 on_step=False, 159 on_epoch=True, 160 prog_bar=True 161 ) 162 return acc
Compute validation loss and accuracy for a batch of validation images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
- torch.Tensor: Validation accuracy for the current batch.
164 def configure_optimizers(self): 165 """ 166 Configure the optimizer for training. 167 168 Returns 169 ------- 170 torch.optim.Adam 171 Adam optimizer configured with model parameters and learning rate 172 from hyperparameters. 173 """ 174 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Configure the optimizer for training.
Returns
- torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.