source.predict

Garbage Classification Prediction Script. ...

  1# source/predict.py
  2# !/usr/bin/env python3
  3# -*- coding: utf-8 -*-
  4"""
  5Garbage Classification Prediction Script.
  6...
  7"""
  8__docformat__ = "numpy"
  9
 10import sys
 11from pathlib import Path
 12import torch
 13from torchvision import models
 14from PIL import Image
 15from source.utils import config as cfg
 16from source.utils.custom_classes.GarbageClassifier import GarbageClassifier
 17from codecarbon import EmissionsTracker
 18
 19
 20# ========================
 21# CORE PREDICTION FUNCTIONS (importable)
 22# ========================
 23
 24
 25def load_model_for_inference(model_path=None, device=None):
 26    """
 27    Load a trained model and preprocessing pipeline for inference.
 28
 29    Loads a GarbageClassifier model from checkpoint and prepares it for
 30    inference. Automatically selects GPU if available, otherwise uses CPU.
 31
 32    Parameters
 33    ----------
 34    model_path : str or Path, optional
 35        Path to model checkpoint file. If None, uses cfg.MODEL_PATH.
 36        Default is None.
 37    device : torch.device, optional
 38        Device to load model on (CPU or CUDA). If None, auto-detects GPU
 39        availability. Default is None.
 40
 41    Returns
 42    -------
 43    tuple
 44        A tuple containing:
 45        - model : GarbageClassifier
 46            Loaded model in evaluation mode.
 47        - device : torch.device
 48            Device the model is loaded on.
 49        - transform : torchvision.transforms.Compose
 50            Image preprocessing pipeline (ResNet18 ImageNet normalization).
 51
 52    Examples
 53    --------
 54    >>> model, device, transform = load_model_for_inference()
 55    >>> # Use in Gradio app or API
 56    """
 57    if model_path is None:
 58        model_path = cfg.MODEL_PATH
 59
 60    if device is None:
 61        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 62
 63    model = GarbageClassifier.load_from_checkpoint(
 64        model_path, num_classes=cfg.NUM_CLASSES
 65    )
 66    model = model.to(device)
 67    model.eval()
 68
 69    transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
 70
 71    return model, device, transform
 72
 73
 74def predict_image(
 75    image_path,
 76    model=None,
 77    transform=None,
 78    device=None,
 79    class_names=None,
 80    track_carbon=False,
 81):
 82    """
 83    Predict the garbage category of a single image.
 84
 85    Loads image from file or PIL Image object, applies preprocessing, and
 86    returns predictions with confidence scores for all classes. Optionally
 87    tracks carbon emissions for the inference operation.
 88
 89    Parameters
 90    ----------
 91    image_path : str, Path, or PIL.Image
 92        Path to image file (str or Path object) or PIL Image object directly.
 93        Supported formats: JPG, PNG, BMP, GIF, TIFF.
 94    model : GarbageClassifier, optional
 95        Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
 96    transform : torchvision.transforms.Compose, optional
 97        Image preprocessing pipeline. If None, uses default ResNet18's.
 98        Default is None.
 99    device : torch.device, optional
100        Device for inference. If None, auto-selects GPU/CPU. Default is None.
101    class_names : list of str, optional
102        List of class names. If None, uses cfg.CLASS_NAMES. Default is None.
103    track_carbon : bool, default=False
104        Whether to track carbon emissions for this inference operation.
105
106    Returns
107    -------
108    dict
109        Dictionary containing:
110        - 'predicted_class': str
111            Predicted garbage class name.
112        - 'predicted_idx': int
113            Predicted class index (0 to num_classes-1).
114        - 'confidence': float
115            Confidence score for predicted class (0.0 to 1.0).
116        - 'probabilities': dict
117            All class probabilities {class_name: score, ...}.
118        - 'emissions': dict or None
119            Carbon emissions data if track_carbon=True, else None.
120        - 'image': PIL.Image
121            Original input image (RGB format).
122
123    Raises
124    ------
125    FileNotFoundError
126        If image file path does not exist.
127    IOError
128        If image file cannot be read.
129    Exception
130        If image format is not supported.
131
132    Examples
133    --------
134    >>> result = predict_image("sample.jpg", track_carbon=True)
135    >>> print(f"Prediction: {result['predicted_class']}")
136    >>> print(f"Confidence: {result['confidence']:.2%}")
137    """
138    # Load model if not provided
139    if model is None or transform is None or device is None:
140        model, device, transform = load_model_for_inference(device=device)
141
142    if class_names is None:
143        class_names = cfg.CLASS_NAMES
144
145    # Start carbon tracking if enabled
146    emissions_data = None
147    if track_carbon:
148        tracker = EmissionsTracker(
149            project_name="garbage_classifier_inference",
150            output_dir=str(Path(cfg.MODEL_PATH).parent),
151            log_level="warning",
152        )
153        tracker.start()
154
155    try:
156        # Handle both file paths and PIL Images
157        if isinstance(image_path, (str, Path)):
158            image = Image.open(image_path).convert("RGB")
159        elif isinstance(image_path, Image.Image):
160            image = image_path.convert("RGB")
161        else:
162            # Assume numpy array (from Gradio)
163            image = Image.fromarray(image_path).convert("RGB")
164
165        tensor = transform(image).unsqueeze(0).to(device)
166
167        with torch.no_grad():
168            outputs = model(tensor)
169            probs = torch.softmax(outputs, dim=1)[0]
170            pred_idx = outputs.argmax(1).item()
171            pred_class = class_names[pred_idx]
172            confidence = probs[pred_idx].item()
173
174        # Get all probabilities
175        all_probs = {
176            class_names[i]: probs[i].item() for i in range(len(class_names))
177        }
178
179        # Stop carbon tracking
180        if track_carbon:
181            emissions_kg = tracker.stop()
182            from source.utils.carbon_utils import (
183                kg_co2_to_car_distance,
184                format_car_distance,
185            )
186
187            car_distances = kg_co2_to_car_distance(emissions_kg)
188            emissions_data = {
189                "emissions_kg": emissions_kg,
190                "emissions_g": emissions_kg * 1000,
191                "car_distance_km": car_distances["distance_km"],
192                "car_distance_m": car_distances["distance_m"],
193                "car_distance_formatted": format_car_distance(emissions_kg),
194            }
195
196        return {
197            "predicted_class": pred_class,
198            "predicted_idx": pred_idx,
199            "confidence": confidence,
200            "probabilities": all_probs,
201            "emissions": emissions_data,
202            "image": image,  # Return PIL image for display
203        }
204
205    except Exception as e:
206        if track_carbon:
207            tracker.stop()
208        raise e
209
210
211def predict_batch(
212    image_paths,
213    model=None,
214    transform=None,
215    device=None,
216    class_names=None,
217    track_carbon=False,
218    progress_callback=None,
219):
220    """
221    Predict garbage categories for multiple images.
222
223    Processes a list of images efficiently using a single loaded model.
224    Optionally tracks carbon emissions for the entire batch. Supports progress
225    callbacks for UI integration.
226
227    Parameters
228    ----------
229    image_paths : list of (str or Path)
230        List of image file paths to process.
231    model : GarbageClassifier, optional
232        Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
233    transform : torchvision.transforms.Compose, optional
234        Image preprocessing pipeline. Default is None.
235    device : torch.device, optional
236        Device for inference. Default is None (auto-detect).
237    class_names : list of str, optional
238        List of class names. Default is None (uses cfg.CLASS_NAMES).
239    track_carbon : bool, default=False
240        Whether to track carbon emissions for the entire batch.
241    progress_callback : callable, optional
242        Callback function called as progress_callback(current, total, message)
243        where current is the image index (1-indexed), total is the total count,
244        and message describes the current operation. Default is None.
245
246    Returns
247    -------
248    dict
249        Dictionary containing:
250        - 'results': list of dict
251            List of prediction results (one per image) with keys:
252            'filename', 'predicted_class', 'predicted_idx', 'confidence',
253            'probabilities', 'status' ('success' / 'error'), [and 'error'].
254        - 'summary': dict
255            Statistics with keys: 'total_images', 'successful', 'failed'.
256        - 'emissions': dict or None
257            Carbon emissions data with 'emissions_per_image_g' if tracked,
258            else None.
259
260    Notes
261    -----
262    Failed predictions are recorded with status='error' and error message.
263    Model is loaded only once for efficiency. Individual image
264    inference is not trackedfor carbon (only the batch total)
265    to avoid overhead.
266
267    Examples
268    --------
269    >>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True)
270    >>> for r in results['results']:
271    ...     if r['status'] == 'success':
272    ...         print(f"{r['filename']}: {r['predicted_class']}")
273    """
274    # Load model once for all images
275    if model is None or transform is None or device is None:
276        model, device, transform = load_model_for_inference(device=device)
277
278    if class_names is None:
279        class_names = cfg.CLASS_NAMES
280
281    # Start carbon tracking if enabled
282    emissions_data = None
283    if track_carbon:
284        tracker = EmissionsTracker(
285            project_name="garbage_classifier_batch_inference",
286            output_dir=str(Path(cfg.MODEL_PATH).parent),
287            log_level="warning",
288        )
289        tracker.start()
290
291    results = []
292    total = len(image_paths)
293
294    for idx, image_path in enumerate(image_paths):
295        if progress_callback:
296            progress_callback(
297                idx + 1, total,
298                f"Processing {Path(image_path).name}"
299            )
300
301        try:
302            result = predict_image(
303                image_path,
304                model=model,
305                transform=transform,
306                device=device,
307                class_names=class_names,
308                track_carbon=False,
309            )
310            result["filename"] = Path(image_path).name
311            result["status"] = "success"
312            results.append(result)
313
314        except Exception as e:
315            results.append(
316                {
317                    "filename": Path(image_path).name,
318                    "status": "error", "error": str(e)
319                }
320            )
321
322    # Stop carbon tracking
323    if track_carbon:
324        emissions_kg = tracker.stop()
325        from source.utils.carbon_utils import (
326            kg_co2_to_car_distance,
327            format_car_distance,
328        )
329
330        car_distances = kg_co2_to_car_distance(emissions_kg)
331        emissions_data = {
332            "emissions_kg": emissions_kg,
333            "emissions_g": emissions_kg * 1000,
334            "car_distance_km": car_distances["distance_km"],
335            "car_distance_m": car_distances["distance_m"],
336            "car_distance_formatted": format_car_distance(emissions_kg),
337            "emissions_per_image_g": (emissions_kg * 1000) / len(image_paths),
338        }
339
340    # Summary
341    successful = len([r for r in results if r.get("status") == "success"])
342    summary = {
343        "total_images": total,
344        "successful": successful,
345        "failed": total - successful,
346    }
347
348    final_result = {
349        "results": results,
350        "summary": summary,
351        "emissions": emissions_data
352    }
353
354    return final_result
355
356
357def get_image_files(path):
358    """
359    Get all valid image files from a directory.
360
361    Recursively searches a directory for image files with supported extensions.
362    Returns sorted list of image file paths.
363
364    Parameters
365    ----------
366    path : Path or str
367        Directory path to search for image files.
368
369    Returns
370    -------
371    list of Path
372        Sorted list of image file paths. Supported extensions: .jpg, .jpeg,
373        .png, .bmp, .gif, .tiff, .tif (case-insensitive).
374
375    Notes
376    -----
377    Extensions are matched case-insensitively. Returns empty list if no
378    valid image files are found.
379    """
380    valid_extensions = {
381        ".jpg",
382        ".jpeg",
383        ".png",
384        ".bmp",
385        ".gif",
386        ".tiff",
387        ".tif"
388    }
389
390    image_files = [
391        f
392        for f in path.iterdir()
393        if f.is_file() and f.suffix.lower() in valid_extensions
394    ]
395    return sorted(image_files)
396
397
398# ========================
399# CLI INTERFACE (for terminal use)
400# ========================
401
402
403def predict_single_image_cli(image_path):
404    """
405    CLI wrapper for single image prediction.
406
407    Command-line interface function that loads model, predicts on a single
408    image, and prints formatted results to stdout.
409
410    Parameters
411    ----------
412    image_path : str or Path
413        Path to the image file to predict.
414
415    Returns
416    -------
417    None
418
419    Side Effects
420    -----------
421    - Prints device information to stdout.
422    - Prints prediction result with confidence and probabilities to stdout.
423    """
424    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
425    print(f"Device: {device}")
426    print("Loading model...")
427
428    result = predict_image(image_path, track_carbon=False)
429
430    print(
431        f"\nPrediction: {result['predicted_class']} \
432        (class {result['predicted_idx']})"
433    )
434    print(f"Confidence: {result['confidence']:.2%}")
435    print("\nAll probabilities:")
436    for class_name, prob in result["probabilities"].items():
437        print(f"  {class_name}: {prob:.2%}")
438
439
440def predict_folder_cli(folder_path):
441    """
442    CLI wrapper for batch prediction on folder of images.
443
444    Command-line interface function that processes all valid images in a
445    directory and prints formatted results to stdout.
446
447    Parameters
448    ----------
449    folder_path : str or Path
450        Path to directory containing images to predict.
451
452    Returns
453    -------
454    None
455
456    Raises
457    ------
458    SystemExit
459        If folder path does not exist or is not a directory, or if no valid
460        image files are found.
461
462    Side Effects
463    -----------
464    - Prints device information, progress updates, and summary to stdout.
465    """
466    folder = Path(folder_path)
467
468    if not folder.exists():
469        print(f"Error: Folder '{folder_path}' does not exist.")
470        sys.exit(1)
471
472    if not folder.is_dir():
473        print(f"Error: '{folder_path}' is not a directory.")
474        sys.exit(1)
475
476    image_files = get_image_files(folder)
477
478    if not image_files:
479        print(f"No valid image files found in '{folder_path}'")
480        sys.exit(1)
481
482    print(f"Found {len(image_files)} image(s) to process\n")
483    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
484    print(f"Device: {device}")
485
486    def progress_callback(current, total, message):
487        print(f"[{current}/{total}] {message}")
488
489    batch_result = predict_batch(
490        image_files, track_carbon=False, progress_callback=progress_callback
491    )
492
493    # Print summary
494    print("\n" + "=" * 60)
495    print("PREDICTION SUMMARY")
496    print("=" * 60)
497    for result in batch_result["results"]:
498        if result["status"] == "success":
499            print(
500                f"{result['filename']:<40} \
501                    -> {result['predicted_class']} \
502                        ({result['confidence']:.2%})"
503            )
504        else:
505            print(f"{result['filename']:<40} -> ERROR: {result['error']}")
506    print("=" * 60)
507
508
509def main():
510    """
511    Main entry point for the prediction script when run from command line.
512
513    Parses command-line arguments to determine whether to predict on a single
514    image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no
515    arguments provided.
516
517    Command-line Usage
518    ------------------
519    $ python predict.py                    # Uses default sample image
520    $ python predict.py image.jpg          # Single image prediction
521    $ python predict.py /path/to/folder/   # Batch folder prediction
522
523    Parameters
524    ----------
525    None (reads from sys.argv)
526
527    Returns
528    -------
529    None
530
531    Raises
532    ------
533    SystemExit
534        If invalid arguments or path does not exist.
535    """
536    if len(sys.argv) > 2:
537        print("Usage: uv run predict.py <path_to_image_or_folder>")
538        print("Examples:")
539        print("  uv run predict.py img.jpg")
540        print("  uv run predict.py /path/to/images/")
541        sys.exit(1)
542    elif len(sys.argv) == 1:
543        image_path = cfg.SAMPLE_IMG_PATH
544        predict_single_image_cli(image_path)
545    else:
546        input_path = Path(sys.argv[1])
547
548        if not input_path.exists():
549            print(f"Error: Path '{input_path}' does not exist.")
550            sys.exit(1)
551
552        if input_path.is_dir():
553            predict_folder_cli(input_path)
554        else:
555            predict_single_image_cli(input_path)
556
557
558if __name__ == "__main__":
559    main()
def load_model_for_inference(model_path=None, device=None):
26def load_model_for_inference(model_path=None, device=None):
27    """
28    Load a trained model and preprocessing pipeline for inference.
29
30    Loads a GarbageClassifier model from checkpoint and prepares it for
31    inference. Automatically selects GPU if available, otherwise uses CPU.
32
33    Parameters
34    ----------
35    model_path : str or Path, optional
36        Path to model checkpoint file. If None, uses cfg.MODEL_PATH.
37        Default is None.
38    device : torch.device, optional
39        Device to load model on (CPU or CUDA). If None, auto-detects GPU
40        availability. Default is None.
41
42    Returns
43    -------
44    tuple
45        A tuple containing:
46        - model : GarbageClassifier
47            Loaded model in evaluation mode.
48        - device : torch.device
49            Device the model is loaded on.
50        - transform : torchvision.transforms.Compose
51            Image preprocessing pipeline (ResNet18 ImageNet normalization).
52
53    Examples
54    --------
55    >>> model, device, transform = load_model_for_inference()
56    >>> # Use in Gradio app or API
57    """
58    if model_path is None:
59        model_path = cfg.MODEL_PATH
60
61    if device is None:
62        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
64    model = GarbageClassifier.load_from_checkpoint(
65        model_path, num_classes=cfg.NUM_CLASSES
66    )
67    model = model.to(device)
68    model.eval()
69
70    transform = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
71
72    return model, device, transform

