source.train

Training Script for Garbage Classification Model.

This script orchestrates the training process for a garbage classification model using PyTorch Lightning. It can be used both as a standalone script and as an importable module. Includes carbon emissions tracking.

Usage

Command line: $ uv run source/train.py

As a module: from source.train import train_model train_model(batch_size=32, lr=1e-3, max_epochs=10)

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Training Script for Garbage Classification Model.
  5
  6This script orchestrates the training process for a garbage classification
  7model using PyTorch Lightning. It can be used both as a standalone script
  8and as an importable module. Includes carbon emissions tracking.
  9
 10Usage
 11-----
 12Command line:
 13    $ uv run source/train.py
 14
 15As a module:
 16    from source.train import train_model
 17    train_model(batch_size=32, lr=1e-3, max_epochs=10)
 18"""
 19__docformat__ = "numpy"
 20
 21import pytorch_lightning as pl
 22from pathlib import Path
 23from codecarbon import EmissionsTracker
 24from source.utils import config as cfg
 25from source.utils.carbon_utils import (
 26    kg_co2_to_car_distance,
 27    format_car_distance
 28)
 29from source.utils.custom_classes.GarbageDataModule import GarbageDataModule
 30from source.utils.custom_classes.GarbageClassifier import GarbageClassifier
 31from source.utils.custom_classes.LossCurveCallback import LossCurveCallback
 32
 33
 34def train_model(
 35    batch_size: int = 32,
 36    lr: float = 1e-3,
 37    max_epochs: int = None,
 38    model_save_path: str = None,
 39    loss_curves_dir: str = None,
 40    track_carbon: bool = True,
 41    progress_callback=None,
 42):
 43    """
 44    Train the garbage classification model using PyTorch Lightning.
 45
 46    Orchestrates the complete training pipeline including data loading, model
 47    initialization, training, and evaluation. Optionally tracks
 48    carbon emissions during training and provides progress callbacks
 49    for UI integration.
 50
 51    Parameters
 52    ----------
 53    batch_size : int, default=32
 54        Batch size for training and validation data loaders.
 55    lr : float, default=1e-3
 56        Learning rate for the optimizer.
 57    max_epochs : int, optional
 58        Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS.
 59        Default is None.
 60    model_save_path : str, optional
 61        Path to save the trained model checkpoint.
 62        If None, uses cfg.MODEL_PATH.
 63        Default is None.
 64    loss_curves_dir : str, optional
 65        Directory to save loss curve visualizations.
 66        If None, uses cfg.LOSS_CURVES_PATH.
 67        Default is None.
 68    track_carbon : bool, default=True
 69        Whether to track carbon emissions during training using CodeCarbon.
 70    progress_callback : callable, optional
 71        Callback function to report training progress.
 72        Called with a message string for UI updates.
 73        Default is None (no progress reporting).
 74
 75    Returns
 76    -------
 77    dict
 78        Dictionary containing:
 79        - 'trainer': pl.Trainer
 80            PyTorch Lightning trainer instance.
 81        - 'model': GarbageClassifier
 82            Trained model instance.
 83        - 'data_module': GarbageDataModule
 84            Data module used for training.
 85        - 'emissions': dict or None
 86            Carbon emissions data with keys: 'emissions_kg', 'emissions_g',
 87            'car_distance_km', 'car_distance_m', 'car_distance_formatted',
 88            'duration_seconds'. None if track_carbon=False.
 89        - 'metrics': dict
 90            Training and validation metrics with keys: 'train_acc', 'val_acc',
 91            'train_loss', 'val_loss'.
 92
 93    Raises
 94    ------
 95    Exception
 96        Any exception during training is re-raised after stopping
 97        emissions tracker if applicable.
 98
 99    Notes
