utils.custom_classes.GarbageClassifier

Garbage Classification Model Module.

This module implements a PyTorch Lightning module for garbage classification using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class garbage classification problem (cardboard, glass, metal, paper, plastic, trash).

The model uses transfer learning by freezing the pretrained ResNet18 feature extraction layers and training only the final classification layer.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Garbage Classification Model Module.
  5
  6This module implements a PyTorch Lightning module for garbage classification
  7using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class
  8garbage classification problem (cardboard, glass, metal, paper, plastic,
  9trash).
 10
 11The model uses transfer learning by freezing the pretrained ResNet18 feature
 12extraction layers and training only the final classification layer.
 13"""
 14__docformat__ = "numpy"
 15
 16import pytorch_lightning as pl
 17import torch
 18from torch import nn
 19from torchvision import models
 20
 21
 22# ========================
 23# LightningModule
 24# ========================
 25class GarbageClassifier(pl.LightningModule):
 26    """
 27    Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification
 28    problem.
 29    It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
 30
 31    Attributes
 32    ----------
 33    model : torchvision.models.resnet18
 34        Pretrained ResNet18 model.
 35    loss_fn : torch.nn.CrossEntropyLoss
 36        Cross entropy loss function.
 37
 38    Examples
 39    --------
 40    >>> model = GarbageClassifier(num_classes=6, lr=1e-3)
 41    >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
 42    >>> trainer.fit(model, datamodule=data_module)
 43    """
 44
 45    def __init__(self, num_classes, lr=1e-3):
 46        """
 47        Initialize the GarbageClassifier model.
 48
 49        Parameters
 50        ----------
 51        num_classes : int
 52            Number of output classes for classification.
 53        lr : float, optional
 54            Learning rate for the optimizer (default is 1e-3).
 55        """
 56        super().__init__()
 57        self.save_hyperparameters()
 58        self.model = models.resnet18(
 59            weights=models.ResNet18_Weights.IMAGENET1K_V1
 60        )
 61
 62        # Freeze feature layers
 63        for param in self.model.parameters():
 64            param.requires_grad = False
 65
 66        # New classifier layer
 67        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
 68        self.loss_fn = nn.CrossEntropyLoss()
 69
 70    def forward(self, x):
 71        """
 72        Forward pass through the model.
 73
 74        Parameters
 75        ----------
 76        x : torch.Tensor
 77            Input tensor of images with shape (batch_size, channels, height,
 78            width).
 79
 80        Returns
 81        -------
 82        torch.Tensor
 83            Output logits with shape (batch_size, num_classes).
 84        """
 85        return self.model(x)
 86
 87    def training_step(self, batch, batch_idx):
 88        """
 89        Model parameters are updated according to the classification error
 90        of a subset of train images.
 91
 92        Parameters
 93        ----------
 94        batch : Tuple[torch.Tensor, torch.Tensor]
 95            Subset (batch) of images coming from the train dataloader.
 96            Contains input images and corresponding labels.
 97        batch_idx : int
 98            Identifier of the batch within the current epoch.
 99