Load a trained model and preprocessing pipeline for inference.

Loads a GarbageClassifier model from checkpoint and prepares it for inference. Automatically selects GPU if available, otherwise uses CPU.

Parameters
  • model_path (str or Path, optional): Path to model checkpoint file. If None, uses cfg.MODEL_PATH. Default is None.
  • device (torch.device, optional): Device to load model on (CPU or CUDA). If None, auto-detects GPU availability. Default is None.
Returns
  • tuple: A tuple containing:
    • model : GarbageClassifier Loaded model in evaluation mode.
    • device : torch.device Device the model is loaded on.
    • transform : torchvision.transforms.Compose Image preprocessing pipeline (ResNet18 ImageNet normalization).
Examples
>>> model, device, transform = load_model_for_inference()
>>> # Use in Gradio app or API
def predict_image( image_path, model=None, transform=None, device=None, class_names=None, track_carbon=False):
 75def predict_image(
 76    image_path,
 77    model=None,
 78    transform=None,
 79    device=None,
 80    class_names=None,
 81    track_carbon=False,
 82):
 83    """
 84    Predict the garbage category of a single image.
 85
 86    Loads image from file or PIL Image object, applies preprocessing, and
 87    returns predictions with confidence scores for all classes. Optionally
 88    tracks carbon emissions for the inference operation.
 89
 90    Parameters
 91    ----------
 92    image_path : str, Path, or PIL.Image
 93        Path to image file (str or Path object) or PIL Image object directly.
 94        Supported formats: JPG, PNG, BMP, GIF, TIFF.
 95    model : GarbageClassifier, optional
 96        Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
 97    transform : torchvision.transforms.Compose, optional
 98        Image preprocessing pipeline. If None, uses default ResNet18's.
 99        Default is None.
100    device : torch.device, optional
101        Device for inference. If None, auto-selects GPU/CPU. Default is None.
102    class_names : list of str, optional
103        List of class names. If None, uses cfg.CLASS_NAMES. Default is None.
104    track_carbon : bool, default=False
105        Whether to track carbon emissions for this inference operation.
106
107    Returns
108    -------
109    dict
110        Dictionary containing:
111        - 'predicted_class': str
112            Predicted garbage class name.
113        - 'predicted_idx': int
114            Predicted class index (0 to num_classes-1).
115        - 'confidence': float
116            Confidence score for predicted class (0.0 to 1.0).
117        - 'probabilities': dict
118            All class probabilities {class_name: score, ...}.
119        - 'emissions': dict or None
120            Carbon emissions data if track_carbon=True, else None.
121        - 'image': PIL.Image
122            Original input image (RGB format).
123
124    Raises
125    ------
126    FileNotFoundError
127        If image file path does not exist.
128    IOError
129        If image file cannot be read.
130    Exception
131        If image format is not supported.
132
133    Examples
134    --------
135    >>> result = predict_image("sample.jpg", track_carbon=True)
136    >>> print(f"Prediction: {result['predicted_class']}")
137    >>> print(f"Confidence: {result['confidence']:.2%}")
138    """
139    # Load model if not provided
140    if model is None or transform is None or device is None:
141        model, device, transform = load_model_for_inference(device=device)
142
143    if class_names is None:
144        class_names = cfg.CLASS_NAMES
145
146    # Start carbon tracking if enabled
147    emissions_data = None
148    if track_carbon:
149        tracker = EmissionsTracker(
150            project_name="garbage_classifier_inference",
151            output_dir=str(Path(cfg.MODEL_PATH).parent),
152            log_level="warning",
153        )
154        tracker.start()
155
156    try:
157        # Handle both file paths and PIL Images
158        if isinstance(image_path, (str, Path)):
159            image = Image.open(image_path).convert("RGB")
160        elif isinstance(image_path, Image.Image):
161            image = image_path.convert("RGB")
162        else:
163            # Assume numpy array (from Gradio)
164            image = Image.fromarray(image_path).convert("RGB")
165
166        tensor = transform(image).unsqueeze(0).to(device)
167
168        with torch.no_grad():
169            outputs = model(tensor)
170            probs = torch.softmax(outputs, dim=1)[0]
171            pred_idx = outputs.argmax(1).item()
172            pred_class = class_names[pred_idx]
173            confidence = probs[pred_idx].item()
174
175        # Get all probabilities
176        all_probs = {
177            class_names[i]: probs[i].item() for i in range(len(class_names))
178        }
179
180        # Stop carbon tracking
181        if track_carbon:
182            emissions_kg = tracker.stop()
183            from source.utils.carbon_utils import (
184                kg_co2_to_car_distance,
185                format_car_distance,
186            )
187
188            car_distances = kg_co2_to_car_distance(emissions_kg)
189            emissions_data = {
190                "emissions_kg": emissions_kg,
191                "emissions_g": emissions_kg * 1000,
192                "car_distance_km": car_distances["distance_km"],
193                "car_distance_m": car_distances["distance_m"],
194                "car_distance_formatted": format_car_distance(emissions_kg),
195            }
196
197        return {
198            "predicted_class": pred_class,
199            "predicted_idx": pred_idx,
200            "confidence": confidence,
201            "probabilities": all_probs,
202            "emissions": emissions_data,
203            "image": image,  # Return PIL image for display
204        }
205
206    except Exception as e:
207        if track_carbon:
208            tracker.stop()
209        raise e

