app.sections.model_evaluation

Model Evaluation and Inference Interface for Gradio Application.

This module provides comprehensive model evaluation tools and inference capabilities through a Gradio interface. It includes visualization of model performance metrics, confusion matrices, calibration curves, and both single-image and batch prediction functionality.

Key features include:

  • Multi-model support (best model and latest trained model)
  • Cached computation of expensive metrics (confusion matrices, calibration)
  • Automatic cache invalidation based on model modification time
  • Real-time inference with carbon emissions tracking
  • Training metrics visualization (loss/accuracy curves)
Notes

Cache files are stored in app/sections/cached_data/ and are automatically invalidated when the corresponding model file is updated. This ensures that evaluation metrics reflect the current model state while avoiding redundant computation.

   1#!/usr/bin/env python3
   2# -*- coding: utf-8 -*-
   3"""
   4Model Evaluation and Inference Interface for Gradio Application.
   5
   6This module provides comprehensive model evaluation tools and inference
   7capabilities through a Gradio interface. It includes visualization of
   8model performance metrics, confusion matrices, calibration curves, and
   9both single-image and batch prediction functionality.
  10
  11Key features include:
  12- Multi-model support (best model and latest trained model)
  13- Cached computation of expensive metrics (confusion matrices, calibration)
  14- Automatic cache invalidation based on model modification time
  15- Real-time inference with carbon emissions tracking
  16- Training metrics visualization (loss/accuracy curves)
  17
  18Notes
  19-----
  20Cache files are stored in `app/sections/cached_data/` and are automatically
  21invalidated when the corresponding model file is updated. This ensures
  22that evaluation metrics reflect the current model state while avoiding
  23redundant computation.
  24"""
  25__docformat__ = "numpy"
  26
  27import gradio as gr
  28from source.utils.carbon_utils import format_total_emissions_display
  29from source.utils import config as cfg
  30from source.utils.custom_classes.EvalAnalyzer import GarbageModelAnalyzer
  31from source.predict import (
  32    predict_image,
  33    predict_batch,
  34    load_model_for_inference
  35)
  36from pathlib import Path
  37import matplotlib.pyplot as plt
  38import numpy as np
  39import json
  40import pandas as pd
  41import pickle
  42
  43
  44def get_emissions_path():
  45    """
  46    Get the path to the emissions CSV file.
  47
  48    Returns
  49    -------
  50    pathlib.Path
  51        Path object pointing to the emissions.csv file in the model directory.
  52
  53    Notes
  54    -----
  55    The emissions file is stored in the same directory as the trained model
  56    checkpoint, as defined in the global configuration.
  57    """
  58    return Path(cfg.MODEL_PATH).parent / "emissions.csv"
  59
  60
  61def get_available_models():
  62    """
  63    Get dictionary of available trained models.
  64
  65    Returns
  66    -------
  67    dict of {str: str}
  68        Dictionary mapping model display names to their checkpoint file paths.
  69        Includes both the best provided model and the latest trained model.
  70
  71    Examples
  72    --------
  73    >>> models = get_available_models()
  74    >>> models
  75    {
  76        'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt',
  77        'Latest Trained Model': 'models/resnet18_garbage.ckpt'
  78    }
  79    """
  80    models_dict = {
  81        "Best Model (Provided)": str(Path(
  82            "models/best/model_resnet18_garbage.ckpt"
  83        )),
  84        "Latest Trained Model": str(cfg.MODEL_PATH),
  85    }
  86    return models_dict
  87
  88
  89# Setup cache directory
  90CACHE_DIR = Path("app/sections/cached_data")
  91CACHE_DIR.mkdir(parents=True, exist_ok=True)
  92
  93
  94def is_cache_valid(cache_file, model_path):
  95    """
  96    Check if cached file is newer than the model checkpoint.
  97
  98    Validates cache by comparing modification timestamps. Cache is considered
  99    valid only if it was created/modified after the model checkpoint.
 100
 101    Parameters
 102    ----------
 103    cache_file : pathlib.Path
 104        Path to the cached file to validate.
 105    model_path : str or pathlib.Path
 106        Path to the model checkpoint file.
 107
 108    Returns
 109    -------
 110    bool
 111        True if cache exists and is newer than the model, False otherwise.
 112
 113    Notes
 114    -----
 115    This function implements automatic cache invalidation to ensure that
 116    evaluation metrics are regenerated when a model is retrained.
 117
 118    Examples
 119    --------
 120    >>> cache_file = Path("cached_data/cm_raw_Latest.pkl")
 121    >>> model_path = "models/resnet18_garbage.ckpt"
 122    >>> is_cache_valid(cache_file, model_path)
 123    False  # Cache doesn't exist or is older than model
 124    """
 125    if not cache_file.exists():
 126        return False
 127
 128    if not Path(model_path).exists():
 129        return False
 130
 131    cache_time = cache_file.stat().st_mtime
 132    model_time = Path(model_path).stat().st_mtime
 133
 134    print(cache_file, model_path)
 135    print(cache_time, model_time)
 136
 137    return cache_time > model_time  # Cache is valid if newer than model
 138
 139
 140# State to hold both confusion matrices
 141confusion_matrices_state = {
 142    "raw": None,
 143    "normalized": None,
 144    "model_choice": None,  # Track which model generated these
 145}
 146
 147
 148# ========================
 149# METRICS FUNCTIONS
 150# ========================
 151
 152
 153def generate_confusion_matrix(
 154    model_choice,
 155    show_normalized,
 156    progress=gr.Progress()
 157):
 158    """
 159    Generate and cache both raw and normalized confusion matrices.
 160
 161    This function generates confusion matrices for the validation set,
 162    caching both raw and normalized versions to enable instant toggling
 163    without recomputation. Implements three-tier caching: memory, disk,
 164    and automatic invalidation.
 165
 166    Parameters
 167    ----------
 168    model_choice : str
 169        Name of the selected model (from get_available_models()).
 170    show_normalized : bool
 171        Whether to display the normalized version initially.
 172    progress : gr.Progress, optional
 173        Gradio progress tracker for UI updates.
 174
 175    Returns
 176    -------
 177    tuple of (matplotlib.figure.Figure or None, str, gr.update)
 178        - Figure: The confusion matrix plot (raw or normalized)
 179        - str: Status message describing the operation result
 180        - gr.update: Gradio update object to control checkbox visibility
 181
 182    Notes
 183    -----
 184    Caching strategy:
 185    1. Check memory cache (fastest)
 186    2. Check disk cache with validation (fast)
 187    3. Regenerate if cache invalid or missing (slow)
 188
 189    Both raw and normalized matrices are always generated together and
 190    stored separately to enable instant toggling via the checkbox.
 191
 192    The cache is automatically invalidated when the model file is modified,
 193    ensuring metrics reflect the current model state.
 194
 195    See Also
 196    --------
 197    toggle_confusion_matrix : Switch between raw/normalized without
 198    regenerating.
 199    is_cache_valid : Cache validation logic
 200    """
 201    # global confusion_matrices_state # TODO: Revert this (uncomment)
 202
 203    if not model_choice:
 204        return None, "Please select a model first", gr.update(visible=False)
 205
 206    # Get model path
 207    models_dict = get_available_models()
 208    model_path = models_dict.get(model_choice)
 209
 210    # Check if we already have matrices for this model in memory
 211    if (
 212        confusion_matrices_state["model_choice"] == model_choice
 213        and confusion_matrices_state["raw"] is not None
 214        and confusion_matrices_state["normalized"] is not None
 215    ):
 216
 217        selected_matrix = (
 218            confusion_matrices_state["normalized"]
 219            if show_normalized
 220            else confusion_matrices_state["raw"]
 221        )
 222        matrix_type = "Normalized" if show_normalized else "Raw"
 223        return (
 224            selected_matrix,
 225            f"✅ {matrix_type} confusion matrix (from memory)",
 226            gr.update(visible=True, interactive=True),
 227        )
 228
 229    # Check disk cache
 230    cache_file_raw = \
 231        CACHE_DIR / f"cm_raw_{model_choice.replace(' ', '_')}.pkl"
 232    cache_file_norm = \
 233        CACHE_DIR / f"cm_norm_{model_choice.replace(' ', '_')}.pkl"
 234
 235    if is_cache_valid(cache_file_raw, model_path) and is_cache_valid(
 236        cache_file_norm, model_path
 237    ):
 238        try:
 239            progress(0.2, desc="Loading from cache...")
 240            with open(cache_file_raw, "rb") as f:
 241                fig_raw = pickle.load(f)
 242            with open(cache_file_norm, "rb") as f:
 243                fig_norm = pickle.load(f)
 244
 245            confusion_matrices_state["raw"] = fig_raw
 246            confusion_matrices_state["normalized"] = fig_norm
 247            confusion_matrices_state["model_choice"] = model_choice
 248
 249            selected_matrix = fig_norm if show_normalized else fig_raw
 250            matrix_type = "Normalized" if show_normalized else "Raw"
 251            progress(1.0, desc="Done!")
 252            return (
 253                selected_matrix,
 254                f"✅ {matrix_type} confusion matrix loaded from cache",
 255                gr.update(visible=True, interactive=True),
 256            )
 257        except Exception as e:
 258            print(f"[WARN] Failed to load cache: {e}")
 259
 260    # Generate new matrices
 261    try:
 262        progress(0.1, desc="Loading model...")
 263        analyzer = GarbageModelAnalyzer()
 264        analyzer.load_model(checkpoint_path=model_path)
 265
 266        progress(0.3, desc="Setting up data...")
 267        analyzer.setup_data(batch_size=32)
 268
 269        progress(0.5, desc="Evaluating model...")
 270        val_loader = analyzer.data_module.val_dataloader()
 271        preds, labels, probs = analyzer.evaluate_loader(val_loader)
 272
 273        progress(0.7, desc="Generating confusion matrices...")
 274
 275        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 276
 277        num_classes = cfg.NUM_CLASSES
 278        cm_raw = confusion_matrix(
 279            labels.cpu().numpy(),
 280            preds.cpu().numpy(),
 281            labels=range(num_classes)
 282        )
 283        cm_norm = cm_raw.astype("float") / cm_raw.sum(axis=1)[:, np.newaxis]
 284
 285        # Generate RAW matrix figure
 286        fig_raw, ax_raw = plt.subplots(figsize=(10, 8))
 287        disp_raw = ConfusionMatrixDisplay(
 288            confusion_matrix=cm_raw, display_labels=cfg.CLASS_NAMES
 289        )
 290        disp_raw.plot(cmap=plt.cm.Blues, ax=ax_raw)
 291        ax_raw.set_title("Confusion Matrix - Validation Set")
 292        plt.tight_layout()
 293
 294        # Generate NORMALIZED matrix figure
 295        fig_norm, ax_norm = plt.subplots(figsize=(10, 8))
 296        disp_norm = ConfusionMatrixDisplay(
 297            confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES
 298        )
 299        disp_norm.plot(cmap=plt.cm.Blues, ax=ax_norm)
 300        ax_norm.set_title("Normalized Confusion Matrix - Validation Set")
 301        plt.tight_layout()
 302
 303        # Save to cache
 304        progress(0.9, desc="Saving to cache...")
 305        with open(cache_file_raw, "wb") as f:
 306            pickle.dump(fig_raw, f)
 307        with open(cache_file_norm, "wb") as f:
 308            pickle.dump(fig_norm, f)
 309
 310        # Update state
 311        confusion_matrices_state["raw"] = fig_raw
 312        confusion_matrices_state["normalized"] = fig_norm
 313        confusion_matrices_state["model_choice"] = model_choice
 314
 315        selected_matrix = fig_norm if show_normalized else fig_raw
 316        matrix_type = "Normalized" if show_normalized else "Raw"
 317
 318        progress(1.0, desc="Done!")
 319        return (
 320            selected_matrix,
 321            f"✅ {matrix_type} confusion matrix generated and cached",
 322            gr.update(visible=True, interactive=True),
 323        )
 324
 325    except Exception as e:
 326        return None, f"❌ Error: {str(e)}", gr.update(visible=False)
 327
 328
 329def toggle_confusion_matrix(show_normalized):
 330    """
 331    Toggle between raw and normalized confusion matrices instantly.
 332
 333    Switches the displayed confusion matrix without regenerating, using
 334    cached versions from memory. This enables instant UI response to the
 335    normalization checkbox.
 336
 337    Parameters
 338    ----------
 339    show_normalized : bool
 340        If True, return normalized matrix; if False, return raw matrix.
 341
 342    Returns
 343    -------
 344    matplotlib.figure.Figure or None
 345        The requested confusion matrix figure, or None if not available
 346        in memory cache.
 347
 348    Notes
 349    -----
 350    This function only accesses the in-memory cache. If matrices are not
 351    yet generated, it returns None. The generate_confusion_matrix function
 352    should be called first to populate the cache.
 353
 354    See Also
 355    --------
 356    generate_confusion_matrix : Generate and cache both matrix versions
 357    """
 358    # global confusion_matrices_state # TODO: Revert this (uncomment)
 359
 360    if show_normalized:
 361        if confusion_matrices_state["normalized"] is not None:
 362            return confusion_matrices_state["normalized"]
 363    else:
 364        if confusion_matrices_state["raw"] is not None:
 365            return confusion_matrices_state["raw"]
 366
 367    return None
 368
 369
 370def generate_calibration_curves(model_choice, progress=gr.Progress()):
 371    """
 372    Generate and cache calibration curves for the selected model.
 373
 374    Calibration curves show how well predicted probabilities match actual
 375    outcomes. This function generates curves for all classes and caches
 376    the result for faster subsequent access.
 377
 378    Parameters
 379    ----------
 380    model_choice : str
 381        Name of the selected model (from get_available_models()).
 382    progress : gr.Progress, optional
 383        Gradio progress tracker for UI updates.
 384
 385    Returns
 386    -------
 387    tuple of (matplotlib.figure.Figure or None, str)
 388        - Figure: The calibration curves plot
 389        - str: Status message describing the operation result
 390
 391    Notes
 392    -----
 393    The calibration plot is generated using GarbageModelAnalyzer's
 394    plot_calibration_curves method and cached to disk. Cache is
 395    automatically invalidated when the model is retrained.
 396
 397    Well-calibrated models should have curves close to the diagonal,
 398    indicating that predicted probabilities match actual frequencies.
 399
 400    See Also
 401    --------
 402    is_cache_valid : Cache validation logic
 403    GarbageModelAnalyzer.plot_calibration_curves : Core plotting function
 404    """
 405    if not model_choice:
 406        return None, "Please select a model first"
 407
 408    # Get model path
 409    models_dict = get_available_models()
 410    model_path = models_dict.get(model_choice)
 411
 412    # Check disk cache
 413    cache_file = CACHE_DIR / f"calib_{model_choice.replace(' ', '_')}.pkl"
 414
 415    if is_cache_valid(cache_file, model_path):
 416        try:
 417            progress(0.2, desc="Loading from cache...")
 418            with open(cache_file, "rb") as f:
 419                fig = pickle.load(f)
 420            progress(1.0, desc="Done!")
 421            return fig, "✅ Calibration curves loaded from cache"
 422        except Exception as e:
 423            print(f"[WARN] Failed to load cache: {e}")
 424
 425    try:
 426        progress(0.1, desc="Loading model...")
 427        analyzer = GarbageModelAnalyzer()
 428        analyzer.load_model(checkpoint_path=model_path)
 429
 430        progress(0.3, desc="Setting up data...")
 431        analyzer.setup_data(batch_size=32)
 432
 433        progress(0.5, desc="Evaluating model...")
 434        val_loader = analyzer.data_module.val_dataloader()
 435        preds, labels, probs = analyzer.evaluate_loader(val_loader)
 436
 437        progress(0.8, desc="Generating calibration curves...")
 438
 439        analyzer.plot_calibration_curves(labels, probs)
 440        fig = plt.gcf()
 441
 442        # Save to cache
 443        progress(0.9, desc="Saving to cache...")
 444        with open(cache_file, "wb") as f:
 445            pickle.dump(fig, f)
 446
 447        progress(1.0, desc="Done!")
 448        return fig, "✅ Calibration curves generated and cached"
 449
 450    except Exception as e:
 451        return None, f"❌ Error: {str(e)}"
 452
 453
 454def get_metrics_path_for_model(model_choice):
 455    """
 456    Get the correct metrics.json path for the selected model.
 457
 458    Different models store their training metrics in different locations.
 459    This function maps model choices to their corresponding metrics files.
 460
 461    Parameters
 462    ----------
 463    model_choice : str
 464        Name of the selected model (from get_available_models()).
 465
 466    Returns
 467    -------
 468    pathlib.Path
 469        Path to the metrics.json file for the selected model.
 470
 471    Notes
 472    -----
 473    The best provided model has metrics in a fixed location, while the
 474    latest trained model uses the configured loss curves path.
 475    """
 476    if model_choice == "Best Model (Provided)":
 477        return Path("models/best/performance/loss_curves/metrics.json")
 478    else:
 479        return Path(cfg.LOSS_CURVES_PATH) / "metrics.json"
 480
 481
 482def load_loss_curves(model_choice):
 483    """
 484    Load and plot training/validation loss curves from metrics.json.
 485
 486    Parameters
 487    ----------
 488    model_choice : str
 489        Name of the selected model (from get_available_models()).
 490
 491    Returns
 492    -------
 493    tuple of (matplotlib.figure.Figure or None, str)
 494        - Figure: Loss curves plot showing train and validation loss
 495        - str: Status message describing the operation result
 496
 497    Notes
 498    -----
 499    Loss curves show how training and validation loss evolved during
 500    training. Diverging curves may indicate overfitting, while parallel
 501    curves suggest good generalization.
 502
 503    The plot includes:
 504    - Blue line with circles: Training loss
 505    - Red line with squares: Validation loss
 506    - Grid for easier reading
 507
 508    See Also
 509    --------
 510    load_accuracy_curves : Load accuracy metrics instead of loss
 511    get_metrics_path_for_model : Determine metrics file location
 512    """
 513    try:
 514        metrics_path = get_metrics_path_for_model(model_choice)
 515
 516        if not metrics_path.exists():
 517            return (
 518                None,
 519                f"❌ No training metrics found for \
 520                    {model_choice}. Path: {metrics_path}",
 521            )
 522
 523        with open(metrics_path, "r") as f:
 524            data = json.load(f)
 525
 526        train_losses = data.get("train_losses", [])
 527        val_losses = data.get("val_losses", [])
 528
 529        if not train_losses and not val_losses:
 530            return None, "❌ No loss data available"
 531
 532        fig, ax = plt.subplots(figsize=(10, 6))
 533        epochs = range(1, len(train_losses) + 1)
 534
 535        if train_losses:
 536            ax.plot(
 537                epochs,
 538                train_losses,
 539                "b-o",
 540                label="Train Loss",
 541                linewidth=2
 542            )
 543        if val_losses:
 544            ax.plot(
 545                epochs,
 546                val_losses,
 547                "r-s",
 548                label="Validation Loss",
 549                linewidth=2
 550            )
 551
 552        ax.set_xlabel("Epoch", fontsize=12)
 553        ax.set_ylabel("Loss", fontsize=12)
 554        ax.set_title(f"Loss Curves - \
 555            {model_choice}", fontsize=14, fontweight="bold")
 556        ax.legend(fontsize=10)
 557        ax.grid(True, alpha=0.3)
 558        plt.tight_layout()
 559
 560        return fig, f"✅ Loss curves loaded successfully from {model_choice}"
 561
 562    except Exception as e:
 563        return None, f"❌ Error loading loss curves: {str(e)}"
 564
 565
 566def load_accuracy_curves(model_choice):
 567    """
 568    Load and plot training/validation accuracy curves from metrics.json.
 569
 570    Parameters
 571    ----------
 572    model_choice : str
 573        Name of the selected model (from get_available_models()).
 574
 575    Returns
 576    -------
 577    tuple of (matplotlib.figure.Figure or None, str)
 578        - Figure: Accuracy curves plot showing train and validation accuracy
 579        - str: Status message describing the operation result
 580
 581    Notes
 582    -----
 583    Accuracy curves show how classification accuracy evolved during training.
 584    The plot includes:
 585    - Blue line with circles: Training accuracy
 586    - Red line with squares: Validation accuracy
 587    - Y-axis limited to [0, 1] for consistency
 588    - Grid for easier reading
 589
 590    A large gap between train and validation accuracy indicates overfitting.
 591
 592    See Also
 593    --------
 594    load_loss_curves : Load loss metrics instead of accuracy
 595    get_metrics_path_for_model : Determine metrics file location
 596    """
 597    try:
 598        metrics_path = get_metrics_path_for_model(model_choice)
 599
 600        if not metrics_path.exists():
 601            return (
 602                None,
 603                f"❌ No training metrics found for \
 604                    {model_choice}. Path: {metrics_path}",
 605            )
 606
 607        with open(metrics_path, "r") as f:
 608            data = json.load(f)
 609
 610        train_accs = data.get("train_accs", [])
 611        val_accs = data.get("val_accs", [])
 612
 613        if not train_accs and not val_accs:
 614            return None, "❌ No accuracy data available"
 615
 616        fig, ax = plt.subplots(figsize=(10, 6))
 617
 618        if train_accs:
 619            ax.plot(
 620                range(1, len(train_accs) + 1),
 621                train_accs,
 622                "b-o",
 623                label="Train Accuracy",
 624                linewidth=2,
 625            )
 626        if val_accs:
 627            ax.plot(
 628                range(1, len(val_accs) + 1),
 629                val_accs,
 630                "r-s",
 631                label="Validation Accuracy",
 632                linewidth=2,
 633            )
 634
 635        ax.set_xlabel("Epoch", fontsize=12)
 636        ax.set_ylabel("Accuracy", fontsize=12)
 637        ax.set_title(
 638            f"Accuracy Curves - \
 639                {model_choice}", fontsize=14, fontweight="bold"
 640        )
 641        ax.legend(fontsize=10)
 642        ax.grid(True, alpha=0.3)
 643        ax.set_ylim([0, 1])
 644        plt.tight_layout()
 645
 646        return fig, f"✅ Accuracy curves \
 647            loaded successfully from {model_choice}"
 648
 649    except Exception as e:
 650        return None, f"❌ Error loading \
 651            accuracy curves: {str(e)}"
 652
 653
 654# ========================
 655# PREDICTION FUNCTIONS
 656# ========================
 657
 658
 659def predict_single_image_gradio(
 660    model_choice, image, carbon_display_text, track_carbon=True
 661):
 662    """
 663    Gradio wrapper for single image prediction with visualization.
 664
 665    Performs inference on a single image and generates a horizontal bar
 666    chart of class probabilities, highlighting the predicted class.
 667
 668    Parameters
 669    ----------
 670    model_choice : str
 671        Name of the selected model (from get_available_models()).
 672    image : np.ndarray
 673        Input image as numpy array (from gr.Image component).
 674    carbon_display_text : str
 675        Current HTML string for the carbon display (to be updated).
 676    track_carbon : bool, optional
 677        Whether to track carbon emissions for this prediction,
 678        by default True.
 679
 680    Returns
 681    -------
 682    tuple of (matplotlib.figure.Figure or None, str, str)
 683        - Figure: Bar chart of class probabilities
 684        - str: Markdown-formatted prediction results and statistics
 685        - str: Updated HTML for carbon display
 686
 687    Notes
 688    -----
 689    The probability bar chart uses:
 690    - Green bar for the predicted class
 691    - Sky blue bars for other classes
 692    - Percentage labels on each bar
 693
 694    If carbon tracking is enabled, emissions are added to the cumulative
 695    total and the carbon display is updated.
 696
 697    See Also
 698    --------
 699    predict_batch : Batch prediction on multiple images
 700    source.predict.predict_image : Core prediction function
 701    """
 702    if image is None:
 703        return None, "Please upload an image", carbon_display_text
 704
 705    if not model_choice:
 706        return None, "Please select a model first", carbon_display_text
 707
 708    try:
 709        models_dict = get_available_models()
 710        model_path = models_dict.get(model_choice)
 711        model, device, transform = load_model_for_inference(
 712            model_path=model_path
 713        )
 714
 715        result = predict_image(
 716            image,
 717            model=model,
 718            transform=transform,
 719            device=device,
 720            track_carbon=track_carbon,
 721        )
 722
 723        fig, ax = plt.subplots(figsize=(10, 6))
 724        probs_list = [result["probabilities"][cls] for cls in cfg.CLASS_NAMES]
 725        pred_idx = result["predicted_idx"]
 726        aux_list = range(len(cfg.CLASS_NAMES))
 727        colors = [
 728            "green" if i == pred_idx else "skyblue" for i in aux_list
 729        ]
 730        bars = ax.barh(cfg.CLASS_NAMES, probs_list, color=colors)
 731        ax.set_xlabel("Probability", fontsize=12)
 732        ax.set_title(
 733            f'Prediction Probabilities\nPredicted Class: \
 734                {result["predicted_class"]}',
 735            fontsize=14,
 736            fontweight="bold",
 737        )
 738        ax.set_xlim([0, 1])
 739
 740        for i, (bar, prob) in enumerate(zip(bars, probs_list)):
 741            ax.text(
 742                prob + 0.02,
 743                bar.get_y() + bar.get_height() / 2,
 744                f"{prob*100:.1f}%",
 745                va="center",
 746                fontsize=10,
 747            )
 748
 749        plt.tight_layout()
 750
 751        result_text = f"### 🎯 Prediction: \
 752        **{result['predicted_class']}**\n\n"
 753        result_text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
 754        result_text += "**All Probabilities:**\n"
 755        for class_name, prob in result["probabilities"].items():
 756            emoji = "🏆" if class_name == result["predicted_class"] \
 757                else "  "
 758            result_text += f"{emoji} {class_name}: {prob*100:.2f}%\n"
 759
 760        updated_carbon_display = carbon_display_text
 761        if result["emissions"]:
 762            emissions = result["emissions"]
 763            result_text += "\n\n### 🌍 Carbon Footprint\n"
 764            result_text += f"- **Emissions:** \
 765                {emissions['emissions_g']:.4f}g CO₂eq\n"
 766            result_text += f"- **🚗 Car equivalent:** \
 767                {emissions['car_distance_formatted']} driven\n"
 768            updated_carbon_display = format_total_emissions_display(
 769                get_emissions_path()
 770            )
 771
 772        return fig, result_text, updated_carbon_display
 773
 774    except Exception as e:
 775        return None, f"❌ Error: {str(e)}", carbon_display_text
 776
 777
 778def predict_folder_gradio(
 779    model_choice, files, carbon_display_text, track_carbon=True
 780):
 781    """
 782    Gradio wrapper for batch prediction on multiple images.
 783
 784    Performs inference on multiple images simultaneously and returns
 785    results in a tabular format with per-image predictions and overall
 786    statistics.
 787
 788    Parameters
 789    ----------
 790    model_choice : str
 791        Name of the selected model (from get_available_models()).
 792    files : list of gr.File
 793        List of uploaded file objects from Gradio file input.
 794    carbon_display_text : str
 795        Current HTML string for the carbon display (to be updated).
 796    track_carbon : bool, optional
 797        Whether to track carbon emissions for this batch, by default True.
 798
 799    Returns
 800    -------
 801    tuple of (pd.DataFrame or None, str, str)
 802        - DataFrame: Table with filename, predicted class, and confidence
 803          for each image
 804        - str: Markdown-formatted summary statistics and carbon footprint
 805        - str: Updated HTML for carbon display
 806
 807    Notes
 808    -----
 809    The results table includes:
 810    - Filename: Original image filename
 811    - Predicted Class: Predicted garbage category
 812    - Confidence (%): Prediction confidence as percentage
 813
 814    Summary statistics include:
 815    - Total images processed
 816    - Number of successful predictions
 817    - Total carbon emissions (if tracked)
 818    - Average emissions per image
 819    - Car distance equivalent
 820
 821    Failed predictions are marked as "Error" in the table with the
 822    error message in the confidence column.
 823
 824    See Also
 825    --------
 826    predict_single_image_gradio : Single image prediction
 827    source.predict.predict_batch : Core batch prediction function
 828    """
 829    if not files or len(files) == 0:
 830        return None, "Please upload images", carbon_display_text
 831
 832    if not model_choice:
 833        return None, "Please select a model first", carbon_display_text
 834
 835    try:
 836        models_dict = get_available_models()
 837        model_path = models_dict.get(model_choice)
 838        model, device, transform = load_model_for_inference(
 839            model_path=model_path
 840        )
 841
 842        image_paths = [file.name for file in files]
 843
 844        batch_result = predict_batch(
 845            image_paths,
 846            model=model,
 847            transform=transform,
 848            device=device,
 849            track_carbon=track_carbon,
 850        )
 851
 852        df_data = []
 853        for result in batch_result["results"]:
 854            if result["status"] == "success":
 855                df_data.append(
 856                    {
 857                        "Filename": result["filename"],
 858                        "Predicted Class": result["predicted_class"],
 859                        "Confidence (%)": f"{result['confidence']*100:.2f}",
 860                    }
 861                )
 862            else:
 863                df_data.append(
 864                    {
 865                        "Filename": result["filename"],
 866                        "Predicted Class": "Error",
 867                        "Confidence (%)": result["error"],
 868                    }
 869                )
 870
 871        df_results = pd.DataFrame(df_data)
 872
 873        summary = batch_result["summary"]
 874        result_text = "### 📊 Batch Prediction Results\n\n"
 875        result_text += f"**Total images processed:** \
 876            {summary['total_images']}\n"
 877        result_text += f"**Successful predictions:** \
 878            {summary['successful']}\n\n"
 879
 880        updated_carbon_display = carbon_display_text
 881        if batch_result["emissions"]:
 882            emissions = batch_result["emissions"]
 883            result_text += "### 🌍 Carbon Footprint\n"
 884            result_text += f"- **Emissions:** \
 885                {emissions['emissions_g']:.4f}g CO₂eq\n"
 886            result_text += f"- **🚗 Car equivalent:** \
 887                {emissions['car_distance_formatted']} driven\n"
 888            result_text += f"- **Avg per image:** \
 889                {emissions['emissions_per_image_g']:.4f}g CO₂eq\n"
 890            updated_carbon_display = format_total_emissions_display(
 891                get_emissions_path()
 892            )
 893
 894        return df_results, result_text, updated_carbon_display
 895
 896    except Exception as e:
 897        return None, f"❌ Error: {str(e)}", carbon_display_text
 898
 899
 900# ========================
 901# UI LAYOUT
 902# ========================
 903
 904
 905def model_evaluation_tab(carbon_display):
 906    """
 907    Create the Model Evaluation and Inference UI section.
 908
 909    Builds a comprehensive Gradio interface for model evaluation, metrics
 910    visualization, and real-time inference. Organized into three main areas:
 911    model selection, metrics visualization, and image prediction.
 912
 913    Parameters
 914    ----------
 915    carbon_display : gr.HTML
 916        The carbon counter display component to update after inference
 917        operations. Shows cumulative emissions across all tracked operations.
 918
 919    Returns
 920    -------
 921    list
 922        Empty list (kept for API consistency).
 923
 924    Notes
 925    -----
 926    The interface is organized into sections:
 927
 928    1. **Model Selection:**
 929       - Radio buttons to choose between best model and latest trained model
 930
 931    2. **Metrics Visualization Tabs:**
 932       - Confusion Matrix: Raw and normalized versions with instant toggle
 933       - Loss Curves: Training and validation loss over epochs
 934       - Accuracy Curves: Training and validation accuracy over epochs
 935       - Calibration Curves: Per-class calibration analysis
 936
 937    3. **Inference Tabs:**
 938       - Single Image: Upload and classify individual images
 939       - Batch Prediction: Process multiple images simultaneously
 940
 941    All expensive computations (confusion matrices, calibration curves) are
 942    cached and automatically invalidated when models are retrained.
 943
 944    Carbon emissions can be optionally tracked for all inference operations
 945    and are displayed in both absolute terms and car distance equivalents.
 946
 947    Examples
 948    --------
 949    >>> with gr.Blocks() as demo:
 950    ...     carbon_display = gr.HTML()
 951    ...     model_evaluation_tab(carbon_display)
 952    """
 953    with gr.Column():
 954        gr.Markdown("### 🔬 Model Evaluation & Inference")
 955        gr.Markdown(
 956            "Evaluate trained models, visualize metrics, \
 957                and make predictions on new images."
 958        )
 959
 960        gr.Markdown("#### 🧠 Model Selection")
 961        model_choice = gr.Radio(
 962            choices=list(get_available_models().keys()),
 963            value=list(get_available_models().keys())[0],
 964            label="Select Model",
 965            info="Choose between the best provided model or \
 966                your latest trained model",
 967        )
 968
 969        gr.Markdown("---")
 970
 971        gr.Markdown("#### 📈 Model Metrics & Visualizations")
 972
 973        with gr.Tabs():
 974            with gr.Tab("Confusion Matrix"):
 975                show_normalized = gr.Checkbox(
 976                    label="Show Normalized",
 977                    value=False,
 978                    info="Toggle between raw counts and normalized \
 979                        percentages",
 980                    visible=False,
 981                )
 982                cm_button = gr.Button(
 983                    "Generate Confusion Matrix", variant="primary"
 984                )
 985                cm_plot = gr.Plot(label="Confusion Matrix")
 986                cm_status = gr.Markdown("")
 987
 988                cm_button.click(
 989                    fn=generate_confusion_matrix,
 990                    inputs=[model_choice, show_normalized],
 991                    outputs=[cm_plot, cm_status, show_normalized],
 992                )
 993
 994                show_normalized.change(
 995                    fn=toggle_confusion_matrix,
 996                    inputs=[show_normalized],
 997                    outputs=[cm_plot],
 998                )
 999
