source.train
Training Script for Garbage Classification Model.
This script orchestrates the training process for a garbage classification model using PyTorch Lightning. It can be used both as a standalone script and as an importable module. Includes carbon emissions tracking.
Usage
Command line: $ uv run source/train.py
As a module: from source.train import train_model train_model(batch_size=32, lr=1e-3, max_epochs=10)
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Training Script for Garbage Classification Model. 5 6This script orchestrates the training process for a garbage classification 7model using PyTorch Lightning. It can be used both as a standalone script 8and as an importable module. Includes carbon emissions tracking. 9 10Usage 11----- 12Command line: 13 $ uv run source/train.py 14 15As a module: 16 from source.train import train_model 17 train_model(batch_size=32, lr=1e-3, max_epochs=10) 18""" 19__docformat__ = "numpy" 20 21import pytorch_lightning as pl 22from pathlib import Path 23from codecarbon import EmissionsTracker 24from source.utils import config as cfg 25from source.utils.carbon_utils import ( 26 kg_co2_to_car_distance, 27 format_car_distance 28) 29from source.utils.custom_classes.GarbageDataModule import GarbageDataModule 30from source.utils.custom_classes.GarbageClassifier import GarbageClassifier 31from source.utils.custom_classes.LossCurveCallback import LossCurveCallback 32 33 34def train_model( 35 batch_size: int = 32, 36 lr: float = 1e-3, 37 max_epochs: int = None, 38 model_save_path: str = None, 39 loss_curves_dir: str = None, 40 track_carbon: bool = True, 41 progress_callback=None, 42): 43 """ 44 Train the garbage classification model using PyTorch Lightning. 45 46 Orchestrates the complete training pipeline including data loading, model 47 initialization, training, and evaluation. Optionally tracks 48 carbon emissions during training and provides progress callbacks 49 for UI integration. 50 51 Parameters 52 ---------- 53 batch_size : int, default=32 54 Batch size for training and validation data loaders. 55 lr : float, default=1e-3 56 Learning rate for the optimizer. 57 max_epochs : int, optional 58 Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS. 59 Default is None. 60 model_save_path : str, optional 61 Path to save the trained model checkpoint. 62 If None, uses cfg.MODEL_PATH. 63 Default is None. 64 loss_curves_dir : str, optional 65 Directory to save loss curve visualizations. 66 If None, uses cfg.LOSS_CURVES_PATH. 67 Default is None. 68 track_carbon : bool, default=True 69 Whether to track carbon emissions during training using CodeCarbon. 70 progress_callback : callable, optional 71 Callback function to report training progress. 72 Called with a message string for UI updates. 73 Default is None (no progress reporting). 74 75 Returns 76 ------- 77 dict 78 Dictionary containing: 79 - 'trainer': pl.Trainer 80 PyTorch Lightning trainer instance. 81 - 'model': GarbageClassifier 82 Trained model instance. 83 - 'data_module': GarbageDataModule 84 Data module used for training. 85 - 'emissions': dict or None 86 Carbon emissions data with keys: 'emissions_kg', 'emissions_g', 87 'car_distance_km', 'car_distance_m', 'car_distance_formatted', 88 'duration_seconds'. None if track_carbon=False. 89 - 'metrics': dict 90 Training and validation metrics with keys: 'train_acc', 'val_acc', 91 'train_loss', 'val_loss'. 92 93 Raises 94 ------ 95 Exception 96 Any exception during training is re-raised after stopping 97 emissions tracker if applicable. 98 99 Notes 100 ----- 101 Carbon emissions are converted to equivalent car driving distance for 102 intuitive understanding. Model checkpoint is automatically saved to disk. 103 If progress_callback is provided, it receives status updates at key points 104 during training initialization and completion. 105 """ 106 # Use config defaults if not provided 107 if max_epochs is None: 108 max_epochs = cfg.MAX_EPOCHS 109 if model_save_path is None: 110 model_save_path = cfg.MODEL_PATH 111 if loss_curves_dir is None: 112 loss_curves_dir = cfg.LOSS_CURVES_PATH 113 114 # Initialize emissions tracker 115 emissions_data = None 116 if track_carbon: 117 tracker = EmissionsTracker( 118 project_name="garbage_classifier_training", 119 output_dir=str(Path(model_save_path).parent), 120 log_level="warning", # Reduce console output 121 ) 122 tracker.start() 123 124 try: 125 # Initialize data module 126 if progress_callback: 127 progress_callback("Initializing data module...") 128 data_module = GarbageDataModule(batch_size=batch_size) 129 data_module.setup() 130 131 # Initialize model 132 if progress_callback: 133 progress_callback("Creating model...") 134 model = GarbageClassifier(num_classes=data_module.num_classes, lr=lr) 135 136 # Setup callback 137 loss_curve_callback = LossCurveCallback(save_dir=loss_curves_dir) 138 139 # Configure trainer 140 if progress_callback: 141 progress_callback(f"Starting training for {max_epochs} epochs...") 142 trainer = pl.Trainer( 143 max_epochs=max_epochs, 144 accelerator="auto", 145 devices=1, 146 callbacks=[loss_curve_callback], 147 num_sanity_val_steps=0, 148 ) 149 150 # Train 151 trainer.fit(model, datamodule=data_module) 152 153 # Extract final metrics 154 metrics = {} 155 if trainer.callback_metrics: 156 metrics["train_acc"] = trainer.callback_metrics.get( 157 "train_acc", 158 None 159 ) 160 metrics["val_acc"] = trainer.callback_metrics.get( 161 "val_acc", 162 None 163 ) 164 metrics["train_loss"] = trainer.callback_metrics.get( 165 "train_loss", 166 None 167 ) 168 metrics["val_loss"] = trainer.callback_metrics.get( 169 "val_loss", 170 None 171 ) 172 173 # Convert tensors to float if needed 174 for key in metrics: 175 if metrics[key] is not None: 176 if hasattr(metrics[key], "item"): 177 metrics[key] = metrics[key].item() 178 179 # Save model 180 if progress_callback: 181 progress_callback("Saving model...") 182 Path(model_save_path).parent.mkdir(parents=True, exist_ok=True) 183 trainer.save_checkpoint(model_save_path) 184 185 # Stop emissions tracking 186 if track_carbon: 187 emissions_kg = tracker.stop() 188 car_distances = kg_co2_to_car_distance(emissions_kg) 189 emissions_data = { 190 "emissions_kg": emissions_kg, 191 "emissions_g": emissions_kg * 1000, 192 "car_distance_km": car_distances["distance_km"], 193 "car_distance_m": car_distances["distance_m"], 194 "car_distance_formatted": format_car_distance(emissions_kg), 195 "duration_seconds": ( 196 tracker._total_duration.total_seconds() 197 if hasattr(tracker, "_total_duration") 198 else None 199 ), 200 } 201 202 if progress_callback: 203 msg = f"✅ Training complete! Model saved at {model_save_path}" 204 if emissions_data: 205 msg += ( 206 f"\n🌍 Carbon footprint: \ 207 {emissions_data['emissions_g']:.2f}g \ 208 CO₂eq" 209 ) 210 msg += f"\n🚗 Equivalent to driving: \ 211 {emissions_data['car_distance_formatted']}" 212 progress_callback(msg) 213 214 print(f"Model saved at {model_save_path}") 215 if emissions_data: 216 print( 217 f"Carbon emissions: {emissions_data['emissions_kg']:.6f}kg \ 218 CO₂eq ({emissions_data['emissions_g']:.2f}g)" 219 ) 220 print(f"Equivalent to driving: \ 221 {emissions_data['car_distance_formatted']}") 222 223 return { 224 "trainer": trainer, 225 "model": model, 226 "data_module": data_module, 227 "emissions": emissions_data, 228 "metrics": metrics, 229 } 230 231 except Exception as e: 232 if track_carbon: 233 tracker.stop() 234 raise e 235 236 237# ======================== 238# CLI Entry Point 239# ======================== 240if __name__ == "__main__": 241 """ 242 Main entry point for the training script when run from command line. 243 244 Executes the complete training pipeline using default configuration from 245 the config module. Displays final metrics and carbon emissions statistics. 246 247 Parameters 248 ---------- 249 None 250 251 Returns 252 ------- 253 None 254 """ 255 print("Starting training with default configuration...") 256 result = train_model() 257 if result["emissions"]: 258 print( 259 f"\n🌍 Total carbon footprint: \ 260 {result['emissions']['emissions_g']:.2f}g \ 261 CO₂eq" 262 ) 263 print( 264 f"🚗 Equivalent to driving: \ 265 {result['emissions']['car_distance_formatted']}" 266 ) 267 if result["metrics"]: 268 print("\n📊 Final Metrics:") 269 if result["metrics"].get("train_acc") is not None: 270 print(f" Train Accuracy: {result['metrics']['train_acc']:.4f}") 271 if result["metrics"].get("val_acc") is not None: 272 print(f" Validation Accuracy: {result['metrics']['val_acc']:.4f}")
35def train_model( 36 batch_size: int = 32, 37 lr: float = 1e-3, 38 max_epochs: int = None, 39 model_save_path: str = None, 40 loss_curves_dir: str = None, 41 track_carbon: bool = True, 42 progress_callback=None, 43): 44 """ 45 Train the garbage classification model using PyTorch Lightning. 46 47 Orchestrates the complete training pipeline including data loading, model 48 initialization, training, and evaluation. Optionally tracks 49 carbon emissions during training and provides progress callbacks 50 for UI integration. 51 52 Parameters 53 ---------- 54 batch_size : int, default=32 55 Batch size for training and validation data loaders. 56 lr : float, default=1e-3 57 Learning rate for the optimizer. 58 max_epochs : int, optional 59 Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS. 60 Default is None. 61 model_save_path : str, optional 62 Path to save the trained model checkpoint. 63 If None, uses cfg.MODEL_PATH. 64 Default is None. 65 loss_curves_dir : str, optional 66 Directory to save loss curve visualizations. 67 If None, uses cfg.LOSS_CURVES_PATH. 68 Default is None. 69 track_carbon : bool, default=True 70 Whether to track carbon emissions during training using CodeCarbon. 71 progress_callback : callable, optional 72 Callback function to report training progress. 73 Called with a message string for UI updates. 74 Default is None (no progress reporting). 75 76 Returns 77 ------- 78 dict 79 Dictionary containing: 80 - 'trainer': pl.Trainer 81 PyTorch Lightning trainer instance. 82 - 'model': GarbageClassifier 83 Trained model instance. 84 - 'data_module': GarbageDataModule 85 Data module used for training. 86 - 'emissions': dict or None 87 Carbon emissions data with keys: 'emissions_kg', 'emissions_g', 88 'car_distance_km', 'car_distance_m', 'car_distance_formatted', 89 'duration_seconds'. None if track_carbon=False. 90 - 'metrics': dict 91 Training and validation metrics with keys: 'train_acc', 'val_acc', 92 'train_loss', 'val_loss'. 93 94 Raises 95 ------ 96 Exception 97 Any exception during training is re-raised after stopping 98 emissions tracker if applicable. 99 100 Notes 101 ----- 102 Carbon emissions are converted to equivalent car driving distance for 103 intuitive understanding. Model checkpoint is automatically saved to disk. 104 If progress_callback is provided, it receives status updates at key points 105 during training initialization and completion. 106 """ 107 # Use config defaults if not provided 108 if max_epochs is None: 109 max_epochs = cfg.MAX_EPOCHS 110 if model_save_path is None: 111 model_save_path = cfg.MODEL_PATH 112 if loss_curves_dir is None: 113 loss_curves_dir = cfg.LOSS_CURVES_PATH 114 115 # Initialize emissions tracker 116 emissions_data = None 117 if track_carbon: 118 tracker = EmissionsTracker( 119 project_name="garbage_classifier_training", 120 output_dir=str(Path(model_save_path).parent), 121 log_level="warning", # Reduce console output 122 ) 123 tracker.start() 124 125 try: 126 # Initialize data module 127 if progress_callback: 128 progress_callback("Initializing data module...") 129 data_module = GarbageDataModule(batch_size=batch_size) 130 data_module.setup() 131 132 # Initialize model 133 if progress_callback: 134 progress_callback("Creating model...") 135 model = GarbageClassifier(num_classes=data_module.num_classes, lr=lr) 136 137 # Setup callback 138 loss_curve_callback = LossCurveCallback(save_dir=loss_curves_dir) 139 140 # Configure trainer 141 if progress_callback: 142 progress_callback(f"Starting training for {max_epochs} epochs...") 143 trainer = pl.Trainer( 144 max_epochs=max_epochs, 145 accelerator="auto", 146 devices=1, 147 callbacks=[loss_curve_callback], 148 num_sanity_val_steps=0, 149 ) 150 151 # Train 152 trainer.fit(model, datamodule=data_module) 153 154 # Extract final metrics 155 metrics = {} 156 if trainer.callback_metrics: 157 metrics["train_acc"] = trainer.callback_metrics.get( 158 "train_acc", 159 None 160 ) 161 metrics["val_acc"] = trainer.callback_metrics.get( 162 "val_acc", 163 None 164 ) 165 metrics["train_loss"] = trainer.callback_metrics.get( 166 "train_loss", 167 None 168 ) 169 metrics["val_loss"] = trainer.callback_metrics.get( 170 "val_loss", 171 None 172 ) 173 174 # Convert tensors to float if needed 175 for key in metrics: 176 if metrics[key] is not None: 177 if hasattr(metrics[key], "item"): 178 metrics[key] = metrics[key].item() 179 180 # Save model 181 if progress_callback: 182 progress_callback("Saving model...") 183 Path(model_save_path).parent.mkdir(parents=True, exist_ok=True) 184 trainer.save_checkpoint(model_save_path) 185 186 # Stop emissions tracking 187 if track_carbon: 188 emissions_kg = tracker.stop() 189 car_distances = kg_co2_to_car_distance(emissions_kg) 190 emissions_data = { 191 "emissions_kg": emissions_kg, 192 "emissions_g": emissions_kg * 1000, 193 "car_distance_km": car_distances["distance_km"], 194 "car_distance_m": car_distances["distance_m"], 195 "car_distance_formatted": format_car_distance(emissions_kg), 196 "duration_seconds": ( 197 tracker._total_duration.total_seconds() 198 if hasattr(tracker, "_total_duration") 199 else None 200 ), 201 } 202 203 if progress_callback: 204 msg = f"✅ Training complete! Model saved at {model_save_path}" 205 if emissions_data: 206 msg += ( 207 f"\n🌍 Carbon footprint: \ 208 {emissions_data['emissions_g']:.2f}g \ 209 CO₂eq" 210 ) 211 msg += f"\n🚗 Equivalent to driving: \ 212 {emissions_data['car_distance_formatted']}" 213 progress_callback(msg) 214 215 print(f"Model saved at {model_save_path}") 216 if emissions_data: 217 print( 218 f"Carbon emissions: {emissions_data['emissions_kg']:.6f}kg \ 219 CO₂eq ({emissions_data['emissions_g']:.2f}g)" 220 ) 221 print(f"Equivalent to driving: \ 222 {emissions_data['car_distance_formatted']}") 223 224 return { 225 "trainer": trainer, 226 "model": model, 227 "data_module": data_module, 228 "emissions": emissions_data, 229 "metrics": metrics, 230 } 231 232 except Exception as e: 233 if track_carbon: 234 tracker.stop() 235 raise e
Train the garbage classification model using PyTorch Lightning.
Orchestrates the complete training pipeline including data loading, model initialization, training, and evaluation. Optionally tracks carbon emissions during training and provides progress callbacks for UI integration.
Parameters
- batch_size (int, default=32): Batch size for training and validation data loaders.
- lr (float, default=1e-3): Learning rate for the optimizer.
- max_epochs (int, optional): Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS. Default is None.
- model_save_path (str, optional): Path to save the trained model checkpoint. If None, uses cfg.MODEL_PATH. Default is None.
- loss_curves_dir (str, optional): Directory to save loss curve visualizations. If None, uses cfg.LOSS_CURVES_PATH. Default is None.
- track_carbon (bool, default=True): Whether to track carbon emissions during training using CodeCarbon.
- progress_callback (callable, optional): Callback function to report training progress. Called with a message string for UI updates. Default is None (no progress reporting).
Returns
- dict: Dictionary containing:
- 'trainer': pl.Trainer PyTorch Lightning trainer instance.
- 'model': GarbageClassifier Trained model instance.
- 'data_module': GarbageDataModule Data module used for training.
- 'emissions': dict or None Carbon emissions data with keys: 'emissions_kg', 'emissions_g', 'car_distance_km', 'car_distance_m', 'car_distance_formatted', 'duration_seconds'. None if track_carbon=False.
- 'metrics': dict Training and validation metrics with keys: 'train_acc', 'val_acc', 'train_loss', 'val_loss'.
Raises
- Exception: Any exception during training is re-raised after stopping emissions tracker if applicable.
Notes
Carbon emissions are converted to equivalent car driving distance for intuitive understanding. Model checkpoint is automatically saved to disk. If progress_callback is provided, it receives status updates at key points during training initialization and completion.