Predict the garbage category of a single image.

Loads image from file or PIL Image object, applies preprocessing, and returns predictions with confidence scores for all classes. Optionally tracks carbon emissions for the inference operation.

Parameters
  • image_path (str, Path, or PIL.Image): Path to image file (str or Path object) or PIL Image object directly. Supported formats: JPG, PNG, BMP, GIF, TIFF.
  • model (GarbageClassifier, optional): Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
  • transform (torchvision.transforms.Compose, optional): Image preprocessing pipeline. If None, uses default ResNet18's. Default is None.
  • device (torch.device, optional): Device for inference. If None, auto-selects GPU/CPU. Default is None.
  • class_names (list of str, optional): List of class names. If None, uses cfg.CLASS_NAMES. Default is None.
  • track_carbon (bool, default=False): Whether to track carbon emissions for this inference operation.
Returns
  • dict: Dictionary containing:
    • 'predicted_class': str Predicted garbage class name.
    • 'predicted_idx': int Predicted class index (0 to num_classes-1).
    • 'confidence': float Confidence score for predicted class (0.0 to 1.0).
    • 'probabilities': dict All class probabilities {class_name: score, ...}.
    • 'emissions': dict or None Carbon emissions data if track_carbon=True, else None.
    • 'image': PIL.Image Original input image (RGB format).
