Garbage Classifier Project
A deep learning project for garbage classification using PyTorch Lightning and transfer learning with ResNet18.
Overview
This project implements a garbage classification system that categorizes waste into 6 different classes: cardboard, glass, metal, paper, plastic, and trash. The model uses a pretrained ResNet18 architecture fine-tuned on a custom garbage dataset.
Project Structure
- train.py: Main training script for the classification model
- predict.py: Script for making predictions on new images
- utils/: Utility modules containing configuration and custom classes
- config.py: Configuration parameters and constants
- custom_classes/: Custom PyTorch Lightning implementations
- GarbageClassifier.py: ResNet18-based classifier model
- GarbageDataModule.py: Data loading and preprocessing module
- LossCurveCallback.py: Callback for visualizing training metrics
- app/: Runnable demo using gradio
- main.py: Main executor of the interactive demo
- sections/: Windows that compone the whole interactive GUI
- data_exploration.py: Interactive window for checking data
- model_training.py: Interactive window for training model
- model_evaluation.py: Interactive window for visualizing training metrics
Features
- Transfer learning with pretrained ResNet18 (ImageNet weights)
- Stratified train/test split (90/10) for balanced class distribution
- Automatic loss and accuracy curve generation
- GPU acceleration support with automatic fallback to CPU
- Command-line interface for training and prediction
Quick Start
Training the model:
uv run train.py
Making predictions:
uv run predict.py path/to/image.jpg
Generating documentation:
uv run scripts/generate_docs.py
Dependencies
- PyTorch Lightning: Deep learning framework
- PyTorch & torchvision: Neural network implementation and pretrained models
- PIL: Image processing
- scikit-learn: Dataset splitting utilities
- matplotlib: Visualization
Model Architecture
The classifier uses a ResNet18 architecture with:
- Pretrained feature extraction layers (frozen)
- Custom classification head for 6 garbage categories
- Cross-entropy loss function
- Adam optimizer
Dataset
The model is trained on a custom garbage dataset with 6 classes:
- Cardboard
- Glass
- Metal
- Paper
- Plastic
- Trash
Performance
Training metrics and performance visualizations are automatically saved to:
- Loss curves:
models/performance/loss_curves/ - Model checkpoint:
models/weights/model_resnet18_garbage.ckpt
Examples
>>> # Import the classifier
>>> from utils.custom_classes.GarbageClassifier import GarbageClassifier
>>> from utils import config as cfg
>>>
>>> # Load trained model
>>> model = GarbageClassifier.load_from_checkpoint(
... cfg.MODEL_PATH,
... num_classes=cfg.NUM_CLASSES
... )
>>>
>>> # Make prediction
>>> from predict import predict_image
>>> pred_class, pred_idx = predict_image("sample.jpg")
>>> print(f"Prediction: {pred_class}")
Notes
This project follows best practices for:
- Code organization and modularity
- Documentation (NumPy-style docstrings)
- Dependency management (using uv)
- Reproducibility (fixed random seeds in data splitting)
For more detailed information, see the individual module documentation.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Garbage Classifier Project 5=========================== 6 7A deep learning project for garbage classification using PyTorch Lightning 8and transfer learning with ResNet18. 9 10Overview 11-------- 12This project implements a garbage classification system that categorizes waste 13into 6 different classes: cardboard, glass, metal, paper, plastic, and trash. 14The model uses a pretrained ResNet18 architecture fine-tuned on a custom 15garbage dataset. 16 17Project Structure 18----------------- 19- **train.py**: Main training script for the classification model 20- **predict.py**: Script for making predictions on new images 21- **utils/**: Utility modules containing configuration and custom classes 22 - **config.py**: Configuration parameters and constants 23 - **custom_classes/**: Custom PyTorch Lightning implementations 24 - **GarbageClassifier.py**: ResNet18-based classifier model 25 - **GarbageDataModule.py**: Data loading and preprocessing module 26 - **LossCurveCallback.py**: Callback for visualizing training metrics 27 28Features 29-------- 30- Transfer learning with pretrained ResNet18 (ImageNet weights) 31- Stratified train/test split (90/10) for balanced class distribution 32- Automatic loss and accuracy curve generation 33- GPU acceleration support with automatic fallback to CPU 34- Command-line interface for training and prediction 35 36Quick Start 37----------- 38**Training the model:** 39 40 uv run train.py 41 42**Making predictions:** 43 44 uv run predict.py path/to/image.jpg 45 46**Generating documentation:** 47 48 uv run scripts/generate_docs.py 49 50Dependencies 51------------ 52- PyTorch Lightning: Deep learning framework 53- PyTorch & torchvision: Neural network implementation and pretrained models 54- PIL: Image processing 55- scikit-learn: Dataset splitting utilities 56- matplotlib: Visualization 57 58Model Architecture 59------------------ 60The classifier uses a ResNet18 architecture with: 61- Pretrained feature extraction layers (frozen) 62- Custom classification head for 6 garbage categories 63- Cross-entropy loss function 64- Adam optimizer 65 66Dataset 67------- 68The model is trained on a custom garbage dataset with 6 classes: 691. Cardboard 702. Glass 713. Metal 724. Paper 735. Plastic 746. Trash 75 76Performance 77----------- 78Training metrics and performance visualizations are automatically saved to: 79- Loss curves: `models/performance/loss_curves/` 80- Model checkpoint: `models/weights/model_resnet18_garbage.ckpt` 81 82Examples 83-------- 84>>> # Import the classifier 85>>> from utils.custom_classes.GarbageClassifier import GarbageClassifier 86>>> from utils import config as cfg 87>>> 88>>> # Load trained model 89>>> model = GarbageClassifier.load_from_checkpoint( 90... cfg.MODEL_PATH, 91... num_classes=cfg.NUM_CLASSES 92... ) 93>>> 94>>> # Make prediction 95>>> from predict import predict_image 96>>> pred_class, pred_idx = predict_image("sample.jpg") 97>>> print(f"Prediction: {pred_class}") 98 99Notes 100----- 101This project follows best practices for: 102- Code organization and modularity 103- Documentation (NumPy-style docstrings) 104- Dependency management (using uv) 105- Reproducibility (fixed random seeds in data splitting) 106 107For more detailed information, see the individual module documentation. 108""" 109 110from source import train 111from source import predict 112 113__all__ = [ 114 "train", 115 "predict", 116 "utils", 117]