1000            with gr.Tab("Loss Curves"):
1001                loss_button = gr.Button("Load Loss Curves", variant="primary")
1002                loss_plot = gr.Plot(label="Loss Curves")
1003                loss_status = gr.Markdown("")
1004
1005                loss_button.click(
1006                    fn=load_loss_curves,
1007                    inputs=[model_choice],
1008                    outputs=[loss_plot, loss_status],
1009                )
1010
1011            with gr.Tab("Accuracy Curves"):
1012                acc_button = gr.Button(
1013                    "Load Accuracy Curves", variant="primary"
1014                )
1015                acc_plot = gr.Plot(label="Accuracy Curves")
1016                acc_status = gr.Markdown("")
1017
1018                acc_button.click(
1019                    fn=load_accuracy_curves,
1020                    inputs=[model_choice],
1021                    outputs=[acc_plot, acc_status],
1022                )
1023
1024            with gr.Tab("Calibration Curves"):
1025                calib_button = gr.Button(
1026                    "Generate Calibration Curves", variant="primary"
1027                )
1028                calib_plot = gr.Plot(label="Calibration Curves")
1029                calib_status = gr.Markdown("")
1030
1031                calib_button.click(
1032                    fn=generate_calibration_curves,
1033                    inputs=[model_choice],
1034                    outputs=[calib_plot, calib_status],
1035                )
1036
1037        gr.Markdown("---")
1038
1039        gr.Markdown("#### 🔍 Image Prediction")
1040
1041        with gr.Tabs():
1042            with gr.Tab("Single Image"):
1043                gr.Markdown(
1044                    "Upload an image to classify it into one of \
1045                        the garbage categories."
1046                )
1047
1048                with gr.Row():
1049                    with gr.Column(scale=1):
1050                        single_image_input = gr.Image(
1051                            label="Upload Image", type="numpy", height=400
1052                        )
1053                        single_track_carbon = gr.Checkbox(
1054                            label="🌍 Track Carbon Emissions", value=True
1055                        )
1056                        single_predict_button = gr.Button(
1057                            "🔍 Predict", variant="primary", size="lg"
1058                        )
1059
1060                    with gr.Column(scale=1):
1061                        single_probs_plot = gr.Plot(
1062                            label="Class Probabilities"
1063                        )
1064
1065                single_result_text = gr.Markdown(
1066                    "Upload an image and click 'Predict'"
1067                )
1068
1069                single_predict_button.click(
1070                    fn=predict_single_image_gradio,
1071                    inputs=[
1072                        model_choice,
1073                        single_image_input,
1074                        carbon_display,
1075                        single_track_carbon,
1076                    ],
1077                    outputs=[
1078                        single_probs_plot,
1079                        single_result_text,
1080                        carbon_display
1081                    ],
1082                )
1083
1084            with gr.Tab("Batch Prediction"):
1085                gr.Markdown("Upload multiple images \
1086                    to classify them all at once.")
1087
1088                batch_image_input = gr.File(
1089                    label="Upload Images",
1090                    file_count="multiple", file_types=["image"]
1091                )
1092                batch_track_carbon = gr.Checkbox(
1093                    label="🌍 Track Carbon Emissions", value=True
1094                )
1095                batch_predict_button = gr.Button(
1096                    "🔍 Predict All", variant="primary", size="lg"
1097                )
1098
1099                batch_result_text = gr.Markdown(
1100                    "Upload images and click 'Predict All'"
1101                )
1102                batch_results_table = gr.Dataframe(
1103                    label="Prediction Results",
1104                    interactive=False, wrap=True
1105                )
1106
1107                batch_predict_button.click(
1108                    fn=predict_folder_gradio,
1109                    inputs=[
1110                        model_choice,
1111                        batch_image_input,
1112                        carbon_display,
1113                        batch_track_carbon,
1114                    ],
1115                    outputs=[
1116                        batch_results_table,
1117                        batch_result_text,
1118                        carbon_display
1119                    ],
1120                )
1121
1122        gr.Markdown("---")
1123        gr.Markdown(
1124            "**ℹ️ Info:** Carbon emissions are tracked for \
1125                inference operations and added to the total \
1126                    carbon footprint."
1127        )
1128
1129    return []
def get_emissions_path():
45def get_emissions_path():
46    """
47    Get the path to the emissions CSV file.
48
49    Returns
50    -------
51    pathlib.Path
52        Path object pointing to the emissions.csv file in the model directory.
53
54    Notes
55    -----
56    The emissions file is stored in the same directory as the trained model
57    checkpoint, as defined in the global configuration.
58    """
59    return Path(cfg.MODEL_PATH).parent / "emissions.csv"