Raises
  • FileNotFoundError: If image file path does not exist.
  • IOError: If image file cannot be read.
  • Exception: If image format is not supported.
Examples
>>> result = predict_image("sample.jpg", track_carbon=True)
>>> print(f"Prediction: {result['predicted_class']}")
>>> print(f"Confidence: {result['confidence']:.2%}")
def predict_batch( image_paths, model=None, transform=None, device=None, class_names=None, track_carbon=False, progress_callback=None):
212def predict_batch(
213    image_paths,
214    model=None,
215    transform=None,
216    device=None,
217    class_names=None,
218    track_carbon=False,
219    progress_callback=None,
220):
221    """
222    Predict garbage categories for multiple images.
223
224    Processes a list of images efficiently using a single loaded model.
225    Optionally tracks carbon emissions for the entire batch. Supports progress
226    callbacks for UI integration.
227
228    Parameters
229    ----------
230    image_paths : list of (str or Path)
231        List of image file paths to process.
232    model : GarbageClassifier, optional
233        Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
234    transform : torchvision.transforms.Compose, optional
235        Image preprocessing pipeline. Default is None.
236    device : torch.device, optional
237        Device for inference. Default is None (auto-detect).
238    class_names : list of str, optional
239        List of class names. Default is None (uses cfg.CLASS_NAMES).
240    track_carbon : bool, default=False
241        Whether to track carbon emissions for the entire batch.
242    progress_callback : callable, optional
243        Callback function called as progress_callback(current, total, message)
244        where current is the image index (1-indexed), total is the total count,
245        and message describes the current operation. Default is None.
246
247    Returns
248    -------
249    dict
250        Dictionary containing:
251        - 'results': list of dict
252            List of prediction results (one per image) with keys:
253            'filename', 'predicted_class', 'predicted_idx', 'confidence',
254            'probabilities', 'status' ('success' / 'error'), [and 'error'].
255        - 'summary': dict
256            Statistics with keys: 'total_images', 'successful', 'failed'.
257        - 'emissions': dict or None
258            Carbon emissions data with 'emissions_per_image_g' if tracked,
259            else None.
260
261    Notes
262    -----
263    Failed predictions are recorded with status='error' and error message.
264    Model is loaded only once for efficiency. Individual image
265    inference is not trackedfor carbon (only the batch total)
266    to avoid overhead.
267
268    Examples
269    --------
270    >>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True)
271    >>> for r in results['results']:
272    ...     if r['status'] == 'success':
273    ...         print(f"{r['filename']}: {r['predicted_class']}")
274    """
275    # Load model once for all images
276    if model is None or transform is None or device is None:
277        model, device, transform = load_model_for_inference(device=device)
278
279    if class_names is None:
280        class_names = cfg.CLASS_NAMES
281
282    # Start carbon tracking if enabled
283    emissions_data = None
284    if track_carbon:
285        tracker = EmissionsTracker(
286            project_name="garbage_classifier_batch_inference",
287            output_dir=str(Path(cfg.MODEL_PATH).parent),
288            log_level="warning",
289        )
290        tracker.start()
291
292    results = []
293    total = len(image_paths)
294
295    for idx, image_path in enumerate(image_paths):
296        if progress_callback:
297            progress_callback(
298                idx + 1, total,
299                f"Processing {Path(image_path).name}"
300            )
301
302        try:
303            result = predict_image(
304                image_path,
305                model=model,
306                transform=transform,
307                device=device,
308                class_names=class_names,
309                track_carbon=False,
310            )
311            result["filename"] = Path(image_path).name
312            result["status"] = "success"
313            results.append(result)
314
315        except Exception as e:
316            results.append(
317                {
318                    "filename": Path(image_path).name,
319                    "status": "error", "error": str(e)
320                }
321            )
322
323    # Stop carbon tracking
324    if track_carbon:
325        emissions_kg = tracker.stop()
326        from source.utils.carbon_utils import (
327            kg_co2_to_car_distance,
328            format_car_distance,
329        )
330
331        car_distances = kg_co2_to_car_distance(emissions_kg)
332        emissions_data = {
333            "emissions_kg": emissions_kg,
334            "emissions_g": emissions_kg * 1000,
335            "car_distance_km": car_distances["distance_km"],
336            "car_distance_m": car_distances["distance_m"],
337            "car_distance_formatted": format_car_distance(emissions_kg),
338            "emissions_per_image_g": (emissions_kg * 1000) / len(image_paths),
339        }
340
341    # Summary
342    successful = len([r for r in results if r.get("status") == "success"])
343    summary = {
344        "total_images": total,
345        "successful": successful,
346        "failed": total - successful,
347    }
348
349    final_result = {
350        "results": results,
351        "summary": summary,
352        "emissions": emissions_data
353    }
354
355    return final_result

