predict
Garbage Classification Prediction Script.
This script loads a trained GarbageClassifier model and performs inference on a single image to predict its garbage category. The script can accept an image path as a command-line argument or use a default sample image.
The prediction uses a pretrained ResNet18 model fine-tuned for 6-class garbage classification (cardboard, glass, metal, paper, plastic, trash).
Usage
Command line:
$ uv run predict.py <path_to_image>
Examples
Predict with custom image:
$ uv run predict.py img.jpg
Predict with default sample image:
$ uv run predict.py
Notes
- The model checkpoint path is configured in
utils.config
- Images are automatically preprocessed using ImageNet normalization
- Prediction runs on GPU if available, otherwise falls back to CPU
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Garbage Classification Prediction Script. 5 6This script loads a trained GarbageClassifier model and performs inference 7on a single image to predict its garbage category. The script can accept 8an image path as a command-line argument or use a default sample image. 9 10The prediction uses a pretrained ResNet18 model fine-tuned for 6-class 11garbage classification (cardboard, glass, metal, paper, plastic, trash). 12 13Usage 14----- 15Command line: 16 17 $ uv run predict.py <path_to_image> 18 19Examples 20-------- 21Predict with custom image: 22 23 $ uv run predict.py img.jpg 24 25Predict with default sample image: 26 27 $ uv run predict.py 28 29Notes 30----- 31- The model checkpoint path is configured in `utils.config` 32- Images are automatically preprocessed using ImageNet normalization 33- Prediction runs on GPU if available, otherwise falls back to CPU 34""" 35__docformat__ = "numpy" 36 37import sys 38import torch 39from torchvision import models 40from PIL import Image 41from utils import config as cfg 42from utils.custom_classes.GarbageClassifier import GarbageClassifier 43 44 45def predict_image(image_path): 46 """ 47 Predict the garbage category of an input image. 48 49 Parameters 50 ---------- 51 image_path : str 52 Path to the image file to classify. 53 54 Returns 55 ------- 56 tuple of (str, int) 57 A tuple containing the predicted class name and class index. 58 59 Examples 60 -------- 61 >>> pred_class, pred_idx = predict_image("sample.jpg") 62 >>> print(f"Prediction: {pred_class}") 63 64 Notes 65 ----- 66 The function automatically handles device selection (GPU/CPU) and 67 applies the appropriate image transformations for the ResNet18 model. 68 """ 69 class_names = cfg.CLASS_NAMES 70 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 print(f"Device: {device}") 72 73 print("Loading model...") 74 model = GarbageClassifier.load_from_checkpoint( 75 cfg.MODEL_PATH, 76 num_classes=cfg.NUM_CLASSES 77 ) 78 79 model = model.to(device) 80 model.eval() 81 82 print("Transforming image...") 83 transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 84 image = Image.open(image_path).convert("RGB") 85 tensor = transform(image).unsqueeze(0).to(device) 86 87 with torch.no_grad(): 88 outputs = model(tensor) 89 pred_idx = outputs.argmax(1).item() 90 pred_class = class_names[pred_idx] 91 92 return pred_class, pred_idx 93 94 95def main(): 96 """ 97 Main entry point for the prediction script. 98 99 Parses command-line arguments and performs prediction on the specified 100 image or a default sample image. 101 """ 102 if len(sys.argv) > 2: 103 print("Usage: uv run predict.py <path_to_image>") 104 print("Example with image in this folder: uv run predict.py img.jpg") 105 sys.exit(1) 106 elif len(sys.argv) == 1: 107 image_path = cfg.SAMPLE_IMG_PATH 108 else: 109 image_path = sys.argv[1] 110 111 pred_class, pred_idx = predict_image(image_path) 112 print(f"Prediction: {pred_class} (class {pred_idx})") 113 114 115if __name__ == "__main__": 116 main()
46def predict_image(image_path): 47 """ 48 Predict the garbage category of an input image. 49 50 Parameters 51 ---------- 52 image_path : str 53 Path to the image file to classify. 54 55 Returns 56 ------- 57 tuple of (str, int) 58 A tuple containing the predicted class name and class index. 59 60 Examples 61 -------- 62 >>> pred_class, pred_idx = predict_image("sample.jpg") 63 >>> print(f"Prediction: {pred_class}") 64 65 Notes 66 ----- 67 The function automatically handles device selection (GPU/CPU) and 68 applies the appropriate image transformations for the ResNet18 model. 69 """ 70 class_names = cfg.CLASS_NAMES 71 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 print(f"Device: {device}") 73 74 print("Loading model...") 75 model = GarbageClassifier.load_from_checkpoint( 76 cfg.MODEL_PATH, 77 num_classes=cfg.NUM_CLASSES 78 ) 79 80 model = model.to(device) 81 model.eval() 82 83 print("Transforming image...") 84 transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 85 image = Image.open(image_path).convert("RGB") 86 tensor = transform(image).unsqueeze(0).to(device) 87 88 with torch.no_grad(): 89 outputs = model(tensor) 90 pred_idx = outputs.argmax(1).item() 91 pred_class = class_names[pred_idx] 92 93 return pred_class, pred_idx
Predict the garbage category of an input image.
Parameters
- image_path (str): Path to the image file to classify.
Returns
- tuple of (str, int): A tuple containing the predicted class name and class index.
Examples
>>> pred_class, pred_idx = predict_image("sample.jpg")
>>> print(f"Prediction: {pred_class}")
Notes
The function automatically handles device selection (GPU/CPU) and applies the appropriate image transformations for the ResNet18 model.
96def main(): 97 """ 98 Main entry point for the prediction script. 99 100 Parses command-line arguments and performs prediction on the specified 101 image or a default sample image. 102 """ 103 if len(sys.argv) > 2: 104 print("Usage: uv run predict.py <path_to_image>") 105 print("Example with image in this folder: uv run predict.py img.jpg") 106 sys.exit(1) 107 elif len(sys.argv) == 1: 108 image_path = cfg.SAMPLE_IMG_PATH 109 else: 110 image_path = sys.argv[1] 111 112 pred_class, pred_idx = predict_image(image_path) 113 print(f"Prediction: {pred_class} (class {pred_idx})")
Main entry point for the prediction script.
Parses command-line arguments and performs prediction on the specified image or a default sample image.