source.utils.custom_classes.GarbageClassifier

Garbage Classification Model Module.

This module implements a PyTorch Lightning module for garbage classification using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class garbage classification problem (cardboard, glass, metal, paper, plastic, trash).

The model uses transfer learning by freezing the pretrained ResNet18 feature extraction layers and training only the final classification layer.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Garbage Classification Model Module.
  5
  6This module implements a PyTorch Lightning module for garbage classification
  7using a pretrained ResNet18 model. The classifier is fine-tuned for a 6-class
  8garbage classification problem (cardboard, glass, metal, paper, plastic,
  9trash).
 10
 11The model uses transfer learning by freezing the pretrained ResNet18 feature
 12extraction layers and training only the final classification layer.
 13"""
 14__docformat__ = "numpy"
 15
 16import pytorch_lightning as pl
 17import torch
 18from torch import nn
 19from torchvision import models
 20
 21
 22# ========================
 23# LightningModule
 24# ========================
 25class GarbageClassifier(pl.LightningModule):
 26    """
 27    Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification
 28    problem.
 29    It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
 30
 31    Attributes
 32    ----------
 33    model : torchvision.models.resnet18
 34        Pretrained ResNet18 model.
 35    loss_fn : torch.nn.CrossEntropyLoss
 36        Cross entropy loss function.
 37
 38    Examples
 39    --------
 40    >>> model = GarbageClassifier(num_classes=6, lr=1e-3)
 41    >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
 42    >>> trainer.fit(model, datamodule=data_module)
 43    """
 44
 45    def __init__(self, num_classes, lr=1e-3):
 46        """
 47        Initialize the GarbageClassifier model.
 48
 49        Parameters
 50        ----------
 51        num_classes : int
 52            Number of output classes for classification.
 53        lr : float, optional
 54            Learning rate for the optimizer (default is 1e-3).
 55        """
 56        super().__init__()
 57        self.save_hyperparameters()
 58        self.model = models.resnet18(
 59            weights=models.ResNet18_Weights.IMAGENET1K_V1
 60        )
 61
 62        # Freeze feature layers
 63        for param in self.model.parameters():
 64            param.requires_grad = False
 65
 66        # New classifier layer
 67        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
 68        self.loss_fn = nn.CrossEntropyLoss()
 69
 70    def forward(self, x):
 71        """
 72        Forward pass through the model.
 73
 74        Parameters
 75        ----------
 76        x : torch.Tensor
 77            Input tensor of images with shape (batch_size, channels, height,
 78            width).
 79
 80        Returns
 81        -------
 82        torch.Tensor
 83            Output logits with shape (batch_size, num_classes).
 84        """
 85        return self.model(x)
 86
 87    def training_step(self, batch, batch_idx):
 88        """
 89        Model parameters are updated according to the classification error
 90        of a subset of train images.
 91
 92        Parameters
 93        ----------
 94        batch : Tuple[torch.Tensor, torch.Tensor]
 95            Subset (batch) of images coming from the train dataloader.
 96            Contains input images and corresponding labels.
 97        batch_idx : int
 98            Identifier of the batch within the current epoch.
 99