Predict garbage categories for multiple images.

Processes a list of images efficiently using a single loaded model. Optionally tracks carbon emissions for the entire batch. Supports progress callbacks for UI integration.

Parameters
  • image_paths (list of (str or Path)): List of image file paths to process.
  • model (GarbageClassifier, optional): Pre-loaded model. If None, loads from cfg.MODEL_PATH. Default is None.
  • transform (torchvision.transforms.Compose, optional): Image preprocessing pipeline. Default is None.
  • device (torch.device, optional): Device for inference. Default is None (auto-detect).
  • class_names (list of str, optional): List of class names. Default is None (uses cfg.CLASS_NAMES).
  • track_carbon (bool, default=False): Whether to track carbon emissions for the entire batch.
  • progress_callback (callable, optional): Callback function called as progress_callback(current, total, message) where current is the image index (1-indexed), total is the total count, and message describes the current operation. Default is None.
Returns
  • dict: Dictionary containing:
    • 'results': list of dict List of prediction results (one per image) with keys: 'filename', 'predicted_class', 'predicted_idx', 'confidence', 'probabilities', 'status' ('success' / 'error'), [and 'error'].
    • 'summary': dict Statistics with keys: 'total_images', 'successful', 'failed'.
    • 'emissions': dict or None Carbon emissions data with 'emissions_per_image_g' if tracked, else None.
