train
Training Script for Garbage Classification Model.
This script orchestrates the training process for a garbage classification model using PyTorch Lightning. It initializes the data module, model, callbacks, and trainer, then executes the training loop and saves the trained model checkpoint.
The script performs the following steps:
- Initializes the GarbageDataModule with stratified train/test split
- Creates a GarbageClassifier model (ResNet18-based)
- Sets up loss curve visualization callback
- Configures PyTorch Lightning Trainer with specified hyperparameters
- Trains the model on the garbage dataset
- Saves the trained model checkpoint
Usage
Command line:
$ uv run train.py
Notes
Configuration parameters are loaded from utils.config
module:
MAX_EPOCHS
: Maximum number of training epochsLOSS_CURVES_PATH
: Directory for saving loss and accuracy plotsMODEL_PATH
: Path where the trained model checkpoint will be saved
The training uses automatic device selection (GPU if available, otherwise CPU) and disables sanity validation steps for faster startup.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Training Script for Garbage Classification Model. 5 6This script orchestrates the training process for a garbage classification 7model using PyTorch Lightning. It initializes the data module, model, 8callbacks, and trainer, then executes the training loop and saves the 9trained model checkpoint. 10 11The script performs the following steps: 121. Initializes the GarbageDataModule with stratified train/test split 132. Creates a GarbageClassifier model (ResNet18-based) 143. Sets up loss curve visualization callback 154. Configures PyTorch Lightning Trainer with specified hyperparameters 165. Trains the model on the garbage dataset 176. Saves the trained model checkpoint 18 19Usage 20----- 21Command line: 22 23 $ uv run train.py 24 25Notes 26----- 27Configuration parameters are loaded from `utils.config` module: 28- `MAX_EPOCHS`: Maximum number of training epochs 29- `LOSS_CURVES_PATH`: Directory for saving loss and accuracy plots 30- `MODEL_PATH`: Path where the trained model checkpoint will be saved 31 32The training uses automatic device selection (GPU if available, otherwise CPU) 33and disables sanity validation steps for faster startup. 34""" 35__docformat__ = "numpy" 36 37import pytorch_lightning as pl 38from utils import config as cfg 39from utils.custom_classes.GarbageDataModule import GarbageDataModule 40from utils.custom_classes.GarbageClassifier import GarbageClassifier 41from utils.custom_classes.LossCurveCallback import LossCurveCallback 42 43 44# ======================== 45# Training 46# ======================== 47if __name__ == "__main__": 48 """ 49 Main entry point for the training script. 50 """ 51 data_module = GarbageDataModule(batch_size=32) 52 data_module.setup() 53 54 model = GarbageClassifier(num_classes=data_module.num_classes, lr=1e-3) 55 56 loss_curve_callback = LossCurveCallback(save_dir=cfg.LOSS_CURVES_PATH) 57 58 trainer = pl.Trainer( 59 max_epochs=cfg.MAX_EPOCHS, 60 accelerator="auto", 61 devices=1, 62 callbacks=[loss_curve_callback], 63 num_sanity_val_steps=0 64 ) 65 trainer.fit(model, datamodule=data_module) 66 trainer.save_checkpoint(cfg.MODEL_PATH) 67 print(f"Model saved at {cfg.MODEL_PATH}")