pdoc

Garbage Classifier Project

A deep learning project for garbage classification using PyTorch Lightning and transfer learning with ResNet18.

Overview

This project implements a garbage classification system that categorizes waste into 6 different classes: cardboard, glass, metal, paper, plastic, and trash. The model uses a pretrained ResNet18 architecture fine-tuned on a custom garbage dataset.

Project Structure
  • train.py: Main training script for the classification model
  • predict.py: Script for making predictions on new images
  • utils/: Utility modules containing configuration and custom classes
    • config.py: Configuration parameters and constants
    • custom_classes/: Custom PyTorch Lightning implementations
      • GarbageClassifier.py: ResNet18-based classifier model
      • GarbageDataModule.py: Data loading and preprocessing module
      • LossCurveCallback.py: Callback for visualizing training metrics
  • app/: Runnable demo using gradio
    • main.py: Main executor of the interactive demo
    • sections/: Windows that compone the whole interactive GUI
      • data_exploration.py: Interactive window for checking data
      • model_training.py: Interactive window for training model
      • model_evaluation.py: Interactive window for visualizing training metrics
Features
  • Transfer learning with pretrained ResNet18 (ImageNet weights)
  • Stratified train/test split (90/10) for balanced class distribution
  • Automatic loss and accuracy curve generation
  • GPU acceleration support with automatic fallback to CPU
  • Command-line interface for training and prediction
Quick Start

Training the model:

uv run train.py

Making predictions:

uv run predict.py path/to/image.jpg

Generating documentation:

uv run scripts/generate_docs.py
Dependencies
  • PyTorch Lightning: Deep learning framework
  • PyTorch & torchvision: Neural network implementation and pretrained models
  • PIL: Image processing
  • scikit-learn: Dataset splitting utilities
  • matplotlib: Visualization
Model Architecture

The classifier uses a ResNet18 architecture with:

  • Pretrained feature extraction layers (frozen)
  • Custom classification head for 6 garbage categories
  • Cross-entropy loss function
  • Adam optimizer
Dataset

The model is trained on a custom garbage dataset with 6 classes:

  1. Cardboard
  2. Glass
  3. Metal
  4. Paper
  5. Plastic
  6. Trash
Performance

Training metrics and performance visualizations are automatically saved to:

  • Loss curves: models/performance/loss_curves/
  • Model checkpoint: models/weights/model_resnet18_garbage.ckpt
Examples
>>> # Import the classifier
>>> from utils.custom_classes.GarbageClassifier import GarbageClassifier
>>> from utils import config as cfg
>>>
>>> # Load trained model
>>> model = GarbageClassifier.load_from_checkpoint(
...     cfg.MODEL_PATH,
...     num_classes=cfg.NUM_CLASSES
... )
>>>
>>> # Make prediction
>>> from predict import predict_image
>>> pred_class, pred_idx = predict_image("sample.jpg")
>>> print(f"Prediction: {pred_class}")
Notes

This project follows best practices for:

  • Code organization and modularity
  • Documentation (NumPy-style docstrings)
  • Dependency management (using uv)
  • Reproducibility (fixed random seeds in data splitting)

For more detailed information, see the individual module documentation.

  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Garbage Classifier Project
  5===========================
  6
  7A deep learning project for garbage classification using PyTorch Lightning
  8and transfer learning with ResNet18.
  9
 10Overview
 11--------
 12This project implements a garbage classification system that categorizes waste
 13into 6 different classes: cardboard, glass, metal, paper, plastic, and trash.
 14The model uses a pretrained ResNet18 architecture fine-tuned on a custom
 15garbage dataset.
 16
 17Project Structure
 18-----------------
 19- **train.py**: Main training script for the classification model
 20- **predict.py**: Script for making predictions on new images
 21- **utils/**: Utility modules containing configuration and custom classes
 22  - **config.py**: Configuration parameters and constants
 23  - **custom_classes/**: Custom PyTorch Lightning implementations
 24    - **GarbageClassifier.py**: ResNet18-based classifier model
 25    - **GarbageDataModule.py**: Data loading and preprocessing module
 26    - **LossCurveCallback.py**: Callback for visualizing training metrics
 27
 28Features
 29--------
 30- Transfer learning with pretrained ResNet18 (ImageNet weights)
 31- Stratified train/test split (90/10) for balanced class distribution
 32- Automatic loss and accuracy curve generation
 33- GPU acceleration support with automatic fallback to CPU
 34- Command-line interface for training and prediction
 35
 36Quick Start
 37-----------
 38**Training the model:**
 39
 40    uv run train.py
 41
 42**Making predictions:**
 43
 44    uv run predict.py path/to/image.jpg
 45
 46**Generating documentation:**
 47
 48    uv run scripts/generate_docs.py
 49
 50Dependencies
 51------------
 52- PyTorch Lightning: Deep learning framework
 53- PyTorch & torchvision: Neural network implementation and pretrained models
 54- PIL: Image processing
 55- scikit-learn: Dataset splitting utilities
 56- matplotlib: Visualization
 57
 58Model Architecture
 59------------------
 60The classifier uses a ResNet18 architecture with:
 61- Pretrained feature extraction layers (frozen)
 62- Custom classification head for 6 garbage categories
 63- Cross-entropy loss function
 64- Adam optimizer
 65
 66Dataset
 67-------
 68The model is trained on a custom garbage dataset with 6 classes:
 691. Cardboard
 702. Glass
 713. Metal
 724. Paper
 735. Plastic
 746. Trash
 75
 76Performance
 77-----------
 78Training metrics and performance visualizations are automatically saved to:
 79- Loss curves: `models/performance/loss_curves/`
 80- Model checkpoint: `models/weights/model_resnet18_garbage.ckpt`
 81
 82Examples
 83--------
 84>>> # Import the classifier
 85>>> from utils.custom_classes.GarbageClassifier import GarbageClassifier
 86>>> from utils import config as cfg
 87>>>
 88>>> # Load trained model
 89>>> model = GarbageClassifier.load_from_checkpoint(
 90...     cfg.MODEL_PATH,
 91...     num_classes=cfg.NUM_CLASSES
 92... )
 93>>>
 94>>> # Make prediction
 95>>> from predict import predict_image
 96>>> pred_class, pred_idx = predict_image("sample.jpg")
 97>>> print(f"Prediction: {pred_class}")
 98
 99Notes
100-----
101This project follows best practices for:
102- Code organization and modularity
103- Documentation (NumPy-style docstrings)
104- Dependency management (using uv)
105- Reproducibility (fixed random seeds in data splitting)
106
107For more detailed information, see the individual module documentation.
108"""
109
110from source import train
111from source import predict
112
113__all__ = [
114    "train",
115    "predict",
116    "utils",
117]