Notes

Failed predictions are recorded with status='error' and error message. Model is loaded only once for efficiency. Individual image inference is not trackedfor carbon (only the batch total) to avoid overhead.

Examples
>>> results = predict_batch(["img1.jpg", "img2.jpg"], track_carbon=True)
>>> for r in results['results']:
...     if r['status'] == 'success':
...         print(f"{r['filename']}: {r['predicted_class']}")
def get_image_files(path):
358def get_image_files(path):
359    """
360    Get all valid image files from a directory.
361
362    Recursively searches a directory for image files with supported extensions.
363    Returns sorted list of image file paths.
364
365    Parameters
366    ----------
367    path : Path or str
368        Directory path to search for image files.
369
370    Returns
371    -------
372    list of Path
373        Sorted list of image file paths. Supported extensions: .jpg, .jpeg,
374        .png, .bmp, .gif, .tiff, .tif (case-insensitive).
375
376    Notes
377    -----
378    Extensions are matched case-insensitively. Returns empty list if no
379    valid image files are found.
380    """
381    valid_extensions = {
382        ".jpg",
383        ".jpeg",
384        ".png",
385        ".bmp",
386        ".gif",
387        ".tiff",
388        ".tif"
389    }
390
391    image_files = [
392        f
393        for f in path.iterdir()
394        if f.is_file() and f.suffix.lower() in valid_extensions
395    ]
396    return sorted(image_files)