Get the path to the emissions CSV file.

Returns
  • pathlib.Path: Path object pointing to the emissions.csv file in the model directory.
Notes

The emissions file is stored in the same directory as the trained model checkpoint, as defined in the global configuration.

def get_available_models():
62def get_available_models():
63    """
64    Get dictionary of available trained models.
65
66    Returns
67    -------
68    dict of {str: str}
69        Dictionary mapping model display names to their checkpoint file paths.
70        Includes both the best provided model and the latest trained model.
71
72    Examples
73    --------
74    >>> models = get_available_models()
75    >>> models
76    {
77        'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt',
78        'Latest Trained Model': 'models/resnet18_garbage.ckpt'
79    }
80    """
81    models_dict = {
82        "Best Model (Provided)": str(Path(
83            "models/best/model_resnet18_garbage.ckpt"
84        )),
85        "Latest Trained Model": str(cfg.MODEL_PATH),
86    }
87    return models_dict

Get dictionary of available trained models.

Returns
  • dict of {str (str}): Dictionary mapping model display names to their checkpoint file paths. Includes both the best provided model and the latest trained model.
Examples
>>> models = get_available_models()
>>> models
{
    'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt',
    'Latest Trained Model': 'models/resnet18_garbage.ckpt'
}
CACHE_DIR = PosixPath('app/sections/cached_data')
def is_cache_valid(cache_file, model_path):
 95def is_cache_valid(cache_file, model_path):
 96    """
 97    Check if cached file is newer than the model checkpoint.
 98
 99    Validates cache by comparing modification timestamps. Cache is considered
100    valid only if it was created/modified after the model checkpoint.
101
102    Parameters
103    ----------
104    cache_file : pathlib.Path
105        Path to the cached file to validate.
106    model_path : str or pathlib.Path
107        Path to the model checkpoint file.
108
109    Returns
110    -------
111    bool
112        True if cache exists and is newer than the model, False otherwise.
113
114    Notes
115    -----
116    This function implements automatic cache invalidation to ensure that
117    evaluation metrics are regenerated when a model is retrained.
118
119    Examples
120    --------
121    >>> cache_file = Path("cached_data/cm_raw_Latest.pkl")
122    >>> model_path = "models/resnet18_garbage.ckpt"
123    >>> is_cache_valid(cache_file, model_path)
124    False  # Cache doesn't exist or is older than model
125    """
126    if not cache_file.exists():
127        return False
128
129    if not Path(model_path).exists():
130        return False
131
132    cache_time = cache_file.stat().st_mtime
133    model_time = Path(model_path).stat().st_mtime
134
135    print(cache_file, model_path)
136    print(cache_time, model_time)
137
138    return cache_time > model_time  # Cache is valid if newer than model