100        Returns
101        -------
102        torch.Tensor
103            Classification error (cross entropy loss) of trained image batch.
104        """
105        xb, yb = batch
106        out = self(xb)
107        loss = self.loss_fn(out, yb)
108        self.log('train_loss',
109                 loss,
110                 on_step=False,
111                 on_epoch=True,
112                 prog_bar=True)
113        return loss
114
115    def validation_step(self, batch, batch_idx):
116        """
117        Compute validation loss and accuracy for a batch of validation images.
118
119        Parameters
120        ----------
121        batch : Tuple[torch.Tensor, torch.Tensor]
122            Subset (batch) of images coming from the validation dataloader.
123            Contains input images and corresponding labels.
124        batch_idx : int
125            Identifier of the batch within the current validation epoch.
126
127        Returns
128        -------
129        torch.Tensor
130            Validation accuracy for the current batch.
131        """
132        xb, yb = batch
133        out = self(xb)
134        loss = self.loss_fn(out, yb)
135        preds = out.argmax(dim=1)
136        acc = (preds == yb).float().mean()
137        self.log('val_loss',
138                 loss,
139                 on_step=False,
140                 on_epoch=True,
141                 prog_bar=False)
142        self.log('val_acc',
143                 acc,
144                 on_step=False,
145                 on_epoch=True,
146                 prog_bar=True)
147        return acc
148
149    def configure_optimizers(self):
150        """
151        Configure the optimizer for training.
152
153        Returns
154        -------
155        torch.optim.Adam
156            Adam optimizer configured with model parameters and learning rate
157            from hyperparameters.
158        """
159        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
class GarbageClassifier(pytorch_lightning.core.module.LightningModule):
 26class GarbageClassifier(pl.LightningModule):
 27    """
 28    Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification
 29    problem.
 30    It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
 31
 32    Attributes
 33    ----------
 34    model : torchvision.models.resnet18
 35        Pretrained ResNet18 model.
 36    loss_fn : torch.nn.CrossEntropyLoss
 37        Cross entropy loss function.
 38
 39    Examples
 40    --------
 41    >>> model = GarbageClassifier(num_classes=6, lr=1e-3)
 42    >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
 43    >>> trainer.fit(model, datamodule=data_module)
 44    """
 45
 46    def __init__(self, num_classes, lr=1e-3):
 47        """
 48        Initialize the GarbageClassifier model.
 49
 50        Parameters
 51        ----------
 52        num_classes : int
 53            Number of output classes for classification.
 54        lr : float, optional
 55            Learning rate for the optimizer (default is 1e-3).
 56        """
 57        super().__init__()
 58        self.save_hyperparameters()
 59        self.model = models.resnet18(
 60            weights=models.ResNet18_Weights.IMAGENET1K_V1
 61        )
 62
 63        # Freeze feature layers
 64        for param in self.model.parameters():
 65            param.requires_grad = False
 66
 67        # New classifier layer
 68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
 69        self.loss_fn = nn.CrossEntropyLoss()
 70
 71    def forward(self, x):
 72        """
 73        Forward pass through the model.
 74
 75        Parameters
 76        ----------
 77        x : torch.Tensor
 78            Input tensor of images with shape (batch_size, channels, height,
 79            width).
 80
 81        Returns
 82        -------
 83        torch.Tensor
 84            Output logits with shape (batch_size, num_classes).
 85        """
 86        return self.model(x)
 87
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109        self.log('train_loss',
110                 loss,
111                 on_step=False,
112                 on_epoch=True,
113                 prog_bar=True)
114        return loss
115
116    def validation_step(self, batch, batch_idx):
117        """
118        Compute validation loss and accuracy for a batch of validation images.
119
120        Parameters
121        ----------
122        batch : Tuple[torch.Tensor, torch.Tensor]
123            Subset (batch) of images coming from the validation dataloader.
124            Contains input images and corresponding labels.
125        batch_idx : int
126            Identifier of the batch within the current validation epoch.
127
128        Returns
129        -------
130        torch.Tensor
131            Validation accuracy for the current batch.
132        """
133        xb, yb = batch
134        out = self(xb)
135        loss = self.loss_fn(out, yb)
136        preds = out.argmax(dim=1)
137        acc = (preds == yb).float().mean()
138        self.log('val_loss',
139                 loss,
140                 on_step=False,
141                 on_epoch=True,
142                 prog_bar=False)
143        self.log('val_acc',
144                 acc,
145                 on_step=False,
146                 on_epoch=True,
147                 prog_bar=True)
148        return acc
149
150    def configure_optimizers(self):
151        """
152        Configure the optimizer for training.
153
154        Returns
155        -------
156        torch.optim.Adam
157            Adam optimizer configured with model parameters and learning rate
158            from hyperparameters.
159        """
160        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.

Attributes
  • model (torchvision.models.resnet18): Pretrained ResNet18 model.
  • loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
GarbageClassifier(num_classes, lr=0.001)
46    def __init__(self, num_classes, lr=1e-3):
47        """
48        Initialize the GarbageClassifier model.
49
50        Parameters
51        ----------
52        num_classes : int
53            Number of output classes for classification.
54        lr : float, optional
55            Learning rate for the optimizer (default is 1e-3).
56        """
57        super().__init__()
58        self.save_hyperparameters()
59        self.model = models.resnet18(
60            weights=models.ResNet18_Weights.IMAGENET1K_V1
61        )
62
63        # Freeze feature layers
64        for param in self.model.parameters():
65            param.requires_grad = False
66
67        # New classifier layer
68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
69        self.loss_fn = nn.CrossEntropyLoss()

Initialize the GarbageClassifier model.

Parameters
  • num_classes (int): Number of output classes for classification.
  • lr (float, optional): Learning rate for the optimizer (default is 1e-3).
def forward(self, x):
71    def forward(self, x):
72        """
73        Forward pass through the model.
74
75        Parameters
76        ----------
77        x : torch.Tensor
78            Input tensor of images with shape (batch_size, channels, height,
79            width).
80
81        Returns
82        -------
83        torch.Tensor
84            Output logits with shape (batch_size, num_classes).
85        """
86        return self.model(x)

Forward pass through the model.

Parameters
  • x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
  • torch.Tensor: Output logits with shape (batch_size, num_classes).
def training_step(self, batch, batch_idx):
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109        self.log('train_loss',
110                 loss,
111                 on_step=False,
112                 on_epoch=True,
113                 prog_bar=True)
114        return loss

Model parameters are updated according to the classification error of a subset of train images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current epoch.
Returns
  • torch.Tensor: Classification error (cross entropy loss) of trained image batch.
def validation_step(self, batch, batch_idx):
116    def validation_step(self, batch, batch_idx):
117        """
118        Compute validation loss and accuracy for a batch of validation images.
119
120        Parameters
121        ----------
122        batch : Tuple[torch.Tensor, torch.Tensor]
123            Subset (batch) of images coming from the validation dataloader.
124            Contains input images and corresponding labels.
125        batch_idx : int
126            Identifier of the batch within the current validation epoch.
127
128        Returns
129        -------
130        torch.Tensor
131            Validation accuracy for the current batch.
132        """
133        xb, yb = batch
134        out = self(xb)
135        loss = self.loss_fn(out, yb)
136        preds = out.argmax(dim=1)
137        acc = (preds == yb).float().mean()
138        self.log('val_loss',
139                 loss,
140                 on_step=False,
141                 on_epoch=True,
142                 prog_bar=False)
143        self.log('val_acc',
144                 acc,
145                 on_step=False,
146                 on_epoch=True,
147                 prog_bar=True)
148        return acc

Compute validation loss and accuracy for a batch of validation images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
  • torch.Tensor: Validation accuracy for the current batch.
def configure_optimizers(self):
150    def configure_optimizers(self):
151        """
152        Configure the optimizer for training.
153
154        Returns
155        -------
156        torch.optim.Adam
157            Adam optimizer configured with model parameters and learning rate
158            from hyperparameters.
159        """
160        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Configure the optimizer for training.

Returns
  • torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.