predict

Garbage Classification Prediction Script.

This script loads a trained GarbageClassifier model and performs inference on a single image to predict its garbage category. The script can accept an image path as a command-line argument or use a default sample image.

The prediction uses a pretrained ResNet18 model fine-tuned for 6-class garbage classification (cardboard, glass, metal, paper, plastic, trash).

Usage

Command line:

$ uv run predict.py <path_to_image>
Examples

Predict with custom image:

$ uv run predict.py img.jpg

Predict with default sample image:

$ uv run predict.py
Notes
  • The model checkpoint path is configured in utils.config
  • Images are automatically preprocessed using ImageNet normalization
  • Prediction runs on GPU if available, otherwise falls back to CPU
  1#!/usr/bin/env python3
  2# -*- coding: utf-8 -*-
  3"""
  4Garbage Classification Prediction Script.
  5
  6This script loads a trained GarbageClassifier model and performs inference
  7on a single image to predict its garbage category. The script can accept
  8an image path as a command-line argument or use a default sample image.
  9
 10The prediction uses a pretrained ResNet18 model fine-tuned for 6-class
 11garbage classification (cardboard, glass, metal, paper, plastic, trash).
 12
 13Usage
 14-----
 15Command line:
 16
 17    $ uv run predict.py <path_to_image>
 18
 19Examples
 20--------
 21Predict with custom image:
 22
 23    $ uv run predict.py img.jpg
 24
 25Predict with default sample image:
 26
 27    $ uv run predict.py
 28
 29Notes
 30-----
 31- The model checkpoint path is configured in `utils.config`
 32- Images are automatically preprocessed using ImageNet normalization
 33- Prediction runs on GPU if available, otherwise falls back to CPU
 34"""
 35__docformat__ = "numpy"
 36
 37import sys
 38import torch
 39from torchvision import models
 40from PIL import Image
 41from utils import config as cfg
 42from utils.custom_classes.GarbageClassifier import GarbageClassifier
 43
 44
 45def predict_image(image_path):
 46    """
 47    Predict the garbage category of an input image.
 48
 49    Parameters
 50    ----------
 51    image_path : str
 52        Path to the image file to classify.
 53
 54    Returns
 55    -------
 56    tuple of (str, int)
 57        A tuple containing the predicted class name and class index.
 58
 59    Examples
 60    --------
 61    >>> pred_class, pred_idx = predict_image("sample.jpg")
 62    >>> print(f"Prediction: {pred_class}")
 63
 64    Notes
 65    -----
 66    The function automatically handles device selection (GPU/CPU) and
 67    applies the appropriate image transformations for the ResNet18 model.
 68    """
 69    class_names = cfg.CLASS_NAMES
 70    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 71    print(f"Device: {device}")
 72
 73    print("Loading model...")
 74    model = GarbageClassifier.load_from_checkpoint(
 75        cfg.MODEL_PATH,
 76        num_classes=cfg.NUM_CLASSES
 77    )
 78
 79    model = model.to(device)
 80    model.eval()
 81
 82    print("Transforming image...")
 83    transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 84    image = Image.open(image_path).convert("RGB")
 85    tensor = transform(image).unsqueeze(0).to(device)
 86
 87    with torch.no_grad():
 88        outputs = model(tensor)
 89        pred_idx = outputs.argmax(1).item()
 90        pred_class = class_names[pred_idx]
 91
 92    return pred_class, pred_idx
 93
 94
 95def main():
 96    """
 97    Main entry point for the prediction script.
 98
 99    Parses command-line arguments and performs prediction on the specified
100    image or a default sample image.
101    """
102    if len(sys.argv) > 2:
103        print("Usage: uv run predict.py <path_to_image>")
104        print("Example with image in this folder: uv run predict.py img.jpg")
105        sys.exit(1)
106    elif len(sys.argv) == 1:
107        image_path = cfg.SAMPLE_IMG_PATH
108    else:
109        image_path = sys.argv[1]
110
111    pred_class, pred_idx = predict_image(image_path)
112    print(f"Prediction: {pred_class} (class {pred_idx})")
113
114
115if __name__ == "__main__":
116    main()
def predict_image(image_path):
46def predict_image(image_path):
47    """
48    Predict the garbage category of an input image.
49
50    Parameters
51    ----------
52    image_path : str
53        Path to the image file to classify.
54
55    Returns
56    -------
57    tuple of (str, int)
58        A tuple containing the predicted class name and class index.
59
60    Examples
61    --------
62    >>> pred_class, pred_idx = predict_image("sample.jpg")
63    >>> print(f"Prediction: {pred_class}")
64
65    Notes
66    -----
67    The function automatically handles device selection (GPU/CPU) and
68    applies the appropriate image transformations for the ResNet18 model.
69    """
70    class_names = cfg.CLASS_NAMES
71    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72    print(f"Device: {device}")
73
74    print("Loading model...")
75    model = GarbageClassifier.load_from_checkpoint(
76        cfg.MODEL_PATH,
77        num_classes=cfg.NUM_CLASSES
78    )
79
80    model = model.to(device)
81    model.eval()
82
83    print("Transforming image...")
84    transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
85    image = Image.open(image_path).convert("RGB")
86    tensor = transform(image).unsqueeze(0).to(device)
87
88    with torch.no_grad():
89        outputs = model(tensor)
90        pred_idx = outputs.argmax(1).item()
91        pred_class = class_names[pred_idx]
92
93    return pred_class, pred_idx

Predict the garbage category of an input image.

Parameters
  • image_path (str): Path to the image file to classify.
Returns
  • tuple of (str, int): A tuple containing the predicted class name and class index.
Examples
>>> pred_class, pred_idx = predict_image("sample.jpg")
>>> print(f"Prediction: {pred_class}")
Notes

The function automatically handles device selection (GPU/CPU) and applies the appropriate image transformations for the ResNet18 model.

def main():
 96def main():
 97    """
 98    Main entry point for the prediction script.
 99
100    Parses command-line arguments and performs prediction on the specified
101    image or a default sample image.
102    """
103    if len(sys.argv) > 2:
104        print("Usage: uv run predict.py <path_to_image>")
105        print("Example with image in this folder: uv run predict.py img.jpg")
106        sys.exit(1)
107    elif len(sys.argv) == 1:
108        image_path = cfg.SAMPLE_IMG_PATH
109    else:
110        image_path = sys.argv[1]
111
112    pred_class, pred_idx = predict_image(image_path)
113    print(f"Prediction: {pred_class} (class {pred_idx})")

Main entry point for the prediction script.

Parses command-line arguments and performs prediction on the specified image or a default sample image.