100    -----
101    Carbon emissions are converted to equivalent car driving distance for
102    intuitive understanding. Model checkpoint is automatically saved to disk.
103    If progress_callback is provided, it receives status updates at key points
104    during training initialization and completion.
105    """
106    # Use config defaults if not provided
107    if max_epochs is None:
108        max_epochs = cfg.MAX_EPOCHS
109    if model_save_path is None:
110        model_save_path = cfg.MODEL_PATH
111    if loss_curves_dir is None:
112        loss_curves_dir = cfg.LOSS_CURVES_PATH
113
114    # Initialize emissions tracker
115    emissions_data = None
116    if track_carbon:
117        tracker = EmissionsTracker(
118            project_name="garbage_classifier_training",
119            output_dir=str(Path(model_save_path).parent),
120            log_level="warning",  # Reduce console output
121        )
122        tracker.start()
123
124    try:
125        # Initialize data module
126        if progress_callback:
127            progress_callback("Initializing data module...")
128        data_module = GarbageDataModule(batch_size=batch_size)
129        data_module.setup()
130
131        # Initialize model
132        if progress_callback:
133            progress_callback("Creating model...")
134        model = GarbageClassifier(num_classes=data_module.num_classes, lr=lr)
135
136        # Setup callback
137        loss_curve_callback = LossCurveCallback(save_dir=loss_curves_dir)
138
139        # Configure trainer
140        if progress_callback:
141            progress_callback(f"Starting training for {max_epochs} epochs...")
142        trainer = pl.Trainer(
143            max_epochs=max_epochs,
144            accelerator="auto",
145            devices=1,
146            callbacks=[loss_curve_callback],
147            num_sanity_val_steps=0,
148        )
149
150        # Train
151        trainer.fit(model, datamodule=data_module)
152
153        # Extract final metrics
154        metrics = {}
155        if trainer.callback_metrics:
156            metrics["train_acc"] = trainer.callback_metrics.get(
157                "train_acc",
158                None
159            )
160            metrics["val_acc"] = trainer.callback_metrics.get(
161                "val_acc",
162                None
163            )
164            metrics["train_loss"] = trainer.callback_metrics.get(
165                "train_loss",
166                None
167            )
168            metrics["val_loss"] = trainer.callback_metrics.get(
169                "val_loss",
170                None
171            )
172
173            # Convert tensors to float if needed
174            for key in metrics:
175                if metrics[key] is not None:
176                    if hasattr(metrics[key], "item"):
177                        metrics[key] = metrics[key].item()
178
179        # Save model
180        if progress_callback:
181            progress_callback("Saving model...")
182        Path(model_save_path).parent.mkdir(parents=True, exist_ok=True)
183        trainer.save_checkpoint(model_save_path)
184
185        # Stop emissions tracking
186        if track_carbon:
187            emissions_kg = tracker.stop()
188            car_distances = kg_co2_to_car_distance(emissions_kg)
189            emissions_data = {
190                "emissions_kg": emissions_kg,
191                "emissions_g": emissions_kg * 1000,
192                "car_distance_km": car_distances["distance_km"],
193                "car_distance_m": car_distances["distance_m"],
194                "car_distance_formatted": format_car_distance(emissions_kg),
195                "duration_seconds": (
196                    tracker._total_duration.total_seconds()
197                    if hasattr(tracker, "_total_duration")
198                    else None
199                ),
200            }
201
202        if progress_callback:
203            msg = f"✅ Training complete! Model saved at {model_save_path}"
204            if emissions_data:
205                msg += (
206                    f"\n🌍 Carbon footprint:  \
207                        {emissions_data['emissions_g']:.2f}g \
208                            CO₂eq"
209                )
210                msg += f"\n🚗 Equivalent to driving: \
211                    {emissions_data['car_distance_formatted']}"
212            progress_callback(msg)
213
214        print(f"Model saved at {model_save_path}")
215        if emissions_data:
216            print(
217                f"Carbon emissions: {emissions_data['emissions_kg']:.6f}kg \
218                    CO₂eq ({emissions_data['emissions_g']:.2f}g)"
219            )
220            print(f"Equivalent to driving: \
221                {emissions_data['car_distance_formatted']}")
222
223        return {
224            "trainer": trainer,
225            "model": model,
226            "data_module": data_module,
227            "emissions": emissions_data,
228            "metrics": metrics,
229        }
230
231    except Exception as e:
232        if track_carbon:
233            tracker.stop()
234        raise e
235
236
237# ========================
238# CLI Entry Point
239# ========================
240if __name__ == "__main__":
241    """
242    Main entry point for the training script when run from command line.
243
244    Executes the complete training pipeline using default configuration from
245    the config module. Displays final metrics and carbon emissions statistics.
246
247    Parameters
248    ----------
249    None
250
251    Returns
252    -------
253    None
254    """
255    print("Starting training with default configuration...")
256    result = train_model()
257    if result["emissions"]:
258        print(
259            f"\n🌍 Total carbon footprint: \
260                {result['emissions']['emissions_g']:.2f}g \
261                    CO₂eq"
262        )
263        print(
264            f"🚗 Equivalent to driving: \
265                {result['emissions']['car_distance_formatted']}"
266        )
267    if result["metrics"]:
268        print("\n📊 Final Metrics:")
269        if result["metrics"].get("train_acc") is not None:
270            print(f"  Train Accuracy: {result['metrics']['train_acc']:.4f}")
271        if result["metrics"].get("val_acc") is not None:
272            print(f"  Validation Accuracy: {result['metrics']['val_acc']:.4f}")
def train_model( batch_size: int = 32, lr: float = 0.001, max_epochs: int = None, model_save_path: str = None, loss_curves_dir: str = None, track_carbon: bool = True, progress_callback=None):
 35def train_model(
 36    batch_size: int = 32,
 37    lr: float = 1e-3,
 38    max_epochs: int = None,
 39    model_save_path: str = None,
 40    loss_curves_dir: str = None,
 41    track_carbon: bool = True,
 42    progress_callback=None,
 43):
 44    """
 45    Train the garbage classification model using PyTorch Lightning.
 46
 47    Orchestrates the complete training pipeline including data loading, model
 48    initialization, training, and evaluation. Optionally tracks
 49    carbon emissions during training and provides progress callbacks
 50    for UI integration.
 51
 52    Parameters
 53    ----------
 54    batch_size : int, default=32
 55        Batch size for training and validation data loaders.
 56    lr : float, default=1e-3
 57        Learning rate for the optimizer.
 58    max_epochs : int, optional
 59        Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS.
 60        Default is None.
 61    model_save_path : str, optional
 62        Path to save the trained model checkpoint.
 63        If None, uses cfg.MODEL_PATH.
 64        Default is None.
 65    loss_curves_dir : str, optional
 66        Directory to save loss curve visualizations.
 67        If None, uses cfg.LOSS_CURVES_PATH.
 68        Default is None.
 69    track_carbon : bool, default=True
 70        Whether to track carbon emissions during training using CodeCarbon.
 71    progress_callback : callable, optional
 72        Callback function to report training progress.
 73        Called with a message string for UI updates.
 74        Default is None (no progress reporting).
 75
 76    Returns
 77    -------
 78    dict
 79        Dictionary containing:
 80        - 'trainer': pl.Trainer
 81            PyTorch Lightning trainer instance.
 82        - 'model': GarbageClassifier
 83            Trained model instance.
 84        - 'data_module': GarbageDataModule
 85            Data module used for training.
 86        - 'emissions': dict or None
 87            Carbon emissions data with keys: 'emissions_kg', 'emissions_g',
 88            'car_distance_km', 'car_distance_m', 'car_distance_formatted',
 89            'duration_seconds'. None if track_carbon=False.
 90        - 'metrics': dict
 91            Training and validation metrics with keys: 'train_acc', 'val_acc',
 92            'train_loss', 'val_loss'.
 93
 94    Raises
 95    ------
 96    Exception
 97        Any exception during training is re-raised after stopping
 98        emissions tracker if applicable.
 99
