utils.custom_classes.GarbageDataModule

Garbage Dataset DataModule for PyTorch Lightning.

This module provides a LightningDataModule implementation for loading and preparing the garbage classification dataset with stratified train/test splits.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Garbage Dataset DataModule for PyTorch Lightning.
  5
  6This module provides a LightningDataModule implementation for loading and
  7preparing the garbage classification dataset with stratified train/test splits.
  8"""
  9__docformat__ = "numpy"
 10
 11import pytorch_lightning as pl
 12from torch.utils.data import DataLoader, Subset
 13from torchvision import datasets, models
 14from sklearn.model_selection import train_test_split
 15import numpy as np
 16from utils import config as cfg
 17
 18
 19class GarbageDataModule(pl.LightningDataModule):
 20    """
 21    PyTorch Lightning DataModule for Garbage Classification Dataset.
 22
 23    This DataModule handles loading, splitting, and creating dataloaders for
 24    the garbage classification dataset. It performs a stratified 90/10
 25    train/test split and applies ResNet18 ImageNet preprocessing transforms.
 26
 27    Attributes
 28    ----------
 29    batch_size : int
 30        Number of samples per batch for training.
 31    num_workers : int
 32        Number of subprocesses to use for data loading.
 33    transform : torchvision.transforms.Compose
 34        Image preprocessing transforms from ResNet18 ImageNet weights.
 35    train_dataset : torch.utils.data.Subset
 36        Training dataset subset.
 37    test_dataset : torch.utils.data.Subset
 38        Test/validation dataset subset.
 39    train_idx : numpy.ndarray
 40        Indices of samples in the training set.
 41    test_idx : numpy.ndarray
 42        Indices of samples in the test set.
 43    num_classes : int
 44        Number of classes in the dataset.
 45
 46    Examples
 47    --------
 48    >>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
 49    >>> data_module.setup()
 50    >>> train_loader = data_module.train_dataloader()
 51    >>> val_loader = data_module.val_dataloader()
 52    """
 53
 54    def __init__(self, batch_size=32, num_workers=4):
 55        """
 56        Initialize the GarbageDataModule.
 57
 58        Parameters
 59        ----------
 60        batch_size : int, optional
 61            Number of samples per batch for training (default is 32).
 62        num_workers : int, optional
 63            Number of subprocesses to use for data loading (default is 4).
 64        """
 65
 66        super().__init__()
 67        self.batch_size = batch_size
 68        self.num_workers = num_workers
 69        self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 70
 71    def setup(self, stage=None):
 72        """
 73        Prepare the dataset by loading and splitting into train and test sets.
 74
 75        Loads the full dataset from the configured path, performs a stratified
 76        90/10 train/test split to ensure balanced class distribution, and
 77        creates dataset subsets for training and validation.
 78
 79        Parameters
 80        ----------
 81        stage : str, optional
 82            Current stage ('fit', 'validate', 'test', or 'predict').
 83            Not used in this implementation (default is None).
 84
 85        Notes
 86        -----
 87        The dataset is split using a stratified approach with random_state=42
 88        for reproducibility. The split ratio is 90% training and 10% testing.
 89        """
 90
 91        # Load full dataset
 92        full_dataset = datasets.ImageFolder(
 93            cfg.DATASET_PATH,
 94            transform=self.transform
 95        )
 96        targets = [label for _, label in full_dataset]
 97        self.num_classes = cfg.NUM_CLASSES
 98
 99        # Stratified split 90/10
100        train_idx, test_idx = train_test_split(
101            np.arange(len(targets)),
102            test_size=0.1,
103            stratify=targets,
104            random_state=42
105        )
106
107        self.train_dataset = Subset(full_dataset, train_idx)
108        self.test_dataset = Subset(full_dataset, test_idx)
109        self.train_idx = train_idx
110        self.test_idx = test_idx
111
112    def train_dataloader(self):
113        """
114        Create and return the training dataloader.
115
116        Returns
117        -------
118        torch.utils.data.DataLoader
119            DataLoader for the training dataset with shuffling enabled.
120
121        Notes
122        -----
123        The dataloader uses the configured batch_size and num_workers,
124        and shuffles the data at each epoch.
125        """
126
127        return DataLoader(self.train_dataset,
128                          batch_size=self.batch_size,
129                          shuffle=True,
130                          num_workers=self.num_workers)
131
132    def val_dataloader(self):
133        """
134        Create and return the validation dataloader.
135
136        Returns
137        -------
138        torch.utils.data.DataLoader
139            DataLoader for the validation/test dataset without shuffling.
140
141        Notes
142        -----
143        The dataloader uses a fixed batch_size of 1000 for faster validation,
144        with num_workers from configuration. Shuffling is disabled to ensure
145        consistent validation metrics.
146        """
147
148        return DataLoader(self.test_dataset,
149                          batch_size=1000,
150                          shuffle=False,
151                          num_workers=self.num_workers)
class GarbageDataModule(pytorch_lightning.core.datamodule.LightningDataModule):
 20class GarbageDataModule(pl.LightningDataModule):
 21    """
 22    PyTorch Lightning DataModule for Garbage Classification Dataset.
 23
 24    This DataModule handles loading, splitting, and creating dataloaders for
 25    the garbage classification dataset. It performs a stratified 90/10
 26    train/test split and applies ResNet18 ImageNet preprocessing transforms.
 27
 28    Attributes
 29    ----------
 30    batch_size : int
 31        Number of samples per batch for training.
 32    num_workers : int
 33        Number of subprocesses to use for data loading.
 34    transform : torchvision.transforms.Compose
 35        Image preprocessing transforms from ResNet18 ImageNet weights.
 36    train_dataset : torch.utils.data.Subset
 37        Training dataset subset.
 38    test_dataset : torch.utils.data.Subset
 39        Test/validation dataset subset.
 40    train_idx : numpy.ndarray
 41        Indices of samples in the training set.
 42    test_idx : numpy.ndarray
 43        Indices of samples in the test set.
 44    num_classes : int
 45        Number of classes in the dataset.
 46
 47    Examples
 48    --------
 49    >>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
 50    >>> data_module.setup()
 51    >>> train_loader = data_module.train_dataloader()
 52    >>> val_loader = data_module.val_dataloader()
 53    """
 54
 55    def __init__(self, batch_size=32, num_workers=4):
 56        """
 57        Initialize the GarbageDataModule.
 58
 59        Parameters
 60        ----------
 61        batch_size : int, optional
 62            Number of samples per batch for training (default is 32).
 63        num_workers : int, optional
 64            Number of subprocesses to use for data loading (default is 4).
 65        """
 66
 67        super().__init__()
 68        self.batch_size = batch_size
 69        self.num_workers = num_workers
 70        self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 71
 72    def setup(self, stage=None):
 73        """
 74        Prepare the dataset by loading and splitting into train and test sets.
 75
 76        Loads the full dataset from the configured path, performs a stratified
 77        90/10 train/test split to ensure balanced class distribution, and
 78        creates dataset subsets for training and validation.
 79
 80        Parameters
 81        ----------
 82        stage : str, optional
 83            Current stage ('fit', 'validate', 'test', or 'predict').
 84            Not used in this implementation (default is None).
 85
 86        Notes
 87        -----
 88        The dataset is split using a stratified approach with random_state=42
 89        for reproducibility. The split ratio is 90% training and 10% testing.
 90        """
 91
 92        # Load full dataset
 93        full_dataset = datasets.ImageFolder(
 94            cfg.DATASET_PATH,
 95            transform=self.transform
 96        )
 97        targets = [label for _, label in full_dataset]
 98        self.num_classes = cfg.NUM_CLASSES
 99
100        # Stratified split 90/10
101        train_idx, test_idx = train_test_split(
102            np.arange(len(targets)),
103            test_size=0.1,
104            stratify=targets,
105            random_state=42
106        )
107
108        self.train_dataset = Subset(full_dataset, train_idx)
109        self.test_dataset = Subset(full_dataset, test_idx)
110        self.train_idx = train_idx
111        self.test_idx = test_idx
112
113    def train_dataloader(self):
114        """
115        Create and return the training dataloader.
116
117        Returns
118        -------
119        torch.utils.data.DataLoader
120            DataLoader for the training dataset with shuffling enabled.
121
122        Notes
123        -----
124        The dataloader uses the configured batch_size and num_workers,
125        and shuffles the data at each epoch.
126        """
127
128        return DataLoader(self.train_dataset,
129                          batch_size=self.batch_size,
130                          shuffle=True,
131                          num_workers=self.num_workers)
132
133    def val_dataloader(self):
134        """
135        Create and return the validation dataloader.
136
137        Returns
138        -------
139        torch.utils.data.DataLoader
140            DataLoader for the validation/test dataset without shuffling.
141
142        Notes
143        -----
144        The dataloader uses a fixed batch_size of 1000 for faster validation,
145        with num_workers from configuration. Shuffling is disabled to ensure
146        consistent validation metrics.
147        """
148
149        return DataLoader(self.test_dataset,
150                          batch_size=1000,
151                          shuffle=False,
152                          num_workers=self.num_workers)

PyTorch Lightning DataModule for Garbage Classification Dataset.

This DataModule handles loading, splitting, and creating dataloaders for the garbage classification dataset. It performs a stratified 90/10 train/test split and applies ResNet18 ImageNet preprocessing transforms.

Attributes
  • batch_size (int): Number of samples per batch for training.
  • num_workers (int): Number of subprocesses to use for data loading.
  • transform (torchvision.transforms.Compose): Image preprocessing transforms from ResNet18 ImageNet weights.
  • train_dataset (torch.utils.data.Subset): Training dataset subset.
  • test_dataset (torch.utils.data.Subset): Test/validation dataset subset.
  • train_idx (numpy.ndarray): Indices of samples in the training set.
  • test_idx (numpy.ndarray): Indices of samples in the test set.
  • num_classes (int): Number of classes in the dataset.
Examples
>>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
>>> data_module.setup()
>>> train_loader = data_module.train_dataloader()
>>> val_loader = data_module.val_dataloader()
GarbageDataModule(batch_size=32, num_workers=4)
55    def __init__(self, batch_size=32, num_workers=4):
56        """
57        Initialize the GarbageDataModule.
58
59        Parameters
60        ----------
61        batch_size : int, optional
62            Number of samples per batch for training (default is 32).
63        num_workers : int, optional
64            Number of subprocesses to use for data loading (default is 4).
65        """
66
67        super().__init__()
68        self.batch_size = batch_size
69        self.num_workers = num_workers
70        self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()

Initialize the GarbageDataModule.

Parameters
  • batch_size (int, optional): Number of samples per batch for training (default is 32).
  • num_workers (int, optional): Number of subprocesses to use for data loading (default is 4).
def setup(self, stage=None):
 72    def setup(self, stage=None):
 73        """
 74        Prepare the dataset by loading and splitting into train and test sets.
 75
 76        Loads the full dataset from the configured path, performs a stratified
 77        90/10 train/test split to ensure balanced class distribution, and
 78        creates dataset subsets for training and validation.
 79
 80        Parameters
 81        ----------
 82        stage : str, optional
 83            Current stage ('fit', 'validate', 'test', or 'predict').
 84            Not used in this implementation (default is None).
 85
 86        Notes
 87        -----
 88        The dataset is split using a stratified approach with random_state=42
 89        for reproducibility. The split ratio is 90% training and 10% testing.
 90        """
 91
 92        # Load full dataset
 93        full_dataset = datasets.ImageFolder(
 94            cfg.DATASET_PATH,
 95            transform=self.transform
 96        )
 97        targets = [label for _, label in full_dataset]
 98        self.num_classes = cfg.NUM_CLASSES
 99
100        # Stratified split 90/10
101        train_idx, test_idx = train_test_split(
102            np.arange(len(targets)),
103            test_size=0.1,
104            stratify=targets,
105            random_state=42
106        )
107
108        self.train_dataset = Subset(full_dataset, train_idx)
109        self.test_dataset = Subset(full_dataset, test_idx)
110        self.train_idx = train_idx
111        self.test_idx = test_idx

Prepare the dataset by loading and splitting into train and test sets.

Loads the full dataset from the configured path, performs a stratified 90/10 train/test split to ensure balanced class distribution, and creates dataset subsets for training and validation.

Parameters
  • stage (str, optional): Current stage ('fit', 'validate', 'test', or 'predict'). Not used in this implementation (default is None).
Notes

The dataset is split using a stratified approach with random_state=42 for reproducibility. The split ratio is 90% training and 10% testing.

def train_dataloader(self):
113    def train_dataloader(self):
114        """
115        Create and return the training dataloader.
116
117        Returns
118        -------
119        torch.utils.data.DataLoader
120            DataLoader for the training dataset with shuffling enabled.
121
122        Notes
123        -----
124        The dataloader uses the configured batch_size and num_workers,
125        and shuffles the data at each epoch.
126        """
127
128        return DataLoader(self.train_dataset,
129                          batch_size=self.batch_size,
130                          shuffle=True,
131                          num_workers=self.num_workers)

Create and return the training dataloader.

Returns
  • torch.utils.data.DataLoader: DataLoader for the training dataset with shuffling enabled.
Notes

The dataloader uses the configured batch_size and num_workers, and shuffles the data at each epoch.

def val_dataloader(self):
133    def val_dataloader(self):
134        """
135        Create and return the validation dataloader.
136
137        Returns
138        -------
139        torch.utils.data.DataLoader
140            DataLoader for the validation/test dataset without shuffling.
141
142        Notes
143        -----
144        The dataloader uses a fixed batch_size of 1000 for faster validation,
145        with num_workers from configuration. Shuffling is disabled to ensure
146        consistent validation metrics.
147        """
148
149        return DataLoader(self.test_dataset,
150                          batch_size=1000,
151                          shuffle=False,
152                          num_workers=self.num_workers)

Create and return the validation dataloader.

Returns
  • torch.utils.data.DataLoader: DataLoader for the validation/test dataset without shuffling.
Notes

The dataloader uses a fixed batch_size of 1000 for faster validation, with num_workers from configuration. Shuffling is disabled to ensure consistent validation metrics.