Check if cached file is newer than the model checkpoint.

Validates cache by comparing modification timestamps. Cache is considered valid only if it was created/modified after the model checkpoint.

Parameters
  • cache_file (pathlib.Path): Path to the cached file to validate.
  • model_path (str or pathlib.Path): Path to the model checkpoint file.
Returns
  • bool: True if cache exists and is newer than the model, False otherwise.
Notes

This function implements automatic cache invalidation to ensure that evaluation metrics are regenerated when a model is retrained.

Examples
>>> cache_file = Path("cached_data/cm_raw_Latest.pkl")
>>> model_path = "models/resnet18_garbage.ckpt"
>>> is_cache_valid(cache_file, model_path)
False  # Cache doesn't exist or is older than model
confusion_matrices_state = {'raw': None, 'normalized': None, 'model_choice': None}
def generate_confusion_matrix( model_choice, show_normalized, progress=<gradio.helpers.Progress object>):
154def generate_confusion_matrix(
155    model_choice,
156    show_normalized,
157    progress=gr.Progress()
158):
159    """
160    Generate and cache both raw and normalized confusion matrices.
161
162    This function generates confusion matrices for the validation set,
163    caching both raw and normalized versions to enable instant toggling
164    without recomputation. Implements three-tier caching: memory, disk,
165    and automatic invalidation.
166
167    Parameters
168    ----------
169    model_choice : str
170        Name of the selected model (from get_available_models()).
171    show_normalized : bool
172        Whether to display the normalized version initially.
173    progress : gr.Progress, optional
174        Gradio progress tracker for UI updates.
175
176    Returns
177    -------
178    tuple of (matplotlib.figure.Figure or None, str, gr.update)
179        - Figure: The confusion matrix plot (raw or normalized)
180        - str: Status message describing the operation result
181        - gr.update: Gradio update object to control checkbox visibility
182
183    Notes
184    -----
185    Caching strategy:
186    1. Check memory cache (fastest)
187    2. Check disk cache with validation (fast)
188    3. Regenerate if cache invalid or missing (slow)
189
190    Both raw and normalized matrices are always generated together and
191    stored separately to enable instant toggling via the checkbox.
192
193    The cache is automatically invalidated when the model file is modified,
194    ensuring metrics reflect the current model state.
195
196    See Also
197    --------
198    toggle_confusion_matrix : Switch between raw/normalized without
199    regenerating.
200    is_cache_valid : Cache validation logic
201    """
202    # global confusion_matrices_state # TODO: Revert this (uncomment)
203
204    if not model_choice:
205        return None, "Please select a model first", gr.update(visible=False)
206
207    # Get model path
208    models_dict = get_available_models()
209    model_path = models_dict.get(model_choice)
210
211    # Check if we already have matrices for this model in memory
212    if (
213        confusion_matrices_state["model_choice"] == model_choice
214        and confusion_matrices_state["raw"] is not None
215        and confusion_matrices_state["normalized"] is not None
216    ):
217
218        selected_matrix = (
219            confusion_matrices_state["normalized"]
220            if show_normalized
221            else confusion_matrices_state["raw"]
222        )
223        matrix_type = "Normalized" if show_normalized else "Raw"
224        return (
225            selected_matrix,
226            f"✅ {matrix_type} confusion matrix (from memory)",
227            gr.update(visible=True, interactive=True),
228        )
229
230    # Check disk cache
231    cache_file_raw = \
232        CACHE_DIR / f"cm_raw_{model_choice.replace(' ', '_')}.pkl"
233    cache_file_norm = \
234        CACHE_DIR / f"cm_norm_{model_choice.replace(' ', '_')}.pkl"
235
236    if is_cache_valid(cache_file_raw, model_path) and is_cache_valid(
237        cache_file_norm, model_path
238    ):
239        try:
240            progress(0.2, desc="Loading from cache...")
241            with open(cache_file_raw, "rb") as f:
242                fig_raw = pickle.load(f)
243            with open(cache_file_norm, "rb") as f:
244                fig_norm = pickle.load(f)
245
246            confusion_matrices_state["raw"] = fig_raw
247            confusion_matrices_state["normalized"] = fig_norm
248            confusion_matrices_state["model_choice"] = model_choice
249
250            selected_matrix = fig_norm if show_normalized else fig_raw
251            matrix_type = "Normalized" if show_normalized else "Raw"
252            progress(1.0, desc="Done!")
253            return (
254                selected_matrix,
255                f"✅ {matrix_type} confusion matrix loaded from cache",
256                gr.update(visible=True, interactive=True),
257            )
258        except Exception as e:
259            print(f"[WARN] Failed to load cache: {e}")
260
261    # Generate new matrices
262    try:
263        progress(0.1, desc="Loading model...")
264        analyzer = GarbageModelAnalyzer()
265        analyzer.load_model(checkpoint_path=model_path)
266
267        progress(0.3, desc="Setting up data...")
268        analyzer.setup_data(batch_size=32)
269
270        progress(0.5, desc="Evaluating model...")
271        val_loader = analyzer.data_module.val_dataloader()
272        preds, labels, probs = analyzer.evaluate_loader(val_loader)
273
274        progress(0.7, desc="Generating confusion matrices...")
275
276        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
277
278        num_classes = cfg.NUM_CLASSES
279        cm_raw = confusion_matrix(
280            labels.cpu().numpy(),
281            preds.cpu().numpy(),
282            labels=range(num_classes)
283        )
284        cm_norm = cm_raw.astype("float") / cm_raw.sum(axis=1)[:, np.newaxis]
285
286        # Generate RAW matrix figure
287        fig_raw, ax_raw = plt.subplots(figsize=(10, 8))
288        disp_raw = ConfusionMatrixDisplay(
289            confusion_matrix=cm_raw, display_labels=cfg.CLASS_NAMES
290        )
291        disp_raw.plot(cmap=plt.cm.Blues, ax=ax_raw)
292        ax_raw.set_title("Confusion Matrix - Validation Set")
293        plt.tight_layout()
294
295        # Generate NORMALIZED matrix figure
296        fig_norm, ax_norm = plt.subplots(figsize=(10, 8))
297        disp_norm = ConfusionMatrixDisplay(
298            confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES
299        )
300        disp_norm.plot(cmap=plt.cm.Blues, ax=ax_norm)
301        ax_norm.set_title("Normalized Confusion Matrix - Validation Set")
302        plt.tight_layout()
303
304        # Save to cache
305        progress(0.9, desc="Saving to cache...")
306        with open(cache_file_raw, "wb") as f:
307            pickle.dump(fig_raw, f)
308        with open(cache_file_norm, "wb") as f:
309            pickle.dump(fig_norm, f)
310
311        # Update state
312        confusion_matrices_state["raw"] = fig_raw
313        confusion_matrices_state["normalized"] = fig_norm
314        confusion_matrices_state["model_choice"] = model_choice
315
316        selected_matrix = fig_norm if show_normalized else fig_raw
317        matrix_type = "Normalized" if show_normalized else "Raw"
318
319        progress(1.0, desc="Done!")
320        return (
321            selected_matrix,
322            f"✅ {matrix_type} confusion matrix generated and cached",
323            gr.update(visible=True, interactive=True),
324        )
325
326    except Exception as e:
327        return None, f"❌ Error: {str(e)}", gr.update(visible=False)