Get all valid image files from a directory.

Recursively searches a directory for image files with supported extensions. Returns sorted list of image file paths.

Parameters
  • path (Path or str): Directory path to search for image files.
Returns
  • list of Path: Sorted list of image file paths. Supported extensions: .jpg, .jpeg, .png, .bmp, .gif, .tiff, .tif (case-insensitive).
Notes

Extensions are matched case-insensitively. Returns empty list if no valid image files are found.

def predict_single_image_cli(image_path):
404def predict_single_image_cli(image_path):
405    """
406    CLI wrapper for single image prediction.
407
408    Command-line interface function that loads model, predicts on a single
409    image, and prints formatted results to stdout.
410
411    Parameters
412    ----------
413    image_path : str or Path
414        Path to the image file to predict.
415
416    Returns
417    -------
418    None
419
420    Side Effects
421    -----------
422    - Prints device information to stdout.
423    - Prints prediction result with confidence and probabilities to stdout.
424    """
425    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
426    print(f"Device: {device}")
427    print("Loading model...")
428
429    result = predict_image(image_path, track_carbon=False)
430
431    print(
432        f"\nPrediction: {result['predicted_class']} \
433        (class {result['predicted_idx']})"
434    )
435    print(f"Confidence: {result['confidence']:.2%}")
436    print("\nAll probabilities:")
437    for class_name, prob in result["probabilities"].items():
438        print(f"  {class_name}: {prob:.2%}")

