train

Training Script for Garbage Classification Model.

This script orchestrates the training process for a garbage classification model using PyTorch Lightning. It initializes the data module, model, callbacks, and trainer, then executes the training loop and saves the trained model checkpoint.

The script performs the following steps:

  1. Initializes the GarbageDataModule with stratified train/test split
  2. Creates a GarbageClassifier model (ResNet18-based)
  3. Sets up loss curve visualization callback
  4. Configures PyTorch Lightning Trainer with specified hyperparameters
  5. Trains the model on the garbage dataset
  6. Saves the trained model checkpoint
Usage

Command line:

$ uv run train.py
Notes

Configuration parameters are loaded from utils.config module:

  • MAX_EPOCHS: Maximum number of training epochs
  • LOSS_CURVES_PATH: Directory for saving loss and accuracy plots
  • MODEL_PATH: Path where the trained model checkpoint will be saved

The training uses automatic device selection (GPU if available, otherwise CPU) and disables sanity validation steps for faster startup.

 1#!/usr/bin/env python3
 2# -*- coding: utf-8 -*-
 3"""
 4Training Script for Garbage Classification Model.
 5
 6This script orchestrates the training process for a garbage classification
 7model using PyTorch Lightning. It initializes the data module, model,
 8callbacks, and trainer, then executes the training loop and saves the
 9trained model checkpoint.
10
11The script performs the following steps:
121. Initializes the GarbageDataModule with stratified train/test split
132. Creates a GarbageClassifier model (ResNet18-based)
143. Sets up loss curve visualization callback
154. Configures PyTorch Lightning Trainer with specified hyperparameters
165. Trains the model on the garbage dataset
176. Saves the trained model checkpoint
18
19Usage
20-----
21Command line:
22
23    $ uv run train.py
24
25Notes
26-----
27Configuration parameters are loaded from `utils.config` module:
28- `MAX_EPOCHS`: Maximum number of training epochs
29- `LOSS_CURVES_PATH`: Directory for saving loss and accuracy plots
30- `MODEL_PATH`: Path where the trained model checkpoint will be saved
31
32The training uses automatic device selection (GPU if available, otherwise CPU)
33and disables sanity validation steps for faster startup.
34"""
35__docformat__ = "numpy"
36
37import pytorch_lightning as pl
38from utils import config as cfg
39from utils.custom_classes.GarbageDataModule import GarbageDataModule
40from utils.custom_classes.GarbageClassifier import GarbageClassifier
41from utils.custom_classes.LossCurveCallback import LossCurveCallback
42
43
44# ========================
45# Training
46# ========================
47if __name__ == "__main__":
48    """
49    Main entry point for the training script.
50    """
51    data_module = GarbageDataModule(batch_size=32)
52    data_module.setup()
53
54    model = GarbageClassifier(num_classes=data_module.num_classes, lr=1e-3)
55
56    loss_curve_callback = LossCurveCallback(save_dir=cfg.LOSS_CURVES_PATH)
57
58    trainer = pl.Trainer(
59        max_epochs=cfg.MAX_EPOCHS,
60        accelerator="auto",
61        devices=1,
62        callbacks=[loss_curve_callback],
63        num_sanity_val_steps=0
64    )
65    trainer.fit(model, datamodule=data_module)
66    trainer.save_checkpoint(cfg.MODEL_PATH)
67    print(f"Model saved at {cfg.MODEL_PATH}")