Generate and cache both raw and normalized confusion matrices.

This function generates confusion matrices for the validation set, caching both raw and normalized versions to enable instant toggling without recomputation. Implements three-tier caching: memory, disk, and automatic invalidation.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
  • show_normalized (bool): Whether to display the normalized version initially.
  • progress (gr.Progress, optional): Gradio progress tracker for UI updates.
Returns
  • tuple of (matplotlib.figure.Figure or None, str, gr.update): - Figure: The confusion matrix plot (raw or normalized)
    • str: Status message describing the operation result
    • gr.update: Gradio update object to control checkbox visibility
Notes

Caching strategy:

  1. Check memory cache (fastest)
  2. Check disk cache with validation (fast)
  3. Regenerate if cache invalid or missing (slow)

Both raw and normalized matrices are always generated together and stored separately to enable instant toggling via the checkbox.

The cache is automatically invalidated when the model file is modified, ensuring metrics reflect the current model state.

See Also

toggle_confusion_matrix: Switch between raw/normalized without
regenerating.
is_cache_valid: Cache validation logic

def toggle_confusion_matrix(show_normalized):
330def toggle_confusion_matrix(show_normalized):
331    """
332    Toggle between raw and normalized confusion matrices instantly.
333
334    Switches the displayed confusion matrix without regenerating, using
335    cached versions from memory. This enables instant UI response to the
336    normalization checkbox.
337
338    Parameters
339    ----------
340    show_normalized : bool
341        If True, return normalized matrix; if False, return raw matrix.
342
343    Returns
344    -------
345    matplotlib.figure.Figure or None
346        The requested confusion matrix figure, or None if not available
347        in memory cache.
348
349    Notes
350    -----
351    This function only accesses the in-memory cache. If matrices are not
352    yet generated, it returns None. The generate_confusion_matrix function
353    should be called first to populate the cache.
354
355    See Also
356    --------
357    generate_confusion_matrix : Generate and cache both matrix versions
358    """
359    # global confusion_matrices_state # TODO: Revert this (uncomment)
360
361    if show_normalized:
362        if confusion_matrices_state["normalized"] is not None:
363            return confusion_matrices_state["normalized"]
364    else:
365        if confusion_matrices_state["raw"] is not None:
366            return confusion_matrices_state["raw"]
367
368    return None