100    Notes
101    -----
102    Carbon emissions are converted to equivalent car driving distance for
103    intuitive understanding. Model checkpoint is automatically saved to disk.
104    If progress_callback is provided, it receives status updates at key points
105    during training initialization and completion.
106    """
107    # Use config defaults if not provided
108    if max_epochs is None:
109        max_epochs = cfg.MAX_EPOCHS
110    if model_save_path is None:
111        model_save_path = cfg.MODEL_PATH
112    if loss_curves_dir is None:
113        loss_curves_dir = cfg.LOSS_CURVES_PATH
114
115    # Initialize emissions tracker
116    emissions_data = None
117    if track_carbon:
118        tracker = EmissionsTracker(
119            project_name="garbage_classifier_training",
120            output_dir=str(Path(model_save_path).parent),
121            log_level="warning",  # Reduce console output
122        )
123        tracker.start()
124
125    try:
126        # Initialize data module
127        if progress_callback:
128            progress_callback("Initializing data module...")
129        data_module = GarbageDataModule(batch_size=batch_size)
130        data_module.setup()
131
132        # Initialize model
133        if progress_callback:
134            progress_callback("Creating model...")
135        model = GarbageClassifier(num_classes=data_module.num_classes, lr=lr)
136
137        # Setup callback
138        loss_curve_callback = LossCurveCallback(save_dir=loss_curves_dir)
139
140        # Configure trainer
141        if progress_callback:
142            progress_callback(f"Starting training for {max_epochs} epochs...")
143        trainer = pl.Trainer(
144            max_epochs=max_epochs,
145            accelerator="auto",
146            devices=1,
147            callbacks=[loss_curve_callback],
148            num_sanity_val_steps=0,
149        )
150
151        # Train
152        trainer.fit(model, datamodule=data_module)
153
154        # Extract final metrics
155        metrics = {}
156        if trainer.callback_metrics:
157            metrics["train_acc"] = trainer.callback_metrics.get(
158                "train_acc",
159                None
160            )
161            metrics["val_acc"] = trainer.callback_metrics.get(
162                "val_acc",
163                None
164            )
165            metrics["train_loss"] = trainer.callback_metrics.get(
166                "train_loss",
167                None
168            )
169            metrics["val_loss"] = trainer.callback_metrics.get(
170                "val_loss",
171                None
172            )
173
174            # Convert tensors to float if needed
175            for key in metrics:
176                if metrics[key] is not None:
177                    if hasattr(metrics[key], "item"):
178                        metrics[key] = metrics[key].item()
179
180        # Save model
181        if progress_callback:
182            progress_callback("Saving model...")
183        Path(model_save_path).parent.mkdir(parents=True, exist_ok=True)
184        trainer.save_checkpoint(model_save_path)
185
186        # Stop emissions tracking
187        if track_carbon:
188            emissions_kg = tracker.stop()
189            car_distances = kg_co2_to_car_distance(emissions_kg)
190            emissions_data = {
191                "emissions_kg": emissions_kg,
192                "emissions_g": emissions_kg * 1000,
193                "car_distance_km": car_distances["distance_km"],
194                "car_distance_m": car_distances["distance_m"],
195                "car_distance_formatted": format_car_distance(emissions_kg),
196                "duration_seconds": (
197                    tracker._total_duration.total_seconds()
198                    if hasattr(tracker, "_total_duration")
199                    else None
200                ),
201            }
202
203        if progress_callback:
204            msg = f"✅ Training complete! Model saved at {model_save_path}"
205            if emissions_data:
206                msg += (
207                    f"\n🌍 Carbon footprint:  \
208                        {emissions_data['emissions_g']:.2f}g \
209                            CO₂eq"
210                )
211                msg += f"\n🚗 Equivalent to driving: \
212                    {emissions_data['car_distance_formatted']}"
213            progress_callback(msg)
214
215        print(f"Model saved at {model_save_path}")
216        if emissions_data:
217            print(
218                f"Carbon emissions: {emissions_data['emissions_kg']:.6f}kg \
219                    CO₂eq ({emissions_data['emissions_g']:.2f}g)"
220            )
221            print(f"Equivalent to driving: \
222                {emissions_data['car_distance_formatted']}")
223
224        return {
225            "trainer": trainer,
226            "model": model,
227            "data_module": data_module,
228            "emissions": emissions_data,
229            "metrics": metrics,
230        }
231
232    except Exception as e:
233        if track_carbon:
234            tracker.stop()
235        raise e

Train the garbage classification model using PyTorch Lightning.

Orchestrates the complete training pipeline including data loading, model initialization, training, and evaluation. Optionally tracks carbon emissions during training and provides progress callbacks for UI integration.

Parameters
  • batch_size (int, default=32): Batch size for training and validation data loaders.
  • lr (float, default=1e-3): Learning rate for the optimizer.
  • max_epochs (int, optional): Maximum number of training epochs. If None, uses cfg.MAX_EPOCHS. Default is None.
  • model_save_path (str, optional): Path to save the trained model checkpoint. If None, uses cfg.MODEL_PATH. Default is None.
  • loss_curves_dir (str, optional): Directory to save loss curve visualizations. If None, uses cfg.LOSS_CURVES_PATH. Default is None.
  • track_carbon (bool, default=True): Whether to track carbon emissions during training using CodeCarbon.
  • progress_callback (callable, optional): Callback function to report training progress. Called with a message string for UI updates. Default is None (no progress reporting).
Returns
  • dict: Dictionary containing:
    • 'trainer': pl.Trainer PyTorch Lightning trainer instance.
    • 'model': GarbageClassifier Trained model instance.
    • 'data_module': GarbageDataModule Data module used for training.
    • 'emissions': dict or None Carbon emissions data with keys: 'emissions_kg', 'emissions_g', 'car_distance_km', 'car_distance_m', 'car_distance_formatted', 'duration_seconds'. None if track_carbon=False.
    • 'metrics': dict Training and validation metrics with keys: 'train_acc', 'val_acc', 'train_loss', 'val_loss'.
Raises
  • Exception: Any exception during training is re-raised after stopping emissions tracker if applicable.
Notes

Carbon emissions are converted to equivalent car driving distance for intuitive understanding. Model checkpoint is automatically saved to disk. If progress_callback is provided, it receives status updates at key points during training initialization and completion.