100        Returns
101        -------
102        torch.Tensor
103            Classification error (cross entropy loss) of trained image batch.
104        """
105        xb, yb = batch
106        out = self(xb)
107        loss = self.loss_fn(out, yb)
108
109        # Accuracy
110        preds = out.argmax(dim=1)
111        acc = (preds == yb).float().mean()
112
113        self.log(
114            "train_loss",
115            loss,
116            on_step=False,
117            on_epoch=True,
118            prog_bar=True
119        )
120
121        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
122
123        return loss
124
125    def validation_step(self, batch, batch_idx):
126        """
127        Compute validation loss and accuracy for a batch of validation images.
128
129        Parameters
130        ----------
131        batch : Tuple[torch.Tensor, torch.Tensor]
132            Subset (batch) of images coming from the validation dataloader.
133            Contains input images and corresponding labels.
134        batch_idx : int
135            Identifier of the batch within the current validation epoch.
136
137        Returns
138        -------
139        torch.Tensor
140            Validation accuracy for the current batch.
141        """
142        xb, yb = batch
143        out = self(xb)
144        loss = self.loss_fn(out, yb)
145        preds = out.argmax(dim=1)
146        acc = (preds == yb).float().mean()
147        self.log(
148            "val_loss",
149            loss,
150            on_step=False,
151            on_epoch=True,
152            prog_bar=False
153        )
154        self.log(
155            "val_acc",
156            acc,
157            on_step=False,
158            on_epoch=True,
159            prog_bar=True
160        )
161        return acc
162
163    def configure_optimizers(self):
164        """
165        Configure the optimizer for training.
166
167        Returns
168        -------
169        torch.optim.Adam
170            Adam optimizer configured with model parameters and learning rate
171            from hyperparameters.
172        """
173        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
class GarbageClassifier(pytorch_lightning.core.module.LightningModule):
 26class GarbageClassifier(pl.LightningModule):
 27    """
 28    Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification
 29    problem.
 30    It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.
 31
 32    Attributes
 33    ----------
 34    model : torchvision.models.resnet18
 35        Pretrained ResNet18 model.
 36    loss_fn : torch.nn.CrossEntropyLoss
 37        Cross entropy loss function.
 38
 39    Examples
 40    --------
 41    >>> model = GarbageClassifier(num_classes=6, lr=1e-3)
 42    >>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
 43    >>> trainer.fit(model, datamodule=data_module)
 44    """
 45
 46    def __init__(self, num_classes, lr=1e-3):
 47        """
 48        Initialize the GarbageClassifier model.
 49
 50        Parameters
 51        ----------
 52        num_classes : int
 53            Number of output classes for classification.
 54        lr : float, optional
 55            Learning rate for the optimizer (default is 1e-3).
 56        """
 57        super().__init__()
 58        self.save_hyperparameters()
 59        self.model = models.resnet18(
 60            weights=models.ResNet18_Weights.IMAGENET1K_V1
 61        )
 62
 63        # Freeze feature layers
 64        for param in self.model.parameters():
 65            param.requires_grad = False
 66
 67        # New classifier layer
 68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
 69        self.loss_fn = nn.CrossEntropyLoss()
 70
 71    def forward(self, x):
 72        """
 73        Forward pass through the model.
 74
 75        Parameters
 76        ----------
 77        x : torch.Tensor
 78            Input tensor of images with shape (batch_size, channels, height,
 79            width).
 80
 81        Returns
 82        -------
 83        torch.Tensor
 84            Output logits with shape (batch_size, num_classes).
 85        """
 86        return self.model(x)
 87
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109
110        # Accuracy
111        preds = out.argmax(dim=1)
112        acc = (preds == yb).float().mean()
113
114        self.log(
115            "train_loss",
116            loss,
117            on_step=False,
118            on_epoch=True,
119            prog_bar=True
120        )
121
122        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
123
124        return loss
125
126    def validation_step(self, batch, batch_idx):
127        """
128        Compute validation loss and accuracy for a batch of validation images.
129
130        Parameters
131        ----------
132        batch : Tuple[torch.Tensor, torch.Tensor]
133            Subset (batch) of images coming from the validation dataloader.
134            Contains input images and corresponding labels.
135        batch_idx : int
136            Identifier of the batch within the current validation epoch.
137
138        Returns
139        -------
140        torch.Tensor
141            Validation accuracy for the current batch.
142        """
143        xb, yb = batch
144        out = self(xb)
145        loss = self.loss_fn(out, yb)
146        preds = out.argmax(dim=1)
147        acc = (preds == yb).float().mean()
148        self.log(
149            "val_loss",
150            loss,
151            on_step=False,
152            on_epoch=True,
153            prog_bar=False
154        )
155        self.log(
156            "val_acc",
157            acc,
158            on_step=False,
159            on_epoch=True,
160            prog_bar=True
161        )
162        return acc
163
164    def configure_optimizers(self):
165        """
166        Configure the optimizer for training.
167
168        Returns
169        -------
170        torch.optim.Adam
171            Adam optimizer configured with model parameters and learning rate
172            from hyperparameters.
173        """
174        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Pretrained (ImageNet) ResNet18 adapted to Garbage Dataset Classification problem. It considers 6 classes: cardboard, glass, metal, paper, plastic and trash.

Attributes
  • model (torchvision.models.resnet18): Pretrained ResNet18 model.
  • loss_fn (torch.nn.CrossEntropyLoss): Cross entropy loss function.
Examples
>>> model = GarbageClassifier(num_classes=6, lr=1e-3)
>>> trainer = pl.Trainer(max_epochs=10, accelerator="auto")
>>> trainer.fit(model, datamodule=data_module)
GarbageClassifier(num_classes, lr=0.001)
46    def __init__(self, num_classes, lr=1e-3):
47        """
48        Initialize the GarbageClassifier model.
49
50        Parameters
51        ----------
52        num_classes : int
53            Number of output classes for classification.
54        lr : float, optional
55            Learning rate for the optimizer (default is 1e-3).
56        """
57        super().__init__()
58        self.save_hyperparameters()
59        self.model = models.resnet18(
60            weights=models.ResNet18_Weights.IMAGENET1K_V1
61        )
62
63        # Freeze feature layers
64        for param in self.model.parameters():
65            param.requires_grad = False
66
67        # New classifier layer
68        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
69        self.loss_fn = nn.CrossEntropyLoss()

Initialize the GarbageClassifier model.

Parameters
  • num_classes (int): Number of output classes for classification.
  • lr (float, optional): Learning rate for the optimizer (default is 1e-3).
model
loss_fn
def forward(self, x):
71    def forward(self, x):
72        """
73        Forward pass through the model.
74
75        Parameters
76        ----------
77        x : torch.Tensor
78            Input tensor of images with shape (batch_size, channels, height,
79            width).
80
81        Returns
82        -------
83        torch.Tensor
84            Output logits with shape (batch_size, num_classes).
85        """
86        return self.model(x)

Forward pass through the model.

Parameters
  • x (torch.Tensor): Input tensor of images with shape (batch_size, channels, height, width).
Returns
  • torch.Tensor: Output logits with shape (batch_size, num_classes).
def training_step(self, batch, batch_idx):
 88    def training_step(self, batch, batch_idx):
 89        """
 90        Model parameters are updated according to the classification error
 91        of a subset of train images.
 92
 93        Parameters
 94        ----------
 95        batch : Tuple[torch.Tensor, torch.Tensor]
 96            Subset (batch) of images coming from the train dataloader.
 97            Contains input images and corresponding labels.
 98        batch_idx : int
 99            Identifier of the batch within the current epoch.