Toggle between raw and normalized confusion matrices instantly.

Switches the displayed confusion matrix without regenerating, using cached versions from memory. This enables instant UI response to the normalization checkbox.

Parameters
  • show_normalized (bool): If True, return normalized matrix; if False, return raw matrix.
Returns
  • matplotlib.figure.Figure or None: The requested confusion matrix figure, or None if not available in memory cache.
Notes

This function only accesses the in-memory cache. If matrices are not yet generated, it returns None. The generate_confusion_matrix function should be called first to populate the cache.

See Also

generate_confusion_matrix: Generate and cache both matrix versions

def generate_calibration_curves(model_choice, progress=<gradio.helpers.Progress object>):
371def generate_calibration_curves(model_choice, progress=gr.Progress()):
372    """
373    Generate and cache calibration curves for the selected model.
374
375    Calibration curves show how well predicted probabilities match actual
376    outcomes. This function generates curves for all classes and caches
377    the result for faster subsequent access.
378
379    Parameters
380    ----------
381    model_choice : str
382        Name of the selected model (from get_available_models()).
383    progress : gr.Progress, optional
384        Gradio progress tracker for UI updates.
385
386    Returns
387    -------
388    tuple of (matplotlib.figure.Figure or None, str)
389        - Figure: The calibration curves plot
390        - str: Status message describing the operation result
391
392    Notes
393    -----
394    The calibration plot is generated using GarbageModelAnalyzer's
395    plot_calibration_curves method and cached to disk. Cache is
396    automatically invalidated when the model is retrained.
397
398    Well-calibrated models should have curves close to the diagonal,
399    indicating that predicted probabilities match actual frequencies.
400
401    See Also
402    --------
403    is_cache_valid : Cache validation logic
404    GarbageModelAnalyzer.plot_calibration_curves : Core plotting function
405    """
406    if not model_choice:
407        return None, "Please select a model first"
408
409    # Get model path
410    models_dict = get_available_models()
411    model_path = models_dict.get(model_choice)
412
413    # Check disk cache
414    cache_file = CACHE_DIR / f"calib_{model_choice.replace(' ', '_')}.pkl"
415
416    if is_cache_valid(cache_file, model_path):
417        try:
418            progress(0.2, desc="Loading from cache...")
419            with open(cache_file, "rb") as f:
420                fig = pickle.load(f)
421            progress(1.0, desc="Done!")
422            return fig, "✅ Calibration curves loaded from cache"
423        except Exception as e:
424            print(f"[WARN] Failed to load cache: {e}")
425
426    try:
427        progress(0.1, desc="Loading model...")
428        analyzer = GarbageModelAnalyzer()
429        analyzer.load_model(checkpoint_path=model_path)
430
431        progress(0.3, desc="Setting up data...")
432        analyzer.setup_data(batch_size=32)
433
434        progress(0.5, desc="Evaluating model...")
435        val_loader = analyzer.data_module.val_dataloader()
436        preds, labels, probs = analyzer.evaluate_loader(val_loader)
437
438        progress(0.8, desc="Generating calibration curves...")
439
440        analyzer.plot_calibration_curves(labels, probs)
441        fig = plt.gcf()
442
443        # Save to cache
444        progress(0.9, desc="Saving to cache...")
445        with open(cache_file, "wb") as f:
446            pickle.dump(fig, f)
447
448        progress(1.0, desc="Done!")
449        return fig, "✅ Calibration curves generated and cached"
450
451    except Exception as e:
452        return None, f"❌ Error: {str(e)}"

Generate and cache calibration curves for the selected model.

Calibration curves show how well predicted probabilities match actual outcomes. This function generates curves for all classes and caches the result for faster subsequent access.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
  • progress (gr.Progress, optional): Gradio progress tracker for UI updates.
Returns
  • tuple of (matplotlib.figure.Figure or None, str): - Figure: The calibration curves plot
    • str: Status message describing the operation result
Notes

The calibration plot is generated using GarbageModelAnalyzer's plot_calibration_curves method and cached to disk. Cache is automatically invalidated when the model is retrained.

Well-calibrated models should have curves close to the diagonal, indicating that predicted probabilities match actual frequencies.

See Also

is_cache_valid: Cache validation logic
GarbageModelAnalyzer.plot_calibration_curves: Core plotting function

def get_metrics_path_for_model(model_choice):
455def get_metrics_path_for_model(model_choice):
456    """
457    Get the correct metrics.json path for the selected model.
458
459    Different models store their training metrics in different locations.
460    This function maps model choices to their corresponding metrics files.
461
462    Parameters
463    ----------
464    model_choice : str
465        Name of the selected model (from get_available_models()).
466
467    Returns
468    -------
469    pathlib.Path
470        Path to the metrics.json file for the selected model.
471
472    Notes
473    -----
474    The best provided model has metrics in a fixed location, while the
475    latest trained model uses the configured loss curves path.
476    """
477    if model_choice == "Best Model (Provided)":
478        return Path("models/best/performance/loss_curves/metrics.json")
479    else:
480        return Path(cfg.LOSS_CURVES_PATH) / "metrics.json"

Get the correct metrics.json path for the selected model.

Different models store their training metrics in different locations. This function maps model choices to their corresponding metrics files.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
Returns
  • pathlib.Path: Path to the metrics.json file for the selected model.
Notes

The best provided model has metrics in a fixed location, while the latest trained model uses the configured loss curves path.

def load_loss_curves(model_choice):
483def load_loss_curves(model_choice):
484    """
485    Load and plot training/validation loss curves from metrics.json.
486
487    Parameters
488    ----------
489    model_choice : str
490        Name of the selected model (from get_available_models()).
491
492    Returns
493    -------
494    tuple of (matplotlib.figure.Figure or None, str)
495        - Figure: Loss curves plot showing train and validation loss
496        - str: Status message describing the operation result
497
498    Notes
499    -----
500    Loss curves show how training and validation loss evolved during
501    training. Diverging curves may indicate overfitting, while parallel
502    curves suggest good generalization.
503
504    The plot includes:
505    - Blue line with circles: Training loss
506    - Red line with squares: Validation loss
507    - Grid for easier reading
508
509    See Also
510    --------
511    load_accuracy_curves : Load accuracy metrics instead of loss
512    get_metrics_path_for_model : Determine metrics file location
513    """
514    try:
515        metrics_path = get_metrics_path_for_model(model_choice)
516
517        if not metrics_path.exists():
518            return (
519                None,
520                f"❌ No training metrics found for \
521                    {model_choice}. Path: {metrics_path}",
522            )
523
524        with open(metrics_path, "r") as f:
525            data = json.load(f)
526
527        train_losses = data.get("train_losses", [])
528        val_losses = data.get("val_losses", [])
529
530        if not train_losses and not val_losses:
531            return None, "❌ No loss data available"
532
533        fig, ax = plt.subplots(figsize=(10, 6))
534        epochs = range(1, len(train_losses) + 1)
535
536        if train_losses:
537            ax.plot(
538                epochs,
539                train_losses,
540                "b-o",
541                label="Train Loss",
542                linewidth=2
543            )
544        if val_losses:
545            ax.plot(
546                epochs,
547                val_losses,
548                "r-s",
549                label="Validation Loss",
550                linewidth=2
551            )
552
553        ax.set_xlabel("Epoch", fontsize=12)
554        ax.set_ylabel("Loss", fontsize=12)
555        ax.set_title(f"Loss Curves - \
556            {model_choice}", fontsize=14, fontweight="bold")
557        ax.legend(fontsize=10)
558        ax.grid(True, alpha=0.3)
559        plt.tight_layout()
560
561        return fig, f"✅ Loss curves loaded successfully from {model_choice}"
562
563    except Exception as e:
564        return None, f"❌ Error loading loss curves: {str(e)}"

Load and plot training/validation loss curves from metrics.json.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
Returns
  • tuple of (matplotlib.figure.Figure or None, str): - Figure: Loss curves plot showing train and validation loss
    • str: Status message describing the operation result
Notes

Loss curves show how training and validation loss evolved during training. Diverging curves may indicate overfitting, while parallel curves suggest good generalization.

The plot includes:

  • Blue line with circles: Training loss
  • Red line with squares: Validation loss
  • Grid for easier reading
See Also

load_accuracy_curves: Load accuracy metrics instead of loss
get_metrics_path_for_model: Determine metrics file location

def load_accuracy_curves(model_choice):
567def load_accuracy_curves(model_choice):
568    """
569    Load and plot training/validation accuracy curves from metrics.json.
570
571    Parameters
572    ----------
573    model_choice : str
574        Name of the selected model (from get_available_models()).
575
576    Returns
577    -------
578    tuple of (matplotlib.figure.Figure or None, str)
579        - Figure: Accuracy curves plot showing train and validation accuracy
580        - str: Status message describing the operation result
581
582    Notes
583    -----
584    Accuracy curves show how classification accuracy evolved during training.
585    The plot includes:
586    - Blue line with circles: Training accuracy
587    - Red line with squares: Validation accuracy
588    - Y-axis limited to [0, 1] for consistency
589    - Grid for easier reading
590
591    A large gap between train and validation accuracy indicates overfitting.
592
593    See Also
594    --------
595    load_loss_curves : Load loss metrics instead of accuracy
596    get_metrics_path_for_model : Determine metrics file location
597    """
598    try:
599        metrics_path = get_metrics_path_for_model(model_choice)
600
601        if not metrics_path.exists():
602            return (
603                None,
604                f"❌ No training metrics found for \
605                    {model_choice}. Path: {metrics_path}",
606            )
607
608        with open(metrics_path, "r") as f:
609            data = json.load(f)
610
611        train_accs = data.get("train_accs", [])
612        val_accs = data.get("val_accs", [])
613
614        if not train_accs and not val_accs:
615            return None, "❌ No accuracy data available"
616
617        fig, ax = plt.subplots(figsize=(10, 6))
618
619        if train_accs:
620            ax.plot(
621                range(1, len(train_accs) + 1),
622                train_accs,
623                "b-o",
624                label="Train Accuracy",
625                linewidth=2,
626            )
627        if val_accs:
628            ax.plot(
629                range(1, len(val_accs) + 1),
630                val_accs,
631                "r-s",
632                label="Validation Accuracy",
633                linewidth=2,
634            )
635
636        ax.set_xlabel("Epoch", fontsize=12)
637        ax.set_ylabel("Accuracy", fontsize=12)
638        ax.set_title(
639            f"Accuracy Curves - \
640                {model_choice}", fontsize=14, fontweight="bold"
641        )
642        ax.legend(fontsize=10)
643        ax.grid(True, alpha=0.3)
644        ax.set_ylim([0, 1])
645        plt.tight_layout()
646
647        return fig, f"✅ Accuracy curves \
648            loaded successfully from {model_choice}"
649
650    except Exception as e:
651        return None, f"❌ Error loading \
652            accuracy curves: {str(e)}"

