utils.custom_classes.GarbageClassifier
Garbage Classification Model Module.
This module implements a PyTorch Lightning module for garbage classification using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class garbage classification problem (cardboard, glass, metal, paper, plastic, trash).
The model uses transfer learning by freezing the pretrained ResNet18 feature extraction layers and training only the final classification layer.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Garbage Classification Model Module. 5 6This module implements a PyTorch Lightning module for garbage classification 7using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class 8garbage classification problem (cardboard, glass, metal, paper, plastic, 9trash). 10 11The model uses transfer learning by freezing the pretrained ResNet18 feature 12extraction layers and training only the final classification layer. 13""" 14__docformat__ = "numpy" 15 16import pytorch_lightning as pl 17import torch 18from torch import nn 19from torchvision import models 20 21 22# ======================== 23# LightningModule 24# ======================== 25class GarbageClassifier(pl.LightningModule): 26 """ 27 Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification 28 problem. 29 It considers 6 classes: cardboard, glass, metal, paper, plastic and trash. 30 31 Attributes 32 ---------- 33 model : torchvision.models.resnet18 34 Pretrained ResNet18 model. 35 loss_fn : torch.nn.CrossEntropyLoss 36 Cross entropy loss function. 37 38 Examples 39 -------- 40 >>> model = GarbageClassifier(num_classes=6, lr=1e-3) 41 >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto") 42 >>> trainer.fit(model, datamodule=data_module) 43 """ 44 45 def __init__(self, num_classes, lr=1e-3): 46 """ 47 Initialize the GarbageClassifier model. 48 49 Parameters 50 ---------- 51 num_classes : int 52 Number of output classes for classification. 53 lr : float, optional 54 Learning rate for the optimizer (default is 1e-3). 55 """ 56 super().__init__() 57 self.save_hyperparameters() 58 self.model = models.resnet18( 59 weights=models.ResNet18_Weights.IMAGENET1K_V1 60 ) 61 62 # Freeze feature layers 63 for param in self.model.parameters(): 64 param.requires_grad = False 65 66 # New classifier layer 67 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 68 self.loss_fn = nn.CrossEntropyLoss() 69 70 def forward(self, x): 71 """ 72 Forward pass through the model. 73 74 Parameters 75 ---------- 76 x : torch.Tensor 77 Input tensor of images with shape (batch_size, channels, height, 78 width). 79 80 Returns 81 ------- 82 torch.Tensor 83 Output logits with shape (batch_size, num_classes). 84 """ 85 return self.model(x) 86 87 def training_step(self, batch, batch_idx): 88 """ 89 Model parameters are updated according to the classification error 90 of a subset of train images. 91 92 Parameters 93 ---------- 94 batch : Tuple[torch.Tensor, torch.Tensor] 95 Subset (batch) of images coming from the train dataloader. 96 Contains input images and corresponding labels. 97 batch_idx : int 98 Identifier of the batch within the current epoch. 99 100 Returns 101 ------- 102 torch.Tensor 103 Classification error (cross entropy loss) of trained image batch. 104 """ 105 xb, yb = batch 106 out = self(xb) 107 loss = self.loss_fn(out, yb) 108 self.log('train_loss', 109 loss, 110 on_step=False, 111 on_epoch=True, 112 prog_bar=True) 113 return loss 114 115 def validation_step(self, batch, batch_idx): 116 """ 117 Compute validation loss and accuracy for a batch of validation images. 118 119 Parameters 120 ---------- 121 batch : Tuple[torch.Tensor, torch.Tensor] 122 Subset (batch) of images coming from the validation dataloader. 123 Contains input images and corresponding labels. 124 batch_idx : int 125 Identifier of the batch within the current validation epoch. 126 127 Returns 128 ------- 129 torch.Tensor 130 Validation accuracy for the current batch. 131 """ 132 xb, yb = batch 133 out = self(xb) 134 loss = self.loss_fn(out, yb) 135 preds = out.argmax(dim=1) 136 acc = (preds == yb).float().mean() 137 self.log('val_loss', 138 loss, 139 on_step=False, 140 on_epoch=True, 141 prog_bar=False) 142 self.log('val_acc', 143 acc, 144 on_step=False, 145 on_epoch=True, 146 prog_bar=True) 147 return acc 148 149 def configure_optimizers(self): 150 """ 151 Configure the optimizer for training. 152 153 Returns 154 ------- 155 torch.optim.Adam 156 Adam optimizer configured with model parameters and learning rate 157 from hyperparameters. 158 """ 159 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
26class GarbageClassifier(pl.LightningModule): 27 """ 28 Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification 29 problem. 30 It considers 6 classes: cardboard, glass, metal, paper, plastic and trash. 31 32 Attributes 33 ---------- 34 model : torchvision.models.resnet18 35 Pretrained ResNet18 model. 36 loss_fn : torch.nn.CrossEntropyLoss 37 Cross entropy loss function. 38 39 Examples 40 -------- 41 >>> model = GarbageClassifier(num_classes=6, lr=1e-3) 42 >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto") 43 >>> trainer.fit(model, datamodule=data_module) 44 """ 45 46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss() 70 71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x) 87 88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 self.log('train_loss', 110 loss, 111 on_step=False, 112 on_epoch=True, 113 prog_bar=True) 114 return loss 115 116 def validation_step(self, batch, batch_idx): 117 """ 118 Compute validation loss and accuracy for a batch of validation images. 119 120 Parameters 121 ---------- 122 batch : Tuple[torch.Tensor, torch.Tensor] 123 Subset (batch) of images coming from the validation dataloader. 124 Contains input images and corresponding labels. 125 batch_idx : int 126 Identifier of the batch within the current validation epoch. 127 128 Returns 129 ------- 130 torch.Tensor 131 Validation accuracy for the current batch. 132 """ 133 xb, yb = batch 134 out = self(xb) 135 loss = self.loss_fn(out, yb) 136 preds = out.argmax(dim=1) 137 acc = (preds == yb).float().mean() 138 self.log('val_loss', 139 loss, 140 on_step=False, 141 on_epoch=True, 142 prog_bar=False) 143 self.log('val_acc', 144 acc, 145 on_step=False, 146 on_epoch=True, 147 prog_bar=True) 148 return acc 149 150 def configure_optimizers(self): 151 """ 152 Configure the optimizer for training. 153 154 Returns 155 ------- 156 torch.optim.Adam 157 Adam optimizer configured with model parameters and learning rate 158 from hyperparameters. 159 """ 160 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
Attributes
- model (torchvision.models.resnet18): Pretrained ResNet18 model.
- loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
46 def __init__(self, num_classes, lr=1e-3): 47 """ 48 Initialize the GarbageClassifier model. 49 50 Parameters 51 ---------- 52 num_classes : int 53 Number of output classes for classification. 54 lr : float, optional 55 Learning rate for the optimizer (default is 1e-3). 56 """ 57 super().__init__() 58 self.save_hyperparameters() 59 self.model = models.resnet18( 60 weights=models.ResNet18_Weights.IMAGENET1K_V1 61 ) 62 63 # Freeze feature layers 64 for param in self.model.parameters(): 65 param.requires_grad = False 66 67 # New classifier layer 68 self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) 69 self.loss_fn = nn.CrossEntropyLoss()
Initialize the GarbageClassifier model.
Parameters
- num_classes (int): Number of output classes for classification.
- lr (float, optional): Learning rate for the optimizer (default is 1e-3).
71 def forward(self, x): 72 """ 73 Forward pass through the model. 74 75 Parameters 76 ---------- 77 x : torch.Tensor 78 Input tensor of images with shape (batch_size, channels, height, 79 width). 80 81 Returns 82 ------- 83 torch.Tensor 84 Output logits with shape (batch_size, num_classes). 85 """ 86 return self.model(x)
Forward pass through the model.
Parameters
- x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
- torch.Tensor: Output logits with shape (batch_size, num_classes).
88 def training_step(self, batch, batch_idx): 89 """ 90 Model parameters are updated according to the classification error 91 of a subset of train images. 92 93 Parameters 94 ---------- 95 batch : Tuple[torch.Tensor, torch.Tensor] 96 Subset (batch) of images coming from the train dataloader. 97 Contains input images and corresponding labels. 98 batch_idx : int 99 Identifier of the batch within the current epoch. 100 101 Returns 102 ------- 103 torch.Tensor 104 Classification error (cross entropy loss) of trained image batch. 105 """ 106 xb, yb = batch 107 out = self(xb) 108 loss = self.loss_fn(out, yb) 109 self.log('train_loss', 110 loss, 111 on_step=False, 112 on_epoch=True, 113 prog_bar=True) 114 return loss
Model parameters are updated according to the classification error of a subset of train images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current epoch.
Returns
- torch.Tensor: Classification error (cross entropy loss) of trained image batch.
116 def validation_step(self, batch, batch_idx): 117 """ 118 Compute validation loss and accuracy for a batch of validation images. 119 120 Parameters 121 ---------- 122 batch : Tuple[torch.Tensor, torch.Tensor] 123 Subset (batch) of images coming from the validation dataloader. 124 Contains input images and corresponding labels. 125 batch_idx : int 126 Identifier of the batch within the current validation epoch. 127 128 Returns 129 ------- 130 torch.Tensor 131 Validation accuracy for the current batch. 132 """ 133 xb, yb = batch 134 out = self(xb) 135 loss = self.loss_fn(out, yb) 136 preds = out.argmax(dim=1) 137 acc = (preds == yb).float().mean() 138 self.log('val_loss', 139 loss, 140 on_step=False, 141 on_epoch=True, 142 prog_bar=False) 143 self.log('val_acc', 144 acc, 145 on_step=False, 146 on_epoch=True, 147 prog_bar=True) 148 return acc
Compute validation loss and accuracy for a batch of validation images.
Parameters
- batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
- batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
- torch.Tensor: Validation accuracy for the current batch.
150 def configure_optimizers(self): 151 """ 152 Configure the optimizer for training. 153 154 Returns 155 ------- 156 torch.optim.Adam 157 Adam optimizer configured with model parameters and learning rate 158 from hyperparameters. 159 """ 160 return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
Configure the optimizer for training.
Returns
- torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.