100
101        Returns
102        -------
103        torch.Tensor
104            Classification error (cross entropy loss) of trained image batch.
105        """
106        xb, yb = batch
107        out = self(xb)
108        loss = self.loss_fn(out, yb)
109
110        # Accuracy
111        preds = out.argmax(dim=1)
112        acc = (preds == yb).float().mean()
113
114        self.log(
115            "train_loss",
116            loss,
117            on_step=False,
118            on_epoch=True,
119            prog_bar=True
120        )
121
122        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
123
124        return loss

Model parameters are updated according to the classification error of a subset of train images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the train dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current epoch.
Returns
  • torch.Tensor: Classification error (cross entropy loss) of trained image batch.
def validation_step(self, batch, batch_idx):
126    def validation_step(self, batch, batch_idx):
127        """
128        Compute validation loss and accuracy for a batch of validation images.
129
130        Parameters
131        ----------
132        batch : Tuple[torch.Tensor, torch.Tensor]
133            Subset (batch) of images coming from the validation dataloader.
134            Contains input images and corresponding labels.
135        batch_idx : int
136            Identifier of the batch within the current validation epoch.
137
138        Returns
139        -------
140        torch.Tensor
141            Validation accuracy for the current batch.
142        """
143        xb, yb = batch
144        out = self(xb)
145        loss = self.loss_fn(out, yb)
146        preds = out.argmax(dim=1)
147        acc = (preds == yb).float().mean()
148        self.log(
149            "val_loss",
150            loss,
151            on_step=False,
152            on_epoch=True,
153            prog_bar=False
154        )
155        self.log(
156            "val_acc",
157            acc,
158            on_step=False,
159            on_epoch=True,
160            prog_bar=True
161        )
162        return acc

Compute validation loss and accuracy for a batch of validation images.

Parameters
  • batch (Tuple[torch.Tensor, torch.Tensor]): Subset (batch) of images coming from the validation dataloader. Contains input images and corresponding labels.
  • batch_idx (int): Identifier of the batch within the current validation epoch.
Returns
  • torch.Tensor: Validation accuracy for the current batch.
def configure_optimizers(self):
164    def configure_optimizers(self):
165        """
166        Configure the optimizer for training.
167
168        Returns
169        -------
170        torch.optim.Adam
171            Adam optimizer configured with model parameters and learning rate
172            from hyperparameters.
173        """
174        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

Configure the optimizer for training.

Returns
  • torch.optim.Adam: Adam optimizer configured with model parameters and learning rate from hyperparameters.