Load and plot training/validation accuracy curves from metrics.json.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
Returns
  • tuple of (matplotlib.figure.Figure or None, str): - Figure: Accuracy curves plot showing train and validation accuracy
    • str: Status message describing the operation result
Notes

Accuracy curves show how classification accuracy evolved during training. The plot includes:

  • Blue line with circles: Training accuracy
  • Red line with squares: Validation accuracy
  • Y-axis limited to [0, 1] for consistency
  • Grid for easier reading

A large gap between train and validation accuracy indicates overfitting.

See Also

load_loss_curves: Load loss metrics instead of accuracy
get_metrics_path_for_model: Determine metrics file location

def predict_single_image_gradio(model_choice, image, carbon_display_text, track_carbon=True):
660def predict_single_image_gradio(
661    model_choice, image, carbon_display_text, track_carbon=True
662):
663    """
664    Gradio wrapper for single image prediction with visualization.
665
666    Performs inference on a single image and generates a horizontal bar
667    chart of class probabilities, highlighting the predicted class.
668
669    Parameters
670    ----------
671    model_choice : str
672        Name of the selected model (from get_available_models()).
673    image : np.ndarray
674        Input image as numpy array (from gr.Image component).
675    carbon_display_text : str
676        Current HTML string for the carbon display (to be updated).
677    track_carbon : bool, optional
678        Whether to track carbon emissions for this prediction,
679        by default True.
680
681    Returns
682    -------
683    tuple of (matplotlib.figure.Figure or None, str, str)
684        - Figure: Bar chart of class probabilities
685        - str: Markdown-formatted prediction results and statistics
686        - str: Updated HTML for carbon display
687
688    Notes
689    -----
690    The probability bar chart uses:
691    - Green bar for the predicted class
692    - Sky blue bars for other classes
693    - Percentage labels on each bar
694
695    If carbon tracking is enabled, emissions are added to the cumulative
696    total and the carbon display is updated.
697
698    See Also
699    --------
700    predict_batch : Batch prediction on multiple images
701    source.predict.predict_image : Core prediction function
702    """
703    if image is None:
704        return None, "Please upload an image", carbon_display_text
705
706    if not model_choice:
707        return None, "Please select a model first", carbon_display_text
708
709    try:
710        models_dict = get_available_models()
711        model_path = models_dict.get(model_choice)
712        model, device, transform = load_model_for_inference(
713            model_path=model_path
714        )
715
716        result = predict_image(
717            image,
718            model=model,
719            transform=transform,
720            device=device,
721            track_carbon=track_carbon,
722        )
723
724        fig, ax = plt.subplots(figsize=(10, 6))
725        probs_list = [result["probabilities"][cls] for cls in cfg.CLASS_NAMES]
726        pred_idx = result["predicted_idx"]
727        aux_list = range(len(cfg.CLASS_NAMES))
728        colors = [
729            "green" if i == pred_idx else "skyblue" for i in aux_list
730        ]
731        bars = ax.barh(cfg.CLASS_NAMES, probs_list, color=colors)
732        ax.set_xlabel("Probability", fontsize=12)
733        ax.set_title(
734            f'Prediction Probabilities\nPredicted Class: \
735                {result["predicted_class"]}',
736            fontsize=14,
737            fontweight="bold",
738        )
739        ax.set_xlim([0, 1])
740
741        for i, (bar, prob) in enumerate(zip(bars, probs_list)):
742            ax.text(
743                prob + 0.02,
744                bar.get_y() + bar.get_height() / 2,
745                f"{prob*100:.1f}%",
746                va="center",
747                fontsize=10,
748            )
749
750        plt.tight_layout()
751
752        result_text = f"### 🎯 Prediction: \
753        **{result['predicted_class']}**\n\n"
754        result_text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
755        result_text += "**All Probabilities:**\n"
756        for class_name, prob in result["probabilities"].items():
757            emoji = "🏆" if class_name == result["predicted_class"] \
758                else "  "
759            result_text += f"{emoji} {class_name}: {prob*100:.2f}%\n"
760
761        updated_carbon_display = carbon_display_text
762        if result["emissions"]:
763            emissions = result["emissions"]
764            result_text += "\n\n### 🌍 Carbon Footprint\n"
765            result_text += f"- **Emissions:** \
766                {emissions['emissions_g']:.4f}g CO₂eq\n"
767            result_text += f"- **🚗 Car equivalent:** \
768                {emissions['car_distance_formatted']} driven\n"
769            updated_carbon_display = format_total_emissions_display(
770                get_emissions_path()
771            )
772
773        return fig, result_text, updated_carbon_display
774
775    except Exception as e:
776        return None, f"❌ Error: {str(e)}", carbon_display_text

Gradio wrapper for single image prediction with visualization.

Performs inference on a single image and generates a horizontal bar chart of class probabilities, highlighting the predicted class.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
  • image (np.ndarray): Input image as numpy array (from gr.Image component).
  • carbon_display_text (str): Current HTML string for the carbon display (to be updated).
  • track_carbon (bool, optional): Whether to track carbon emissions for this prediction, by default True.
Returns
  • tuple of (matplotlib.figure.Figure or None, str, str): - Figure: Bar chart of class probabilities
    • str: Markdown-formatted prediction results and statistics
    • str: Updated HTML for carbon display
Notes

The probability bar chart uses:

  • Green bar for the predicted class
  • Sky blue bars for other classes
  • Percentage labels on each bar

If carbon tracking is enabled, emissions are added to the cumulative total and the carbon display is updated.

See Also

predict_batch: Batch prediction on multiple images
source.predict.predict_image: Core prediction function

def predict_folder_gradio(model_choice, files, carbon_display_text, track_carbon=True):
779def predict_folder_gradio(
780    model_choice, files, carbon_display_text, track_carbon=True
781):
782    """
783    Gradio wrapper for batch prediction on multiple images.
784
785    Performs inference on multiple images simultaneously and returns
786    results in a tabular format with per-image predictions and overall
787    statistics.
788
789    Parameters
790    ----------
791    model_choice : str
792        Name of the selected model (from get_available_models()).
793    files : list of gr.File
794        List of uploaded file objects from Gradio file input.
795    carbon_display_text : str
796        Current HTML string for the carbon display (to be updated).
797    track_carbon : bool, optional
798        Whether to track carbon emissions for this batch, by default True.
799
800    Returns
801    -------
802    tuple of (pd.DataFrame or None, str, str)
803        - DataFrame: Table with filename, predicted class, and confidence
804          for each image
805        - str: Markdown-formatted summary statistics and carbon footprint
806        - str: Updated HTML for carbon display
807
808    Notes
809    -----
810    The results table includes:
811    - Filename: Original image filename
812    - Predicted Class: Predicted garbage category
813    - Confidence (%): Prediction confidence as percentage
814
815    Summary statistics include:
816    - Total images processed
817    - Number of successful predictions
818    - Total carbon emissions (if tracked)
819    - Average emissions per image
820    - Car distance equivalent
821
822    Failed predictions are marked as "Error" in the table with the
823    error message in the confidence column.
824
825    See Also
826    --------
827    predict_single_image_gradio : Single image prediction
828    source.predict.predict_batch : Core batch prediction function
829    """
830    if not files or len(files) == 0:
831        return None, "Please upload images", carbon_display_text
832
833    if not model_choice:
834        return None, "Please select a model first", carbon_display_text
835
836    try:
837        models_dict = get_available_models()
838        model_path = models_dict.get(model_choice)
839        model, device, transform = load_model_for_inference(
840            model_path=model_path
841        )
842
843        image_paths = [file.name for file in files]
844
845        batch_result = predict_batch(
846            image_paths,
847            model=model,
848            transform=transform,
849            device=device,
850            track_carbon=track_carbon,
851        )
852
853        df_data = []
854        for result in batch_result["results"]:
855            if result["status"] == "success":
856                df_data.append(
857                    {
858                        "Filename": result["filename"],
859                        "Predicted Class": result["predicted_class"],
860                        "Confidence (%)": f"{result['confidence']*100:.2f}",
861                    }
862                )
863            else:
864                df_data.append(
865                    {
866                        "Filename": result["filename"],
867                        "Predicted Class": "Error",
868                        "Confidence (%)": result["error"],
869                    }
870                )
871
872        df_results = pd.DataFrame(df_data)
873
874        summary = batch_result["summary"]
875        result_text = "### 📊 Batch Prediction Results\n\n"
876        result_text += f"**Total images processed:** \
877            {summary['total_images']}\n"
878        result_text += f"**Successful predictions:** \
879            {summary['successful']}\n\n"
880
881        updated_carbon_display = carbon_display_text
882        if batch_result["emissions"]:
883            emissions = batch_result["emissions"]
884            result_text += "### 🌍 Carbon Footprint\n"
885            result_text += f"- **Emissions:** \
886                {emissions['emissions_g']:.4f}g CO₂eq\n"
887            result_text += f"- **🚗 Car equivalent:** \
888                {emissions['car_distance_formatted']} driven\n"
889            result_text += f"- **Avg per image:** \
890                {emissions['emissions_per_image_g']:.4f}g CO₂eq\n"
891            updated_carbon_display = format_total_emissions_display(
892                get_emissions_path()
893            )
894
895        return df_results, result_text, updated_carbon_display
896
897    except Exception as e:
898        return None, f"❌ Error: {str(e)}", carbon_display_text

