source.predict
Garbage Classification Prediction Script. ...
1# source/predict.py 2# !/usr/bin/env python3 3# -*- coding: utf-8 -*- 4""" 5Garbage Classification Prediction Script. 6... 7""" 8__docformat__ = "numpy" 9 10import sys 11from pathlib import Path 12import torch 13from torchvision import models 14from PIL import Image 15from source.utils import config as cfg 16from source.utils.custom_classes.GarbageClassifier import GarbageClassifier 17from codecarbon import EmissionsTracker 18 19 20# ======================== 21# CORE PREDICTION FUNCTIONS (importable) 22# ======================== 23 24 25def load_model_for_inference(model_path=None, device=None): 26 """ 27 Load a trained model and preprocessing pipeline for inference. 28 29 Loads a GarbageClassifier model from checkpoint and prepares it for 30 inference. Automatically selects GPU if available, otherwise uses CPU. 31 32 Parameters 33 ---------- 34 model_path : str or Path, optional 35 Path to model checkpoint file. If None, uses cfg.MODEL_PATH. 36 Default is None. 37 device : torch.device, optional 38 Device to load model on (CPU or CUDA). If None, auto-detects GPU 39 availability. Default is None. 40 41 Returns 42 ------- 43 tuple 44 A tuple containing: 45 - model : GarbageClassifier 46 Loaded model in evaluation mode. 47 - device : torch.device 48 Device the model is loaded on. 49 - transform : torchvision.transforms.Compose 50 Image preprocessing pipeline (ResNet18 ImageNet normalization). 51 52 Examples 53 -------- 54 >>> model, device, transform = load_model_for_inference() 55 >>> # Use in Gradio app or API 56 """ 57 if model_path is None: 58 model_path = cfg.MODEL_PATH 59 60 if device is None: 61 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 63 model = GarbageClassifier.load_from_checkpoint( 64 model_path, num_classes=cfg.NUM_CLASSES 65 ) 66 model = model.to(device) 67 model.eval() 68 69 transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 70 71 return model, device, transform 72 73 74def predict_image( 75 image_path, 76 model=None, 77 transform=None, 78 device=None, 79 class_names=None, 80 track_carbon=False, 81): 82 """ 83 Predict the garbage category of a single image. 84 85 Loads image from file or PIL Image object, applies preprocessing, and 86 returns predictions with confidence scores for all classes. Optionally 87 tracks carbon emissions for the inference operation. 88 89 Parameters 90 ---------- 91 image_path : str, Path, or PIL.Image 92 Path to image file (str or Path object) or PIL Image object directly. 93 Supported formats: JPG, PNG, BMP, GIF, TIFF. 94 model : GarbageClassifier, optional 95 Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None. 96 transform : torchvision.transforms.Compose, optional 97 Image preprocessing pipeline. If None, uses default ResNet18's. 98 Default is None. 99 device : torch.device, optional 100 Device for inference. If None, auto-selects GPU/CPU. Default is None. 101 class_names : list of str, optional 102 List of class names. If None, uses cfg.CLASS_NAMES. Default is None. 103 track_carbon : bool, default=False 104 Whether to track carbon emissions for this inference operation. 105 106 Returns 107 ------- 108 dict 109 Dictionary containing: 110 - 'predicted_class': str 111 Predicted garbage class name. 112 - 'predicted_idx': int 113 Predicted class index (0 to num_classes-1). 114 - 'confidence': float 115 Confidence score for predicted class (0.0 to 1.0). 116 - 'probabilities': dict 117 All class probabilities {class_name: score, ...}. 118 - 'emissions': dict or None 119 Carbon emissions data if track_carbon=True, else None. 120 - 'image': PIL.Image 121 Original input image (RGB format). 122 123 Raises 124 ------ 125 FileNotFoundError 126 If image file path does not exist. 127 IOError 128 If image file cannot be read. 129 Exception 130 If image format is not supported. 131 132 Examples 133 -------- 134 >>> result = predict_image("sample.jpg", track_carbon=True) 135 >>> print(f"Prediction: {result['predicted_class']}") 136 >>> print(f"Confidence: {result['confidence']:.2%}") 137 """ 138 # Load model if not provided 139 if model is None or transform is None or device is None: 140 model, device, transform = load_model_for_inference(device=device) 141 142 if class_names is None: 143 class_names = cfg.CLASS_NAMES 144 145 # Start carbon tracking if enabled 146 emissions_data = None 147 if track_carbon: 148 tracker = EmissionsTracker( 149 project_name="garbage_classifier_inference", 150 output_dir=str(Path(cfg.MODEL_PATH).parent), 151 log_level="warning", 152 ) 153 tracker.start() 154 155 try: 156 # Handle both file paths and PIL Images 157 if isinstance(image_path, (str, Path)): 158 image = Image.open(image_path).convert("RGB") 159 elif isinstance(image_path, Image.Image): 160 image = image_path.convert("RGB") 161 else: 162 # Assume numpy array (from Gradio) 163 image = Image.fromarray(image_path).convert("RGB") 164 165 tensor = transform(image).unsqueeze(0).to(device) 166 167 with torch.no_grad(): 168 outputs = model(tensor) 169 probs = torch.softmax(outputs, dim=1)[0] 170 pred_idx = outputs.argmax(1).item() 171 pred_class = class_names[pred_idx] 172 confidence = probs[pred_idx].item() 173 174 # Get all probabilities 175 all_probs = { 176 class_names[i]: probs[i].item() for i in range(len(class_names)) 177 } 178 179 # Stop carbon tracking 180 if track_carbon: 181 emissions_kg = tracker.stop() 182 from source.utils.carbon_utils import ( 183 kg_co2_to_car_distance, 184 format_car_distance, 185 ) 186 187 car_distances = kg_co2_to_car_distance(emissions_kg) 188 emissions_data = { 189 "emissions_kg": emissions_kg, 190 "emissions_g": emissions_kg * 1000, 191 "car_distance_km": car_distances["distance_km"], 192 "car_distance_m": car_distances["distance_m"], 193 "car_distance_formatted": format_car_distance(emissions_kg), 194 } 195 196 return { 197 "predicted_class": pred_class, 198 "predicted_idx": pred_idx, 199 "confidence": confidence, 200 "probabilities": all_probs, 201 "emissions": emissions_data, 202 "image": image, # Return PIL image for display 203 } 204 205 except Exception as e: 206 if track_carbon: 207 tracker.stop() 208 raise e 209 210 211def predict_batch( 212 image_paths, 213 model=None, 214 transform=None, 215 device=None, 216 class_names=None, 217 track_carbon=False, 218 progress_callback=None, 219): 220 """ 221 Predict garbage categories for multiple images. 222 223 Processes a list of images efficiently using a single loaded model. 224 Optionally tracks carbon emissions for the entire batch. Supports progress 225 callbacks for UI integration. 226 227 Parameters 228 ---------- 229 image_paths : list of (str or Path) 230 List of image file paths to process. 231 model : GarbageClassifier, optional 232 Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None. 233 transform : torchvision.transforms.Compose, optional 234 Image preprocessing pipeline. Default is None. 235 device : torch.device, optional 236 Device for inference. Default is None (auto-detect). 237 class_names : list of str, optional 238 List of class names. Default is None (uses cfg.CLASS_NAMES). 239 track_carbon : bool, default=False 240 Whether to track carbon emissions for the entire batch. 241 progress_callback : callable, optional 242 Callback function called as progress_callback(current, total, message) 243 where current is the image index (1-indexed), total is the total count, 244 and message describes the current operation. Default is None. 245 246 Returns 247 ------- 248 dict 249 Dictionary containing: 250 - 'results': list of dict 251 List of prediction results (one per image) with keys: 252 'filename', 'predicted_class', 'predicted_idx', 'confidence', 253 'probabilities', 'status' ('success' / 'error'), [and 'error']. 254 - 'summary': dict 255 Statistics with keys: 'total_images', 'successful', 'failed'. 256 - 'emissions': dict or None 257 Carbon emissions data with 'emissions_per_image_g' if tracked, 258 else None. 259 260 Notes 261 ----- 262 Failed predictions are recorded with status='error' and error message. 263 Model is loaded only once for efficiency. Individual image 264 inference is not trackedfor carbon (only the batch total) 265 to avoid overhead. 266 267 Examples 268 -------- 269 >>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True) 270 >>> for r in results['results']: 271 ... if r['status'] == 'success': 272 ... print(f"{r['filename']}: {r['predicted_class']}") 273 """ 274 # Load model once for all images 275 if model is None or transform is None or device is None: 276 model, device, transform = load_model_for_inference(device=device) 277 278 if class_names is None: 279 class_names = cfg.CLASS_NAMES 280 281 # Start carbon tracking if enabled 282 emissions_data = None 283 if track_carbon: 284 tracker = EmissionsTracker( 285 project_name="garbage_classifier_batch_inference", 286 output_dir=str(Path(cfg.MODEL_PATH).parent), 287 log_level="warning", 288 ) 289 tracker.start() 290 291 results = [] 292 total = len(image_paths) 293 294 for idx, image_path in enumerate(image_paths): 295 if progress_callback: 296 progress_callback( 297 idx + 1, total, 298 f"Processing {Path(image_path).name}" 299 ) 300 301 try: 302 result = predict_image( 303 image_path, 304 model=model, 305 transform=transform, 306 device=device, 307 class_names=class_names, 308 track_carbon=False, 309 ) 310 result["filename"] = Path(image_path).name 311 result["status"] = "success" 312 results.append(result) 313 314 except Exception as e: 315 results.append( 316 { 317 "filename": Path(image_path).name, 318 "status": "error", "error": str(e) 319 } 320 ) 321 322 # Stop carbon tracking 323 if track_carbon: 324 emissions_kg = tracker.stop() 325 from source.utils.carbon_utils import ( 326 kg_co2_to_car_distance, 327 format_car_distance, 328 ) 329 330 car_distances = kg_co2_to_car_distance(emissions_kg) 331 emissions_data = { 332 "emissions_kg": emissions_kg, 333 "emissions_g": emissions_kg * 1000, 334 "car_distance_km": car_distances["distance_km"], 335 "car_distance_m": car_distances["distance_m"], 336 "car_distance_formatted": format_car_distance(emissions_kg), 337 "emissions_per_image_g": (emissions_kg * 1000) / len(image_paths), 338 } 339 340 # Summary 341 successful = len([r for r in results if r.get("status") == "success"]) 342 summary = { 343 "total_images": total, 344 "successful": successful, 345 "failed": total - successful, 346 } 347 348 final_result = { 349 "results": results, 350 "summary": summary, 351 "emissions": emissions_data 352 } 353 354 return final_result 355 356 357def get_image_files(path): 358 """ 359 Get all valid image files from a directory. 360 361 Recursively searches a directory for image files with supported extensions. 362 Returns sorted list of image file paths. 363 364 Parameters 365 ---------- 366 path : Path or str 367 Directory path to search for image files. 368 369 Returns 370 ------- 371 list of Path 372 Sorted list of image file paths. Supported extensions: .jpg, .jpeg, 373 .png, .bmp, .gif, .tiff, .tif (case-insensitive). 374 375 Notes 376 ----- 377 Extensions are matched case-insensitively. Returns empty list if no 378 valid image files are found. 379 """ 380 valid_extensions = { 381 ".jpg", 382 ".jpeg", 383 ".png", 384 ".bmp", 385 ".gif", 386 ".tiff", 387 ".tif" 388 } 389 390 image_files = [ 391 f 392 for f in path.iterdir() 393 if f.is_file() and f.suffix.lower() in valid_extensions 394 ] 395 return sorted(image_files) 396 397 398# ======================== 399# CLI INTERFACE (for terminal use) 400# ======================== 401 402 403def predict_single_image_cli(image_path): 404 """ 405 CLI wrapper for single image prediction. 406 407 Command-line interface function that loads model, predicts on a single 408 image, and prints formatted results to stdout. 409 410 Parameters 411 ---------- 412 image_path : str or Path 413 Path to the image file to predict. 414 415 Returns 416 ------- 417 None 418 419 Side Effects 420 ----------- 421 - Prints device information to stdout. 422 - Prints prediction result with confidence and probabilities to stdout. 423 """ 424 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 425 print(f"Device: {device}") 426 print("Loading model...") 427 428 result = predict_image(image_path, track_carbon=False) 429 430 print( 431 f"\nPrediction: {result['predicted_class']} \ 432 (class {result['predicted_idx']})" 433 ) 434 print(f"Confidence: {result['confidence']:.2%}") 435 print("\nAll probabilities:") 436 for class_name, prob in result["probabilities"].items(): 437 print(f" {class_name}: {prob:.2%}") 438 439 440def predict_folder_cli(folder_path): 441 """ 442 CLI wrapper for batch prediction on folder of images. 443 444 Command-line interface function that processes all valid images in a 445 directory and prints formatted results to stdout. 446 447 Parameters 448 ---------- 449 folder_path : str or Path 450 Path to directory containing images to predict. 451 452 Returns 453 ------- 454 None 455 456 Raises 457 ------ 458 SystemExit 459 If folder path does not exist or is not a directory, or if no valid 460 image files are found. 461 462 Side Effects 463 ----------- 464 - Prints device information, progress updates, and summary to stdout. 465 """ 466 folder = Path(folder_path) 467 468 if not folder.exists(): 469 print(f"Error: Folder '{folder_path}' does not exist.") 470 sys.exit(1) 471 472 if not folder.is_dir(): 473 print(f"Error: '{folder_path}' is not a directory.") 474 sys.exit(1) 475 476 image_files = get_image_files(folder) 477 478 if not image_files: 479 print(f"No valid image files found in '{folder_path}'") 480 sys.exit(1) 481 482 print(f"Found {len(image_files)} image(s) to process\n") 483 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 484 print(f"Device: {device}") 485 486 def progress_callback(current, total, message): 487 print(f"[{current}/{total}] {message}") 488 489 batch_result = predict_batch( 490 image_files, track_carbon=False, progress_callback=progress_callback 491 ) 492 493 # Print summary 494 print("\n" + "=" * 60) 495 print("PREDICTION SUMMARY") 496 print("=" * 60) 497 for result in batch_result["results"]: 498 if result["status"] == "success": 499 print( 500 f"{result['filename']:<40} \ 501 -> {result['predicted_class']} \ 502 ({result['confidence']:.2%})" 503 ) 504 else: 505 print(f"{result['filename']:<40} -> ERROR: {result['error']}") 506 print("=" * 60) 507 508 509def main(): 510 """ 511 Main entry point for the prediction script when run from command line. 512 513 Parses command-line arguments to determine whether to predict on a single 514 image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no 515 arguments provided. 516 517 Command-line Usage 518 ------------------ 519 $ python predict.py # Uses default sample image 520 $ python predict.py image.jpg # Single image prediction 521 $ python predict.py /path/to/folder/ # Batch folder prediction 522 523 Parameters 524 ---------- 525 None (reads from sys.argv) 526 527 Returns 528 ------- 529 None 530 531 Raises 532 ------ 533 SystemExit 534 If invalid arguments or path does not exist. 535 """ 536 if len(sys.argv) > 2: 537 print("Usage: uv run predict.py <path_to_image_or_folder>") 538 print("Examples:") 539 print(" uv run predict.py img.jpg") 540 print(" uv run predict.py /path/to/images/") 541 sys.exit(1) 542 elif len(sys.argv) == 1: 543 image_path = cfg.SAMPLE_IMG_PATH 544 predict_single_image_cli(image_path) 545 else: 546 input_path = Path(sys.argv[1]) 547 548 if not input_path.exists(): 549 print(f"Error: Path '{input_path}' does not exist.") 550 sys.exit(1) 551 552 if input_path.is_dir(): 553 predict_folder_cli(input_path) 554 else: 555 predict_single_image_cli(input_path) 556 557 558if __name__ == "__main__": 559 main()
26def load_model_for_inference(model_path=None, device=None): 27 """ 28 Load a trained model and preprocessing pipeline for inference. 29 30 Loads a GarbageClassifier model from checkpoint and prepares it for 31 inference. Automatically selects GPU if available, otherwise uses CPU. 32 33 Parameters 34 ---------- 35 model_path : str or Path, optional 36 Path to model checkpoint file. If None, uses cfg.MODEL_PATH. 37 Default is None. 38 device : torch.device, optional 39 Device to load model on (CPU or CUDA). If None, auto-detects GPU 40 availability. Default is None. 41 42 Returns 43 ------- 44 tuple 45 A tuple containing: 46 - model : GarbageClassifier 47 Loaded model in evaluation mode. 48 - device : torch.device 49 Device the model is loaded on. 50 - transform : torchvision.transforms.Compose 51 Image preprocessing pipeline (ResNet18 ImageNet normalization). 52 53 Examples 54 -------- 55 >>> model, device, transform = load_model_for_inference() 56 >>> # Use in Gradio app or API 57 """ 58 if model_path is None: 59 model_path = cfg.MODEL_PATH 60 61 if device is None: 62 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 64 model = GarbageClassifier.load_from_checkpoint( 65 model_path, num_classes=cfg.NUM_CLASSES 66 ) 67 model = model.to(device) 68 model.eval() 69 70 transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms() 71 72 return model, device, transform
Load a trained model and preprocessing pipeline for inference.
Loads a GarbageClassifier model from checkpoint and prepares it for inference. Automatically selects GPU if available, otherwise uses CPU.
Parameters
- model_path (str or Path, optional): Path to model checkpoint file. If None, uses cfg.MODEL_PATH. Default is None.
- device (torch.device, optional): Device to load model on (CPU or CUDA). If None, auto-detects GPU availability. Default is None.
Returns
- tuple: A tuple containing:
- model : GarbageClassifier Loaded model in evaluation mode.
- device : torch.device Device the model is loaded on.
- transform : torchvision.transforms.Compose Image preprocessing pipeline (ResNet18 ImageNet normalization).
Examples
>>> model, device, transform = load_model_for_inference()
>>> # Use in Gradio app or API
75def predict_image( 76 image_path, 77 model=None, 78 transform=None, 79 device=None, 80 class_names=None, 81 track_carbon=False, 82): 83 """ 84 Predict the garbage category of a single image. 85 86 Loads image from file or PIL Image object, applies preprocessing, and 87 returns predictions with confidence scores for all classes. Optionally 88 tracks carbon emissions for the inference operation. 89 90 Parameters 91 ---------- 92 image_path : str, Path, or PIL.Image 93 Path to image file (str or Path object) or PIL Image object directly. 94 Supported formats: JPG, PNG, BMP, GIF, TIFF. 95 model : GarbageClassifier, optional 96 Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None. 97 transform : torchvision.transforms.Compose, optional 98 Image preprocessing pipeline. If None, uses default ResNet18's. 99 Default is None. 100 device : torch.device, optional 101 Device for inference. If None, auto-selects GPU/CPU. Default is None. 102 class_names : list of str, optional 103 List of class names. If None, uses cfg.CLASS_NAMES. Default is None. 104 track_carbon : bool, default=False 105 Whether to track carbon emissions for this inference operation. 106 107 Returns 108 ------- 109 dict 110 Dictionary containing: 111 - 'predicted_class': str 112 Predicted garbage class name. 113 - 'predicted_idx': int 114 Predicted class index (0 to num_classes-1). 115 - 'confidence': float 116 Confidence score for predicted class (0.0 to 1.0). 117 - 'probabilities': dict 118 All class probabilities {class_name: score, ...}. 119 - 'emissions': dict or None 120 Carbon emissions data if track_carbon=True, else None. 121 - 'image': PIL.Image 122 Original input image (RGB format). 123 124 Raises 125 ------ 126 FileNotFoundError 127 If image file path does not exist. 128 IOError 129 If image file cannot be read. 130 Exception 131 If image format is not supported. 132 133 Examples 134 -------- 135 >>> result = predict_image("sample.jpg", track_carbon=True) 136 >>> print(f"Prediction: {result['predicted_class']}") 137 >>> print(f"Confidence: {result['confidence']:.2%}") 138 """ 139 # Load model if not provided 140 if model is None or transform is None or device is None: 141 model, device, transform = load_model_for_inference(device=device) 142 143 if class_names is None: 144 class_names = cfg.CLASS_NAMES 145 146 # Start carbon tracking if enabled 147 emissions_data = None 148 if track_carbon: 149 tracker = EmissionsTracker( 150 project_name="garbage_classifier_inference", 151 output_dir=str(Path(cfg.MODEL_PATH).parent), 152 log_level="warning", 153 ) 154 tracker.start() 155 156 try: 157 # Handle both file paths and PIL Images 158 if isinstance(image_path, (str, Path)): 159 image = Image.open(image_path).convert("RGB") 160 elif isinstance(image_path, Image.Image): 161 image = image_path.convert("RGB") 162 else: 163 # Assume numpy array (from Gradio) 164 image = Image.fromarray(image_path).convert("RGB") 165 166 tensor = transform(image).unsqueeze(0).to(device) 167 168 with torch.no_grad(): 169 outputs = model(tensor) 170 probs = torch.softmax(outputs, dim=1)[0] 171 pred_idx = outputs.argmax(1).item() 172 pred_class = class_names[pred_idx] 173 confidence = probs[pred_idx].item() 174 175 # Get all probabilities 176 all_probs = { 177 class_names[i]: probs[i].item() for i in range(len(class_names)) 178 } 179 180 # Stop carbon tracking 181 if track_carbon: 182 emissions_kg = tracker.stop() 183 from source.utils.carbon_utils import ( 184 kg_co2_to_car_distance, 185 format_car_distance, 186 ) 187 188 car_distances = kg_co2_to_car_distance(emissions_kg) 189 emissions_data = { 190 "emissions_kg": emissions_kg, 191 "emissions_g": emissions_kg * 1000, 192 "car_distance_km": car_distances["distance_km"], 193 "car_distance_m": car_distances["distance_m"], 194 "car_distance_formatted": format_car_distance(emissions_kg), 195 } 196 197 return { 198 "predicted_class": pred_class, 199 "predicted_idx": pred_idx, 200 "confidence": confidence, 201 "probabilities": all_probs, 202 "emissions": emissions_data, 203 "image": image, # Return PIL image for display 204 } 205 206 except Exception as e: 207 if track_carbon: 208 tracker.stop() 209 raise e
Predict the garbage category of a single image.
Loads image from file or PIL Image object, applies preprocessing, and returns predictions with confidence scores for all classes. Optionally tracks carbon emissions for the inference operation.
Parameters
- image_path (str, Path, or PIL.Image): Path to image file (str or Path object) or PIL Image object directly. Supported formats: JPG, PNG, BMP, GIF, TIFF.
- model (GarbageClassifier, optional): Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
- transform (torchvision.transforms.Compose, optional): Image preprocessing pipeline. If None, uses default ResNet18's. Default is None.
- device (torch.device, optional): Device for inference. If None, auto-selects GPU/CPU. Default is None.
- class_names (list of str, optional): List of class names. If None, uses cfg.CLASS_NAMES. Default is None.
- track_carbon (bool, default=False): Whether to track carbon emissions for this inference operation.
Returns
- dict: Dictionary containing:
- 'predicted_class': str Predicted garbage class name.
- 'predicted_idx': int Predicted class index (0 to num_classes-1).
- 'confidence': float Confidence score for predicted class (0.0 to 1.0).
- 'probabilities': dict All class probabilities {class_name: score, ...}.
- 'emissions': dict or None Carbon emissions data if track_carbon=True, else None.
- 'image': PIL.Image Original input image (RGB format).
Raises
- FileNotFoundError: If image file path does not exist.
- IOError: If image file cannot be read.
- Exception: If image format is not supported.
Examples
>>> result = predict_image("sample.jpg", track_carbon=True)
>>> print(f"Prediction: {result['predicted_class']}")
>>> print(f"Confidence: {result['confidence']:.2%}")
212def predict_batch( 213 image_paths, 214 model=None, 215 transform=None, 216 device=None, 217 class_names=None, 218 track_carbon=False, 219 progress_callback=None, 220): 221 """ 222 Predict garbage categories for multiple images. 223 224 Processes a list of images efficiently using a single loaded model. 225 Optionally tracks carbon emissions for the entire batch. Supports progress 226 callbacks for UI integration. 227 228 Parameters 229 ---------- 230 image_paths : list of (str or Path) 231 List of image file paths to process. 232 model : GarbageClassifier, optional 233 Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None. 234 transform : torchvision.transforms.Compose, optional 235 Image preprocessing pipeline. Default is None. 236 device : torch.device, optional 237 Device for inference. Default is None (auto-detect). 238 class_names : list of str, optional 239 List of class names. Default is None (uses cfg.CLASS_NAMES). 240 track_carbon : bool, default=False 241 Whether to track carbon emissions for the entire batch. 242 progress_callback : callable, optional 243 Callback function called as progress_callback(current, total, message) 244 where current is the image index (1-indexed), total is the total count, 245 and message describes the current operation. Default is None. 246 247 Returns 248 ------- 249 dict 250 Dictionary containing: 251 - 'results': list of dict 252 List of prediction results (one per image) with keys: 253 'filename', 'predicted_class', 'predicted_idx', 'confidence', 254 'probabilities', 'status' ('success' / 'error'), [and 'error']. 255 - 'summary': dict 256 Statistics with keys: 'total_images', 'successful', 'failed'. 257 - 'emissions': dict or None 258 Carbon emissions data with 'emissions_per_image_g' if tracked, 259 else None. 260 261 Notes 262 ----- 263 Failed predictions are recorded with status='error' and error message. 264 Model is loaded only once for efficiency. Individual image 265 inference is not trackedfor carbon (only the batch total) 266 to avoid overhead. 267 268 Examples 269 -------- 270 >>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True) 271 >>> for r in results['results']: 272 ... if r['status'] == 'success': 273 ... print(f"{r['filename']}: {r['predicted_class']}") 274 """ 275 # Load model once for all images 276 if model is None or transform is None or device is None: 277 model, device, transform = load_model_for_inference(device=device) 278 279 if class_names is None: 280 class_names = cfg.CLASS_NAMES 281 282 # Start carbon tracking if enabled 283 emissions_data = None 284 if track_carbon: 285 tracker = EmissionsTracker( 286 project_name="garbage_classifier_batch_inference", 287 output_dir=str(Path(cfg.MODEL_PATH).parent), 288 log_level="warning", 289 ) 290 tracker.start() 291 292 results = [] 293 total = len(image_paths) 294 295 for idx, image_path in enumerate(image_paths): 296 if progress_callback: 297 progress_callback( 298 idx + 1, total, 299 f"Processing {Path(image_path).name}" 300 ) 301 302 try: 303 result = predict_image( 304 image_path, 305 model=model, 306 transform=transform, 307 device=device, 308 class_names=class_names, 309 track_carbon=False, 310 ) 311 result["filename"] = Path(image_path).name 312 result["status"] = "success" 313 results.append(result) 314 315 except Exception as e: 316 results.append( 317 { 318 "filename": Path(image_path).name, 319 "status": "error", "error": str(e) 320 } 321 ) 322 323 # Stop carbon tracking 324 if track_carbon: 325 emissions_kg = tracker.stop() 326 from source.utils.carbon_utils import ( 327 kg_co2_to_car_distance, 328 format_car_distance, 329 ) 330 331 car_distances = kg_co2_to_car_distance(emissions_kg) 332 emissions_data = { 333 "emissions_kg": emissions_kg, 334 "emissions_g": emissions_kg * 1000, 335 "car_distance_km": car_distances["distance_km"], 336 "car_distance_m": car_distances["distance_m"], 337 "car_distance_formatted": format_car_distance(emissions_kg), 338 "emissions_per_image_g": (emissions_kg * 1000) / len(image_paths), 339 } 340 341 # Summary 342 successful = len([r for r in results if r.get("status") == "success"]) 343 summary = { 344 "total_images": total, 345 "successful": successful, 346 "failed": total - successful, 347 } 348 349 final_result = { 350 "results": results, 351 "summary": summary, 352 "emissions": emissions_data 353 } 354 355 return final_result
Predict garbage categories for multiple images.
Processes a list of images efficiently using a single loaded model. Optionally tracks carbon emissions for the entire batch. Supports progress callbacks for UI integration.
Parameters
- image_paths (list of (str or Path)): List of image file paths to process.
- model (GarbageClassifier, optional): Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
- transform (torchvision.transforms.Compose, optional): Image preprocessing pipeline. Default is None.
- device (torch.device, optional): Device for inference. Default is None (auto-detect).
- class_names (list of str, optional): List of class names. Default is None (uses cfg.CLASS_NAMES).
- track_carbon (bool, default=False): Whether to track carbon emissions for the entire batch.
- progress_callback (callable, optional): Callback function called as progress_callback(current, total, message) where current is the image index (1-indexed), total is the total count, and message describes the current operation. Default is None.
Returns
- dict: Dictionary containing:
- 'results': list of dict List of prediction results (one per image) with keys: 'filename', 'predicted_class', 'predicted_idx', 'confidence', 'probabilities', 'status' ('success' / 'error'), [and 'error'].
- 'summary': dict Statistics with keys: 'total_images', 'successful', 'failed'.
- 'emissions': dict or None Carbon emissions data with 'emissions_per_image_g' if tracked, else None.
Notes
Failed predictions are recorded with status='error' and error message. Model is loaded only once for efficiency. Individual image inference is not trackedfor carbon (only the batch total) to avoid overhead.
Examples
>>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True)
>>> for r in results['results']:
... if r['status'] == 'success':
... print(f"{r['filename']}: {r['predicted_class']}")
358def get_image_files(path): 359 """ 360 Get all valid image files from a directory. 361 362 Recursively searches a directory for image files with supported extensions. 363 Returns sorted list of image file paths. 364 365 Parameters 366 ---------- 367 path : Path or str 368 Directory path to search for image files. 369 370 Returns 371 ------- 372 list of Path 373 Sorted list of image file paths. Supported extensions: .jpg, .jpeg, 374 .png, .bmp, .gif, .tiff, .tif (case-insensitive). 375 376 Notes 377 ----- 378 Extensions are matched case-insensitively. Returns empty list if no 379 valid image files are found. 380 """ 381 valid_extensions = { 382 ".jpg", 383 ".jpeg", 384 ".png", 385 ".bmp", 386 ".gif", 387 ".tiff", 388 ".tif" 389 } 390 391 image_files = [ 392 f 393 for f in path.iterdir() 394 if f.is_file() and f.suffix.lower() in valid_extensions 395 ] 396 return sorted(image_files)
Get all valid image files from a directory.
Recursively searches a directory for image files with supported extensions. Returns sorted list of image file paths.
Parameters
- path (Path or str): Directory path to search for image files.
Returns
- list of Path: Sorted list of image file paths. Supported extensions: .jpg, .jpeg, .png, .bmp, .gif, .tiff, .tif (case-insensitive).
Notes
Extensions are matched case-insensitively. Returns empty list if no valid image files are found.
404def predict_single_image_cli(image_path): 405 """ 406 CLI wrapper for single image prediction. 407 408 Command-line interface function that loads model, predicts on a single 409 image, and prints formatted results to stdout. 410 411 Parameters 412 ---------- 413 image_path : str or Path 414 Path to the image file to predict. 415 416 Returns 417 ------- 418 None 419 420 Side Effects 421 ----------- 422 - Prints device information to stdout. 423 - Prints prediction result with confidence and probabilities to stdout. 424 """ 425 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 426 print(f"Device: {device}") 427 print("Loading model...") 428 429 result = predict_image(image_path, track_carbon=False) 430 431 print( 432 f"\nPrediction: {result['predicted_class']} \ 433 (class {result['predicted_idx']})" 434 ) 435 print(f"Confidence: {result['confidence']:.2%}") 436 print("\nAll probabilities:") 437 for class_name, prob in result["probabilities"].items(): 438 print(f" {class_name}: {prob:.2%}")
CLI wrapper for single image prediction.
Command-line interface function that loads model, predicts on a single image, and prints formatted results to stdout.
Parameters
- image_path (str or Path): Path to the image file to predict.
Returns
- None
Side Effects
- Prints device information to stdout.
- Prints prediction result with confidence and probabilities to stdout.
441def predict_folder_cli(folder_path): 442 """ 443 CLI wrapper for batch prediction on folder of images. 444 445 Command-line interface function that processes all valid images in a 446 directory and prints formatted results to stdout. 447 448 Parameters 449 ---------- 450 folder_path : str or Path 451 Path to directory containing images to predict. 452 453 Returns 454 ------- 455 None 456 457 Raises 458 ------ 459 SystemExit 460 If folder path does not exist or is not a directory, or if no valid 461 image files are found. 462 463 Side Effects 464 ----------- 465 - Prints device information, progress updates, and summary to stdout. 466 """ 467 folder = Path(folder_path) 468 469 if not folder.exists(): 470 print(f"Error: Folder '{folder_path}' does not exist.") 471 sys.exit(1) 472 473 if not folder.is_dir(): 474 print(f"Error: '{folder_path}' is not a directory.") 475 sys.exit(1) 476 477 image_files = get_image_files(folder) 478 479 if not image_files: 480 print(f"No valid image files found in '{folder_path}'") 481 sys.exit(1) 482 483 print(f"Found {len(image_files)} image(s) to process\n") 484 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 485 print(f"Device: {device}") 486 487 def progress_callback(current, total, message): 488 print(f"[{current}/{total}] {message}") 489 490 batch_result = predict_batch( 491 image_files, track_carbon=False, progress_callback=progress_callback 492 ) 493 494 # Print summary 495 print("\n" + "=" * 60) 496 print("PREDICTION SUMMARY") 497 print("=" * 60) 498 for result in batch_result["results"]: 499 if result["status"] == "success": 500 print( 501 f"{result['filename']:<40} \ 502 -> {result['predicted_class']} \ 503 ({result['confidence']:.2%})" 504 ) 505 else: 506 print(f"{result['filename']:<40} -> ERROR: {result['error']}") 507 print("=" * 60)
CLI wrapper for batch prediction on folder of images.
Command-line interface function that processes all valid images in a directory and prints formatted results to stdout.
Parameters
- folder_path (str or Path): Path to directory containing images to predict.
Returns
- None
Raises
- SystemExit: If folder path does not exist or is not a directory, or if no valid image files are found.
Side Effects
- Prints device information, progress updates, and summary to stdout.
510def main(): 511 """ 512 Main entry point for the prediction script when run from command line. 513 514 Parses command-line arguments to determine whether to predict on a single 515 image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no 516 arguments provided. 517 518 Command-line Usage 519 ------------------ 520 $ python predict.py # Uses default sample image 521 $ python predict.py image.jpg # Single image prediction 522 $ python predict.py /path/to/folder/ # Batch folder prediction 523 524 Parameters 525 ---------- 526 None (reads from sys.argv) 527 528 Returns 529 ------- 530 None 531 532 Raises 533 ------ 534 SystemExit 535 If invalid arguments or path does not exist. 536 """ 537 if len(sys.argv) > 2: 538 print("Usage: uv run predict.py <path_to_image_or_folder>") 539 print("Examples:") 540 print(" uv run predict.py img.jpg") 541 print(" uv run predict.py /path/to/images/") 542 sys.exit(1) 543 elif len(sys.argv) == 1: 544 image_path = cfg.SAMPLE_IMG_PATH 545 predict_single_image_cli(image_path) 546 else: 547 input_path = Path(sys.argv[1]) 548 549 if not input_path.exists(): 550 print(f"Error: Path '{input_path}' does not exist.") 551 sys.exit(1) 552 553 if input_path.is_dir(): 554 predict_folder_cli(input_path) 555 else: 556 predict_single_image_cli(input_path)
Main entry point for the prediction script when run from command line.
Parses command-line arguments to determine whether to predict on a single image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no arguments provided.
Command-line Usage
$ python predict.py # Uses default sample image $ python predict.py image.jpg # Single image prediction $ python predict.py /path/to/folder/ # Batch folder prediction
Parameters
- None (reads from sys.argv)
Returns
- None
Raises
- SystemExit: If invalid arguments or path does not exist.