CLI wrapper for single image prediction.

Command-line interface function that loads model, predicts on a single image, and prints formatted results to stdout.

Parameters
  • image_path (str or Path): Path to the image file to predict.
Returns
  • None
Side Effects
  • Prints device information to stdout.
  • Prints prediction result with confidence and probabilities to stdout.
def predict_folder_cli(folder_path):
441def predict_folder_cli(folder_path):
442    """
443    CLI wrapper for batch prediction on folder of images.
444
445    Command-line interface function that processes all valid images in a
446    directory and prints formatted results to stdout.
447
448    Parameters
449    ----------
450    folder_path : str or Path
451        Path to directory containing images to predict.
452
453    Returns
454    -------
455    None
456
457    Raises
458    ------
459    SystemExit
460        If folder path does not exist or is not a directory, or if no valid
461        image files are found.
462
463    Side Effects
464    -----------
465    - Prints device information, progress updates, and summary to stdout.
466    """
467    folder = Path(folder_path)
468
469    if not folder.exists():
470        print(f"Error: Folder '{folder_path}' does not exist.")
471        sys.exit(1)
472
473    if not folder.is_dir():
474        print(f"Error: '{folder_path}' is not a directory.")
475        sys.exit(1)
476
477    image_files = get_image_files(folder)
478
479    if not image_files:
480        print(f"No valid image files found in '{folder_path}'")
481        sys.exit(1)
482
483    print(f"Found {len(image_files)} image(s) to process\n")
484    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
485    print(f"Device: {device}")
486
487    def progress_callback(current, total, message):
488        print(f"[{current}/{total}] {message}")
489
490    batch_result = predict_batch(
491        image_files, track_carbon=False, progress_callback=progress_callback
492    )
493
494    # Print summary
495    print("\n" + "=" * 60)
496    print("PREDICTION SUMMARY")
497    print("=" * 60)
498    for result in batch_result["results"]:
499        if result["status"] == "success":
500            print(
501                f"{result['filename']:<40} \
502                    -> {result['predicted_class']} \
503                        ({result['confidence']:.2%})"
504            )
505        else:
506            print(f"{result['filename']:<40} -> ERROR: {result['error']}")
507    print("=" * 60)

CLI wrapper for batch prediction on folder of images.

Command-line interface function that processes all valid images in a directory and prints formatted results to stdout.

Parameters
  • folder_path (str or Path): Path to directory containing images to predict.
Returns
  • None
Raises
  • SystemExit: If folder path does not exist or is not a directory, or if no valid image files are found.
Side Effects
  • Prints device information, progress updates, and summary to stdout.
def main():
510def main():
511    """
512    Main entry point for the prediction script when run from command line.
513
514    Parses command-line arguments to determine whether to predict on a single
515    image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no
516    arguments provided.
517
518    Command-line Usage
519    ------------------
520    $ python predict.py                    # Uses default sample image
521    $ python predict.py image.jpg          # Single image prediction
522    $ python predict.py /path/to/folder/   # Batch folder prediction
523
524    Parameters
525    ----------
526    None (reads from sys.argv)
527
528    Returns
529    -------
530    None
531
532    Raises
533    ------
534    SystemExit
535        If invalid arguments or path does not exist.
536    """
537    if len(sys.argv) > 2:
538        print("Usage: uv run predict.py <path_to_image_or_folder>")
539        print("Examples:")
540        print("  uv run predict.py img.jpg")
541        print("  uv run predict.py /path/to/images/")
542        sys.exit(1)
543    elif len(sys.argv) == 1:
544        image_path = cfg.SAMPLE_IMG_PATH
545        predict_single_image_cli(image_path)
546    else:
547        input_path = Path(sys.argv[1])
548
549        if not input_path.exists():
550            print(f"Error: Path '{input_path}' does not exist.")
551            sys.exit(1)
552
553        if input_path.is_dir():
554            predict_folder_cli(input_path)
555        else:
556            predict_single_image_cli(input_path)

Main entry point for the prediction script when run from command line.

Parses command-line arguments to determine whether to predict on a single image or batch process a folder. Falls back to cfg.SAMPLE_IMG_PATH if no arguments provided.

Command-line Usage

$ python predict.py # Uses default sample image $ python predict.py image.jpg # Single image prediction $ python predict.py /path/to/folder/ # Batch folder prediction

Parameters
  • None (reads from sys.argv)
Returns
  • None
Raises
  • SystemExit: If invalid arguments or path does not exist.