Gradio wrapper for batch prediction on multiple images.

Performs inference on multiple images simultaneously and returns results in a tabular format with per-image predictions and overall statistics.

Parameters
  • model_choice (str): Name of the selected model (from get_available_models()).
  • files (list of gr.File): List of uploaded file objects from Gradio file input.
  • carbon_display_text (str): Current HTML string for the carbon display (to be updated).
  • track_carbon (bool, optional): Whether to track carbon emissions for this batch, by default True.
Returns
  • tuple of (pd.DataFrame or None, str, str): - DataFrame: Table with filename, predicted class, and confidence for each image
    • str: Markdown-formatted summary statistics and carbon footprint
    • str: Updated HTML for carbon display
Notes

The results table includes:

  • Filename: Original image filename
  • Predicted Class: Predicted garbage category
  • Confidence (%): Prediction confidence as percentage

Summary statistics include:

  • Total images processed
  • Number of successful predictions
  • Total carbon emissions (if tracked)
  • Average emissions per image
  • Car distance equivalent

Failed predictions are marked as "Error" in the table with the error message in the confidence column.

See Also

predict_single_image_gradio: Single image prediction
source.predict.predict_batch: Core batch prediction function

def model_evaluation_tab(carbon_display):
 906def model_evaluation_tab(carbon_display):
 907    """
 908    Create the Model Evaluation and Inference UI section.
 909
 910    Builds a comprehensive Gradio interface for model evaluation, metrics
 911    visualization, and real-time inference. Organized into three main areas:
 912    model selection, metrics visualization, and image prediction.
 913
 914    Parameters
 915    ----------
 916    carbon_display : gr.HTML
 917        The carbon counter display component to update after inference
 918        operations. Shows cumulative emissions across all tracked operations.
 919
 920    Returns
 921    -------
 922    list
 923        Empty list (kept for API consistency).
 924
 925    Notes
 926    -----
 927    The interface is organized into sections:
 928
 929    1. **Model Selection:**
 930       - Radio buttons to choose between best model and latest trained model
 931
 932    2. **Metrics Visualization Tabs:**
 933       - Confusion Matrix: Raw and normalized versions with instant toggle
 934       - Loss Curves: Training and validation loss over epochs
 935       - Accuracy Curves: Training and validation accuracy over epochs
 936       - Calibration Curves: Per-class calibration analysis
 937
 938    3. **Inference Tabs:**
 939       - Single Image: Upload and classify individual images
 940       - Batch Prediction: Process multiple images simultaneously
 941
 942    All expensive computations (confusion matrices, calibration curves) are
 943    cached and automatically invalidated when models are retrained.
 944
 945    Carbon emissions can be optionally tracked for all inference operations
 946    and are displayed in both absolute terms and car distance equivalents.
 947
 948    Examples
 949    --------
 950    >>> with gr.Blocks() as demo:
 951    ...     carbon_display = gr.HTML()
 952    ...     model_evaluation_tab(carbon_display)
 953    """
 954    with gr.Column():
 955        gr.Markdown("### 🔬 Model Evaluation & Inference")
 956        gr.Markdown(
 957            "Evaluate trained models, visualize metrics, \
 958                and make predictions on new images."
 959        )
 960
 961        gr.Markdown("#### 🧠 Model Selection")
 962        model_choice = gr.Radio(
 963            choices=list(get_available_models().keys()),
 964            value=list(get_available_models().keys())[0],
 965            label="Select Model",
 966            info="Choose between the best provided model or \
 967                your latest trained model",
 968        )
 969
 970        gr.Markdown("---")
 971
 972        gr.Markdown("#### 📈 Model Metrics & Visualizations")
 973
 974        with gr.Tabs():
 975            with gr.Tab("Confusion Matrix"):
 976                show_normalized = gr.Checkbox(
 977                    label="Show Normalized",
 978                    value=False,
 979                    info="Toggle between raw counts and normalized \
 980                        percentages",
 981                    visible=False,
 982                )
 983                cm_button = gr.Button(
 984                    "Generate Confusion Matrix", variant="primary"
 985                )
 986                cm_plot = gr.Plot(label="Confusion Matrix")
 987                cm_status = gr.Markdown("")
 988
 989                cm_button.click(
 990                    fn=generate_confusion_matrix,
 991                    inputs=[model_choice, show_normalized],
 992                    outputs=[cm_plot, cm_status, show_normalized],
 993                )
 994
 995                show_normalized.change(
 996                    fn=toggle_confusion_matrix,
 997                    inputs=[show_normalized],
 998                    outputs=[cm_plot],
 999                )
1000
1001            with gr.Tab("Loss Curves"):
1002                loss_button = gr.Button("Load Loss Curves", variant="primary")
1003                loss_plot = gr.Plot(label="Loss Curves")
1004                loss_status = gr.Markdown("")
1005
1006                loss_button.click(
1007                    fn=load_loss_curves,
1008                    inputs=[model_choice],
1009                    outputs=[loss_plot, loss_status],
1010                )
1011
1012            with gr.Tab("Accuracy Curves"):
1013                acc_button = gr.Button(
1014                    "Load Accuracy Curves", variant="primary"
1015                )
1016                acc_plot = gr.Plot(label="Accuracy Curves")
1017                acc_status = gr.Markdown("")
1018
1019                acc_button.click(
1020                    fn=load_accuracy_curves,
1021                    inputs=[model_choice],
1022                    outputs=[acc_plot, acc_status],
1023                )
1024
1025            with gr.Tab("Calibration Curves"):
1026                calib_button = gr.Button(
1027                    "Generate Calibration Curves", variant="primary"
1028                )
1029                calib_plot = gr.Plot(label="Calibration Curves")
1030                calib_status = gr.Markdown("")
1031
1032                calib_button.click(
1033                    fn=generate_calibration_curves,
1034                    inputs=[model_choice],
1035                    outputs=[calib_plot, calib_status],
1036                )
1037
1038        gr.Markdown("---")
1039
1040        gr.Markdown("#### 🔍 Image Prediction")
1041
1042        with gr.Tabs():
1043            with gr.Tab("Single Image"):
1044                gr.Markdown(
1045                    "Upload an image to classify it into one of \
1046                        the garbage categories."
1047                )
1048
1049                with gr.Row():
1050                    with gr.Column(scale=1):
1051                        single_image_input = gr.Image(
1052                            label="Upload Image", type="numpy", height=400
1053                        )
1054                        single_track_carbon = gr.Checkbox(
1055                            label="🌍 Track Carbon Emissions", value=True
1056                        )
1057                        single_predict_button = gr.Button(
1058                            "🔍 Predict", variant="primary", size="lg"
1059                        )
1060
1061                    with gr.Column(scale=1):
1062                        single_probs_plot = gr.Plot(
1063                            label="Class Probabilities"
1064                        )
1065
1066                single_result_text = gr.Markdown(
1067                    "Upload an image and click 'Predict'"
1068                )
1069
1070                single_predict_button.click(
1071                    fn=predict_single_image_gradio,
1072                    inputs=[
1073                        model_choice,
1074                        single_image_input,
1075                        carbon_display,
1076                        single_track_carbon,
1077                    ],
1078                    outputs=[
1079                        single_probs_plot,
1080                        single_result_text,
1081                        carbon_display
1082                    ],
1083                )
1084
1085            with gr.Tab("Batch Prediction"):
1086                gr.Markdown("Upload multiple images \
1087                    to classify them all at once.")
1088
1089                batch_image_input = gr.File(
1090                    label="Upload Images",
1091                    file_count="multiple", file_types=["image"]
1092                )
1093                batch_track_carbon = gr.Checkbox(
1094                    label="🌍 Track Carbon Emissions", value=True
1095                )
1096                batch_predict_button = gr.Button(
1097                    "🔍 Predict All", variant="primary", size="lg"
1098                )
1099
1100                batch_result_text = gr.Markdown(
1101                    "Upload images and click 'Predict All'"
1102                )
1103                batch_results_table = gr.Dataframe(
1104                    label="Prediction Results",
1105                    interactive=False, wrap=True
1106                )
1107
1108                batch_predict_button.click(
1109                    fn=predict_folder_gradio,
1110                    inputs=[
1111                        model_choice,
1112                        batch_image_input,
1113                        carbon_display,
1114                        batch_track_carbon,
1115                    ],
1116                    outputs=[
1117                        batch_results_table,
1118                        batch_result_text,
1119                        carbon_display
1120                    ],
1121                )
1122
1123        gr.Markdown("---")
1124        gr.Markdown(
1125            "**ℹ️ Info:** Carbon emissions are tracked for \
1126                inference operations and added to the total \
1127                    carbon footprint."
1128        )
1129
1130    return []

Create the Model Evaluation and Inference UI section.

Builds a comprehensive Gradio interface for model evaluation, metrics visualization, and real-time inference. Organized into three main areas: model selection, metrics visualization, and image prediction.

Parameters
  • carbon_display (gr.HTML): The carbon counter display component to update after inference operations. Shows cumulative emissions across all tracked operations.
Returns
  • list: Empty list (kept for API consistency).
Notes

The interface is organized into sections:

  1. Model Selection:

    • Radio buttons to choose between best model and latest trained model
  2. Metrics Visualization Tabs:

    • Confusion Matrix: Raw and normalized versions with instant toggle
    • Loss Curves: Training and validation loss over epochs
    • Accuracy Curves: Training and validation accuracy over epochs
    • Calibration Curves: Per-class calibration analysis
  3. Inference Tabs:

    • Single Image: Upload and classify individual images
    • Batch Prediction: Process multiple images simultaneously

All expensive computations (confusion matrices, calibration curves) are cached and automatically invalidated when models are retrained.

Carbon emissions can be optionally tracked for all inference operations and are displayed in both absolute terms and car distance equivalents.

Examples
>>> with gr.Blocks() as demo:
...     carbon_display = gr.HTML()
...     model_evaluation_tab(carbon_display)