utils.custom_classes.GarbageDataModule
Garbage Dataset DataModule for PyTorch Lightning.
This module provides a LightningDataModule implementation for loading and preparing the garbage classification dataset with stratified train/test splits.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Garbage Dataset DataModule for PyTorch Lightning. 5 6This module provides a LightningDataModule implementation for loading and 7preparing the garbage classification dataset with stratified train/test splits. 8""" 9__docformat__ = "numpy" 10 11import pytorch_lightning as pl 12from torch.utils.data import DataLoader, Subset 13from torchvision import datasets, models 14from sklearn.model_selection import train_test_split 15import numpy as np 16from utils import config as cfg 17 18 19class GarbageDataModule(pl.LightningDataModule): 20 """ 21 PyTorch Lightning DataModule for Garbage Classification Dataset. 22 23 This DataModule handles loading, splitting, and creating dataloaders for 24 the garbage classification dataset. It performs a stratified 90/10 25 train/test split and applies ResNet18 ImageNet preprocessing transforms. 26 27 Attributes 28 ---------- 29 batch_size : int 30 Number of samples per batch for training. 31 num_workers : int 32 Number of subprocesses to use for data loading. 33 transform : torchvision.transforms.Compose 34 Image preprocessing transforms from ResNet18 ImageNet weights. 35 train_dataset : torch.utils.data.Subset 36 Training dataset subset. 37 test_dataset : torch.utils.data.Subset 38 Test/validation dataset subset. 39 train_idx : numpy.ndarray 40 Indices of samples in the training set. 41 test_idx : numpy.ndarray 42 Indices of samples in the test set. 43 num_classes : int 44 Number of classes in the dataset. 45 46 Examples 47 -------- 48 >>> data_module = GarbageDataModule(batch_size=32, num_workers=4) 49 >>> data_module.setup() 50 >>> train_loader = data_module.train_dataloader() 51 >>> val_loader = data_module.val_dataloader() 52 """ 53 54 def __init__(self, batch_size=32, num_workers=4): 55 """ 56 Initialize the GarbageDataModule. 57 58 Parameters 59 ---------- 60 batch_size : int, optional 61 Number of samples per batch for training (default is 32). 62 num_workers : int, optional 63 Number of subprocesses to use for data loading (default is 4). 64 """ 65 66 super().__init__() 67 self.batch_size = batch_size 68 self.num_workers = num_workers 69 self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 70 71 def setup(self, stage=None): 72 """ 73 Prepare the dataset by loading and splitting into train and test sets. 74 75 Loads the full dataset from the configured path, performs a stratified 76 90/10 train/test split to ensure balanced class distribution, and 77 creates dataset subsets for training and validation. 78 79 Parameters 80 ---------- 81 stage : str, optional 82 Current stage ('fit', 'validate', 'test', or 'predict'). 83 Not used in this implementation (default is None). 84 85 Notes 86 ----- 87 The dataset is split using a stratified approach with random_state=42 88 for reproducibility. The split ratio is 90% training and 10% testing. 89 """ 90 91 # Load full dataset 92 full_dataset = datasets.ImageFolder( 93 cfg.DATASET_PATH, 94 transform=self.transform 95 ) 96 targets = [label for _, label in full_dataset] 97 self.num_classes = cfg.NUM_CLASSES 98 99 # Stratified split 90/10 100 train_idx, test_idx = train_test_split( 101 np.arange(len(targets)), 102 test_size=0.1, 103 stratify=targets, 104 random_state=42 105 ) 106 107 self.train_dataset = Subset(full_dataset, train_idx) 108 self.test_dataset = Subset(full_dataset, test_idx) 109 self.train_idx = train_idx 110 self.test_idx = test_idx 111 112 def train_dataloader(self): 113 """ 114 Create and return the training dataloader. 115 116 Returns 117 ------- 118 torch.utils.data.DataLoader 119 DataLoader for the training dataset with shuffling enabled. 120 121 Notes 122 ----- 123 The dataloader uses the configured batch_size and num_workers, 124 and shuffles the data at each epoch. 125 """ 126 127 return DataLoader(self.train_dataset, 128 batch_size=self.batch_size, 129 shuffle=True, 130 num_workers=self.num_workers) 131 132 def val_dataloader(self): 133 """ 134 Create and return the validation dataloader. 135 136 Returns 137 ------- 138 torch.utils.data.DataLoader 139 DataLoader for the validation/test dataset without shuffling. 140 141 Notes 142 ----- 143 The dataloader uses a fixed batch_size of 1000 for faster validation, 144 with num_workers from configuration. Shuffling is disabled to ensure 145 consistent validation metrics. 146 """ 147 148 return DataLoader(self.test_dataset, 149 batch_size=1000, 150 shuffle=False, 151 num_workers=self.num_workers)
20class GarbageDataModule(pl.LightningDataModule): 21 """ 22 PyTorch Lightning DataModule for Garbage Classification Dataset. 23 24 This DataModule handles loading, splitting, and creating dataloaders for 25 the garbage classification dataset. It performs a stratified 90/10 26 train/test split and applies ResNet18 ImageNet preprocessing transforms. 27 28 Attributes 29 ---------- 30 batch_size : int 31 Number of samples per batch for training. 32 num_workers : int 33 Number of subprocesses to use for data loading. 34 transform : torchvision.transforms.Compose 35 Image preprocessing transforms from ResNet18 ImageNet weights. 36 train_dataset : torch.utils.data.Subset 37 Training dataset subset. 38 test_dataset : torch.utils.data.Subset 39 Test/validation dataset subset. 40 train_idx : numpy.ndarray 41 Indices of samples in the training set. 42 test_idx : numpy.ndarray 43 Indices of samples in the test set. 44 num_classes : int 45 Number of classes in the dataset. 46 47 Examples 48 -------- 49 >>> data_module = GarbageDataModule(batch_size=32, num_workers=4) 50 >>> data_module.setup() 51 >>> train_loader = data_module.train_dataloader() 52 >>> val_loader = data_module.val_dataloader() 53 """ 54 55 def __init__(self, batch_size=32, num_workers=4): 56 """ 57 Initialize the GarbageDataModule. 58 59 Parameters 60 ---------- 61 batch_size : int, optional 62 Number of samples per batch for training (default is 32). 63 num_workers : int, optional 64 Number of subprocesses to use for data loading (default is 4). 65 """ 66 67 super().__init__() 68 self.batch_size = batch_size 69 self.num_workers = num_workers 70 self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 71 72 def setup(self, stage=None): 73 """ 74 Prepare the dataset by loading and splitting into train and test sets. 75 76 Loads the full dataset from the configured path, performs a stratified 77 90/10 train/test split to ensure balanced class distribution, and 78 creates dataset subsets for training and validation. 79 80 Parameters 81 ---------- 82 stage : str, optional 83 Current stage ('fit', 'validate', 'test', or 'predict'). 84 Not used in this implementation (default is None). 85 86 Notes 87 ----- 88 The dataset is split using a stratified approach with random_state=42 89 for reproducibility. The split ratio is 90% training and 10% testing. 90 """ 91 92 # Load full dataset 93 full_dataset = datasets.ImageFolder( 94 cfg.DATASET_PATH, 95 transform=self.transform 96 ) 97 targets = [label for _, label in full_dataset] 98 self.num_classes = cfg.NUM_CLASSES 99 100 # Stratified split 90/10 101 train_idx, test_idx = train_test_split( 102 np.arange(len(targets)), 103 test_size=0.1, 104 stratify=targets, 105 random_state=42 106 ) 107 108 self.train_dataset = Subset(full_dataset, train_idx) 109 self.test_dataset = Subset(full_dataset, test_idx) 110 self.train_idx = train_idx 111 self.test_idx = test_idx 112 113 def train_dataloader(self): 114 """ 115 Create and return the training dataloader. 116 117 Returns 118 ------- 119 torch.utils.data.DataLoader 120 DataLoader for the training dataset with shuffling enabled. 121 122 Notes 123 ----- 124 The dataloader uses the configured batch_size and num_workers, 125 and shuffles the data at each epoch. 126 """ 127 128 return DataLoader(self.train_dataset, 129 batch_size=self.batch_size, 130 shuffle=True, 131 num_workers=self.num_workers) 132 133 def val_dataloader(self): 134 """ 135 Create and return the validation dataloader. 136 137 Returns 138 ------- 139 torch.utils.data.DataLoader 140 DataLoader for the validation/test dataset without shuffling. 141 142 Notes 143 ----- 144 The dataloader uses a fixed batch_size of 1000 for faster validation, 145 with num_workers from configuration. Shuffling is disabled to ensure 146 consistent validation metrics. 147 """ 148 149 return DataLoader(self.test_dataset, 150 batch_size=1000, 151 shuffle=False, 152 num_workers=self.num_workers)
PyTorch Lightning DataModule for Garbage Classification Dataset.
This DataModule handles loading, splitting, and creating dataloaders for the garbage classification dataset. It performs a stratified 90/10 train/test split and applies ResNet18 ImageNet preprocessing transforms.
Attributes
- batch_size (int): Number of samples per batch for training.
- num_workers (int): Number of subprocesses to use for data loading.
- transform (torchvision.transforms.Compose): Image preprocessing transforms from ResNet18 ImageNet weights.
- train_dataset (torch.utils.data.Subset): Training dataset subset.
- test_dataset (torch.utils.data.Subset): Test/validation dataset subset.
- train_idx (numpy.ndarray): Indices of samples in the training set.
- test_idx (numpy.ndarray): Indices of samples in the test set.
- num_classes (int): Number of classes in the dataset.
Examples
>>> data_module = GarbageDataModule(batch_size=32, num_workers=4)
>>> data_module.setup()
>>> train_loader = data_module.train_dataloader()
>>> val_loader = data_module.val_dataloader()
55 def __init__(self, batch_size=32, num_workers=4): 56 """ 57 Initialize the GarbageDataModule. 58 59 Parameters 60 ---------- 61 batch_size : int, optional 62 Number of samples per batch for training (default is 32). 63 num_workers : int, optional 64 Number of subprocesses to use for data loading (default is 4). 65 """ 66 67 super().__init__() 68 self.batch_size = batch_size 69 self.num_workers = num_workers 70 self.transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
Initialize the GarbageDataModule.
Parameters
- batch_size (int, optional): Number of samples per batch for training (default is 32).
- num_workers (int, optional): Number of subprocesses to use for data loading (default is 4).
72 def setup(self, stage=None): 73 """ 74 Prepare the dataset by loading and splitting into train and test sets. 75 76 Loads the full dataset from the configured path, performs a stratified 77 90/10 train/test split to ensure balanced class distribution, and 78 creates dataset subsets for training and validation. 79 80 Parameters 81 ---------- 82 stage : str, optional 83 Current stage ('fit', 'validate', 'test', or 'predict'). 84 Not used in this implementation (default is None). 85 86 Notes 87 ----- 88 The dataset is split using a stratified approach with random_state=42 89 for reproducibility. The split ratio is 90% training and 10% testing. 90 """ 91 92 # Load full dataset 93 full_dataset = datasets.ImageFolder( 94 cfg.DATASET_PATH, 95 transform=self.transform 96 ) 97 targets = [label for _, label in full_dataset] 98 self.num_classes = cfg.NUM_CLASSES 99 100 # Stratified split 90/10 101 train_idx, test_idx = train_test_split( 102 np.arange(len(targets)), 103 test_size=0.1, 104 stratify=targets, 105 random_state=42 106 ) 107 108 self.train_dataset = Subset(full_dataset, train_idx) 109 self.test_dataset = Subset(full_dataset, test_idx) 110 self.train_idx = train_idx 111 self.test_idx = test_idx
Prepare the dataset by loading and splitting into train and test sets.
Loads the full dataset from the configured path, performs a stratified 90/10 train/test split to ensure balanced class distribution, and creates dataset subsets for training and validation.
Parameters
- stage (str, optional): Current stage ('fit', 'validate', 'test', or 'predict'). Not used in this implementation (default is None).
Notes
The dataset is split using a stratified approach with random_state=42 for reproducibility. The split ratio is 90% training and 10% testing.
113 def train_dataloader(self): 114 """ 115 Create and return the training dataloader. 116 117 Returns 118 ------- 119 torch.utils.data.DataLoader 120 DataLoader for the training dataset with shuffling enabled. 121 122 Notes 123 ----- 124 The dataloader uses the configured batch_size and num_workers, 125 and shuffles the data at each epoch. 126 """ 127 128 return DataLoader(self.train_dataset, 129 batch_size=self.batch_size, 130 shuffle=True, 131 num_workers=self.num_workers)
Create and return the training dataloader.
Returns
- torch.utils.data.DataLoader: DataLoader for the training dataset with shuffling enabled.
Notes
The dataloader uses the configured batch_size and num_workers, and shuffles the data at each epoch.
133 def val_dataloader(self): 134 """ 135 Create and return the validation dataloader. 136 137 Returns 138 ------- 139 torch.utils.data.DataLoader 140 DataLoader for the validation/test dataset without shuffling. 141 142 Notes 143 ----- 144 The dataloader uses a fixed batch_size of 1000 for faster validation, 145 with num_workers from configuration. Shuffling is disabled to ensure 146 consistent validation metrics. 147 """ 148 149 return DataLoader(self.test_dataset, 150 batch_size=1000, 151 shuffle=False, 152 num_workers=self.num_workers)
Create and return the validation dataloader.
Returns
- torch.utils.data.DataLoader: DataLoader for the validation/test dataset without shuffling.
Notes
The dataloader uses a fixed batch_size of 1000 for faster validation, with num_workers from configuration. Shuffling is disabled to ensure consistent validation metrics.