app.sections.model_evaluation
Model Evaluation and Inference Interface for Gradio Application.
This module provides comprehensive model evaluation tools and inference capabilities through a Gradio interface. It includes visualization of model performance metrics, confusion matrices, calibration curves, and both single-image and batch prediction functionality.
Key features include:
- Multi-model support (best model and latest trained model)
- Cached computation of expensive metrics (confusion matrices, calibration)
- Automatic cache invalidation based on model modification time
- Real-time inference with carbon emissions tracking
- Training metrics visualization (loss/accuracy curves)
Notes
Cache files are stored in app/sections/cached_data/ and are automatically
invalidated when the corresponding model file is updated. This ensures
that evaluation metrics reflect the current model state while avoiding
redundant computation.
1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3""" 4Model Evaluation and Inference Interface for Gradio Application. 5 6This module provides comprehensive model evaluation tools and inference 7capabilities through a Gradio interface. It includes visualization of 8model performance metrics, confusion matrices, calibration curves, and 9both single-image and batch prediction functionality. 10 11Key features include: 12- Multi-model support (best model and latest trained model) 13- Cached computation of expensive metrics (confusion matrices, calibration) 14- Automatic cache invalidation based on model modification time 15- Real-time inference with carbon emissions tracking 16- Training metrics visualization (loss/accuracy curves) 17 18Notes 19----- 20Cache files are stored in `app/sections/cached_data/` and are automatically 21invalidated when the corresponding model file is updated. This ensures 22that evaluation metrics reflect the current model state while avoiding 23redundant computation. 24""" 25__docformat__ = "numpy" 26 27import gradio as gr 28from source.utils.carbon_utils import format_total_emissions_display 29from source.utils import config as cfg 30from source.utils.custom_classes.EvalAnalyzer import GarbageModelAnalyzer 31from source.predict import ( 32 predict_image, 33 predict_batch, 34 load_model_for_inference 35) 36from pathlib import Path 37import matplotlib.pyplot as plt 38import numpy as np 39import json 40import pandas as pd 41import pickle 42 43 44def get_emissions_path(): 45 """ 46 Get the path to the emissions CSV file. 47 48 Returns 49 ------- 50 pathlib.Path 51 Path object pointing to the emissions.csv file in the model directory. 52 53 Notes 54 ----- 55 The emissions file is stored in the same directory as the trained model 56 checkpoint, as defined in the global configuration. 57 """ 58 return Path(cfg.MODEL_PATH).parent / "emissions.csv" 59 60 61def get_available_models(): 62 """ 63 Get dictionary of available trained models. 64 65 Returns 66 ------- 67 dict of {str: str} 68 Dictionary mapping model display names to their checkpoint file paths. 69 Includes both the best provided model and the latest trained model. 70 71 Examples 72 -------- 73 >>> models = get_available_models() 74 >>> models 75 { 76 'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt', 77 'Latest Trained Model': 'models/resnet18_garbage.ckpt' 78 } 79 """ 80 models_dict = { 81 "Best Model (Provided)": str(Path( 82 "models/best/model_resnet18_garbage.ckpt" 83 )), 84 "Latest Trained Model": str(cfg.MODEL_PATH), 85 } 86 return models_dict 87 88 89# Setup cache directory 90CACHE_DIR = Path("app/sections/cached_data") 91CACHE_DIR.mkdir(parents=True, exist_ok=True) 92 93 94def is_cache_valid(cache_file, model_path): 95 """ 96 Check if cached file is newer than the model checkpoint. 97 98 Validates cache by comparing modification timestamps. Cache is considered 99 valid only if it was created/modified after the model checkpoint. 100 101 Parameters 102 ---------- 103 cache_file : pathlib.Path 104 Path to the cached file to validate. 105 model_path : str or pathlib.Path 106 Path to the model checkpoint file. 107 108 Returns 109 ------- 110 bool 111 True if cache exists and is newer than the model, False otherwise. 112 113 Notes 114 ----- 115 This function implements automatic cache invalidation to ensure that 116 evaluation metrics are regenerated when a model is retrained. 117 118 Examples 119 -------- 120 >>> cache_file = Path("cached_data/cm_raw_Latest.pkl") 121 >>> model_path = "models/resnet18_garbage.ckpt" 122 >>> is_cache_valid(cache_file, model_path) 123 False # Cache doesn't exist or is older than model 124 """ 125 if not cache_file.exists(): 126 return False 127 128 if not Path(model_path).exists(): 129 return False 130 131 cache_time = cache_file.stat().st_mtime 132 model_time = Path(model_path).stat().st_mtime 133 134 print(cache_file, model_path) 135 print(cache_time, model_time) 136 137 return cache_time > model_time # Cache is valid if newer than model 138 139 140# State to hold both confusion matrices 141confusion_matrices_state = { 142 "raw": None, 143 "normalized": None, 144 "model_choice": None, # Track which model generated these 145} 146 147 148# ======================== 149# METRICS FUNCTIONS 150# ======================== 151 152 153def generate_confusion_matrix( 154 model_choice, 155 show_normalized, 156 progress=gr.Progress() 157): 158 """ 159 Generate and cache both raw and normalized confusion matrices. 160 161 This function generates confusion matrices for the validation set, 162 caching both raw and normalized versions to enable instant toggling 163 without recomputation. Implements three-tier caching: memory, disk, 164 and automatic invalidation. 165 166 Parameters 167 ---------- 168 model_choice : str 169 Name of the selected model (from get_available_models()). 170 show_normalized : bool 171 Whether to display the normalized version initially. 172 progress : gr.Progress, optional 173 Gradio progress tracker for UI updates. 174 175 Returns 176 ------- 177 tuple of (matplotlib.figure.Figure or None, str, gr.update) 178 - Figure: The confusion matrix plot (raw or normalized) 179 - str: Status message describing the operation result 180 - gr.update: Gradio update object to control checkbox visibility 181 182 Notes 183 ----- 184 Caching strategy: 185 1. Check memory cache (fastest) 186 2. Check disk cache with validation (fast) 187 3. Regenerate if cache invalid or missing (slow) 188 189 Both raw and normalized matrices are always generated together and 190 stored separately to enable instant toggling via the checkbox. 191 192 The cache is automatically invalidated when the model file is modified, 193 ensuring metrics reflect the current model state. 194 195 See Also 196 -------- 197 toggle_confusion_matrix : Switch between raw/normalized without 198 regenerating. 199 is_cache_valid : Cache validation logic 200 """ 201 # global confusion_matrices_state # TODO: Revert this (uncomment) 202 203 if not model_choice: 204 return None, "Please select a model first", gr.update(visible=False) 205 206 # Get model path 207 models_dict = get_available_models() 208 model_path = models_dict.get(model_choice) 209 210 # Check if we already have matrices for this model in memory 211 if ( 212 confusion_matrices_state["model_choice"] == model_choice 213 and confusion_matrices_state["raw"] is not None 214 and confusion_matrices_state["normalized"] is not None 215 ): 216 217 selected_matrix = ( 218 confusion_matrices_state["normalized"] 219 if show_normalized 220 else confusion_matrices_state["raw"] 221 ) 222 matrix_type = "Normalized" if show_normalized else "Raw" 223 return ( 224 selected_matrix, 225 f"✅ {matrix_type} confusion matrix (from memory)", 226 gr.update(visible=True, interactive=True), 227 ) 228 229 # Check disk cache 230 cache_file_raw = \ 231 CACHE_DIR / f"cm_raw_{model_choice.replace(' ', '_')}.pkl" 232 cache_file_norm = \ 233 CACHE_DIR / f"cm_norm_{model_choice.replace(' ', '_')}.pkl" 234 235 if is_cache_valid(cache_file_raw, model_path) and is_cache_valid( 236 cache_file_norm, model_path 237 ): 238 try: 239 progress(0.2, desc="Loading from cache...") 240 with open(cache_file_raw, "rb") as f: 241 fig_raw = pickle.load(f) 242 with open(cache_file_norm, "rb") as f: 243 fig_norm = pickle.load(f) 244 245 confusion_matrices_state["raw"] = fig_raw 246 confusion_matrices_state["normalized"] = fig_norm 247 confusion_matrices_state["model_choice"] = model_choice 248 249 selected_matrix = fig_norm if show_normalized else fig_raw 250 matrix_type = "Normalized" if show_normalized else "Raw" 251 progress(1.0, desc="Done!") 252 return ( 253 selected_matrix, 254 f"✅ {matrix_type} confusion matrix loaded from cache", 255 gr.update(visible=True, interactive=True), 256 ) 257 except Exception as e: 258 print(f"[WARN] Failed to load cache: {e}") 259 260 # Generate new matrices 261 try: 262 progress(0.1, desc="Loading model...") 263 analyzer = GarbageModelAnalyzer() 264 analyzer.load_model(checkpoint_path=model_path) 265 266 progress(0.3, desc="Setting up data...") 267 analyzer.setup_data(batch_size=32) 268 269 progress(0.5, desc="Evaluating model...") 270 val_loader = analyzer.data_module.val_dataloader() 271 preds, labels, probs = analyzer.evaluate_loader(val_loader) 272 273 progress(0.7, desc="Generating confusion matrices...") 274 275 from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 276 277 num_classes = cfg.NUM_CLASSES 278 cm_raw = confusion_matrix( 279 labels.cpu().numpy(), 280 preds.cpu().numpy(), 281 labels=range(num_classes) 282 ) 283 cm_norm = cm_raw.astype("float") / cm_raw.sum(axis=1)[:, np.newaxis] 284 285 # Generate RAW matrix figure 286 fig_raw, ax_raw = plt.subplots(figsize=(10, 8)) 287 disp_raw = ConfusionMatrixDisplay( 288 confusion_matrix=cm_raw, display_labels=cfg.CLASS_NAMES 289 ) 290 disp_raw.plot(cmap=plt.cm.Blues, ax=ax_raw) 291 ax_raw.set_title("Confusion Matrix - Validation Set") 292 plt.tight_layout() 293 294 # Generate NORMALIZED matrix figure 295 fig_norm, ax_norm = plt.subplots(figsize=(10, 8)) 296 disp_norm = ConfusionMatrixDisplay( 297 confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES 298 ) 299 disp_norm.plot(cmap=plt.cm.Blues, ax=ax_norm) 300 ax_norm.set_title("Normalized Confusion Matrix - Validation Set") 301 plt.tight_layout() 302 303 # Save to cache 304 progress(0.9, desc="Saving to cache...") 305 with open(cache_file_raw, "wb") as f: 306 pickle.dump(fig_raw, f) 307 with open(cache_file_norm, "wb") as f: 308 pickle.dump(fig_norm, f) 309 310 # Update state 311 confusion_matrices_state["raw"] = fig_raw 312 confusion_matrices_state["normalized"] = fig_norm 313 confusion_matrices_state["model_choice"] = model_choice 314 315 selected_matrix = fig_norm if show_normalized else fig_raw 316 matrix_type = "Normalized" if show_normalized else "Raw" 317 318 progress(1.0, desc="Done!") 319 return ( 320 selected_matrix, 321 f"✅ {matrix_type} confusion matrix generated and cached", 322 gr.update(visible=True, interactive=True), 323 ) 324 325 except Exception as e: 326 return None, f"❌ Error: {str(e)}", gr.update(visible=False) 327 328 329def toggle_confusion_matrix(show_normalized): 330 """ 331 Toggle between raw and normalized confusion matrices instantly. 332 333 Switches the displayed confusion matrix without regenerating, using 334 cached versions from memory. This enables instant UI response to the 335 normalization checkbox. 336 337 Parameters 338 ---------- 339 show_normalized : bool 340 If True, return normalized matrix; if False, return raw matrix. 341 342 Returns 343 ------- 344 matplotlib.figure.Figure or None 345 The requested confusion matrix figure, or None if not available 346 in memory cache. 347 348 Notes 349 ----- 350 This function only accesses the in-memory cache. If matrices are not 351 yet generated, it returns None. The generate_confusion_matrix function 352 should be called first to populate the cache. 353 354 See Also 355 -------- 356 generate_confusion_matrix : Generate and cache both matrix versions 357 """ 358 # global confusion_matrices_state # TODO: Revert this (uncomment) 359 360 if show_normalized: 361 if confusion_matrices_state["normalized"] is not None: 362 return confusion_matrices_state["normalized"] 363 else: 364 if confusion_matrices_state["raw"] is not None: 365 return confusion_matrices_state["raw"] 366 367 return None 368 369 370def generate_calibration_curves(model_choice, progress=gr.Progress()): 371 """ 372 Generate and cache calibration curves for the selected model. 373 374 Calibration curves show how well predicted probabilities match actual 375 outcomes. This function generates curves for all classes and caches 376 the result for faster subsequent access. 377 378 Parameters 379 ---------- 380 model_choice : str 381 Name of the selected model (from get_available_models()). 382 progress : gr.Progress, optional 383 Gradio progress tracker for UI updates. 384 385 Returns 386 ------- 387 tuple of (matplotlib.figure.Figure or None, str) 388 - Figure: The calibration curves plot 389 - str: Status message describing the operation result 390 391 Notes 392 ----- 393 The calibration plot is generated using GarbageModelAnalyzer's 394 plot_calibration_curves method and cached to disk. Cache is 395 automatically invalidated when the model is retrained. 396 397 Well-calibrated models should have curves close to the diagonal, 398 indicating that predicted probabilities match actual frequencies. 399 400 See Also 401 -------- 402 is_cache_valid : Cache validation logic 403 GarbageModelAnalyzer.plot_calibration_curves : Core plotting function 404 """ 405 if not model_choice: 406 return None, "Please select a model first" 407 408 # Get model path 409 models_dict = get_available_models() 410 model_path = models_dict.get(model_choice) 411 412 # Check disk cache 413 cache_file = CACHE_DIR / f"calib_{model_choice.replace(' ', '_')}.pkl" 414 415 if is_cache_valid(cache_file, model_path): 416 try: 417 progress(0.2, desc="Loading from cache...") 418 with open(cache_file, "rb") as f: 419 fig = pickle.load(f) 420 progress(1.0, desc="Done!") 421 return fig, "✅ Calibration curves loaded from cache" 422 except Exception as e: 423 print(f"[WARN] Failed to load cache: {e}") 424 425 try: 426 progress(0.1, desc="Loading model...") 427 analyzer = GarbageModelAnalyzer() 428 analyzer.load_model(checkpoint_path=model_path) 429 430 progress(0.3, desc="Setting up data...") 431 analyzer.setup_data(batch_size=32) 432 433 progress(0.5, desc="Evaluating model...") 434 val_loader = analyzer.data_module.val_dataloader() 435 preds, labels, probs = analyzer.evaluate_loader(val_loader) 436 437 progress(0.8, desc="Generating calibration curves...") 438 439 analyzer.plot_calibration_curves(labels, probs) 440 fig = plt.gcf() 441 442 # Save to cache 443 progress(0.9, desc="Saving to cache...") 444 with open(cache_file, "wb") as f: 445 pickle.dump(fig, f) 446 447 progress(1.0, desc="Done!") 448 return fig, "✅ Calibration curves generated and cached" 449 450 except Exception as e: 451 return None, f"❌ Error: {str(e)}" 452 453 454def get_metrics_path_for_model(model_choice): 455 """ 456 Get the correct metrics.json path for the selected model. 457 458 Different models store their training metrics in different locations. 459 This function maps model choices to their corresponding metrics files. 460 461 Parameters 462 ---------- 463 model_choice : str 464 Name of the selected model (from get_available_models()). 465 466 Returns 467 ------- 468 pathlib.Path 469 Path to the metrics.json file for the selected model. 470 471 Notes 472 ----- 473 The best provided model has metrics in a fixed location, while the 474 latest trained model uses the configured loss curves path. 475 """ 476 if model_choice == "Best Model (Provided)": 477 return Path("models/best/performance/loss_curves/metrics.json") 478 else: 479 return Path(cfg.LOSS_CURVES_PATH) / "metrics.json" 480 481 482def load_loss_curves(model_choice): 483 """ 484 Load and plot training/validation loss curves from metrics.json. 485 486 Parameters 487 ---------- 488 model_choice : str 489 Name of the selected model (from get_available_models()). 490 491 Returns 492 ------- 493 tuple of (matplotlib.figure.Figure or None, str) 494 - Figure: Loss curves plot showing train and validation loss 495 - str: Status message describing the operation result 496 497 Notes 498 ----- 499 Loss curves show how training and validation loss evolved during 500 training. Diverging curves may indicate overfitting, while parallel 501 curves suggest good generalization. 502 503 The plot includes: 504 - Blue line with circles: Training loss 505 - Red line with squares: Validation loss 506 - Grid for easier reading 507 508 See Also 509 -------- 510 load_accuracy_curves : Load accuracy metrics instead of loss 511 get_metrics_path_for_model : Determine metrics file location 512 """ 513 try: 514 metrics_path = get_metrics_path_for_model(model_choice) 515 516 if not metrics_path.exists(): 517 return ( 518 None, 519 f"❌ No training metrics found for \ 520 {model_choice}. Path: {metrics_path}", 521 ) 522 523 with open(metrics_path, "r") as f: 524 data = json.load(f) 525 526 train_losses = data.get("train_losses", []) 527 val_losses = data.get("val_losses", []) 528 529 if not train_losses and not val_losses: 530 return None, "❌ No loss data available" 531 532 fig, ax = plt.subplots(figsize=(10, 6)) 533 epochs = range(1, len(train_losses) + 1) 534 535 if train_losses: 536 ax.plot( 537 epochs, 538 train_losses, 539 "b-o", 540 label="Train Loss", 541 linewidth=2 542 ) 543 if val_losses: 544 ax.plot( 545 epochs, 546 val_losses, 547 "r-s", 548 label="Validation Loss", 549 linewidth=2 550 ) 551 552 ax.set_xlabel("Epoch", fontsize=12) 553 ax.set_ylabel("Loss", fontsize=12) 554 ax.set_title(f"Loss Curves - \ 555 {model_choice}", fontsize=14, fontweight="bold") 556 ax.legend(fontsize=10) 557 ax.grid(True, alpha=0.3) 558 plt.tight_layout() 559 560 return fig, f"✅ Loss curves loaded successfully from {model_choice}" 561 562 except Exception as e: 563 return None, f"❌ Error loading loss curves: {str(e)}" 564 565 566def load_accuracy_curves(model_choice): 567 """ 568 Load and plot training/validation accuracy curves from metrics.json. 569 570 Parameters 571 ---------- 572 model_choice : str 573 Name of the selected model (from get_available_models()). 574 575 Returns 576 ------- 577 tuple of (matplotlib.figure.Figure or None, str) 578 - Figure: Accuracy curves plot showing train and validation accuracy 579 - str: Status message describing the operation result 580 581 Notes 582 ----- 583 Accuracy curves show how classification accuracy evolved during training. 584 The plot includes: 585 - Blue line with circles: Training accuracy 586 - Red line with squares: Validation accuracy 587 - Y-axis limited to [0, 1] for consistency 588 - Grid for easier reading 589 590 A large gap between train and validation accuracy indicates overfitting. 591 592 See Also 593 -------- 594 load_loss_curves : Load loss metrics instead of accuracy 595 get_metrics_path_for_model : Determine metrics file location 596 """ 597 try: 598 metrics_path = get_metrics_path_for_model(model_choice) 599 600 if not metrics_path.exists(): 601 return ( 602 None, 603 f"❌ No training metrics found for \ 604 {model_choice}. Path: {metrics_path}", 605 ) 606 607 with open(metrics_path, "r") as f: 608 data = json.load(f) 609 610 train_accs = data.get("train_accs", []) 611 val_accs = data.get("val_accs", []) 612 613 if not train_accs and not val_accs: 614 return None, "❌ No accuracy data available" 615 616 fig, ax = plt.subplots(figsize=(10, 6)) 617 618 if train_accs: 619 ax.plot( 620 range(1, len(train_accs) + 1), 621 train_accs, 622 "b-o", 623 label="Train Accuracy", 624 linewidth=2, 625 ) 626 if val_accs: 627 ax.plot( 628 range(1, len(val_accs) + 1), 629 val_accs, 630 "r-s", 631 label="Validation Accuracy", 632 linewidth=2, 633 ) 634 635 ax.set_xlabel("Epoch", fontsize=12) 636 ax.set_ylabel("Accuracy", fontsize=12) 637 ax.set_title( 638 f"Accuracy Curves - \ 639 {model_choice}", fontsize=14, fontweight="bold" 640 ) 641 ax.legend(fontsize=10) 642 ax.grid(True, alpha=0.3) 643 ax.set_ylim([0, 1]) 644 plt.tight_layout() 645 646 return fig, f"✅ Accuracy curves \ 647 loaded successfully from {model_choice}" 648 649 except Exception as e: 650 return None, f"❌ Error loading \ 651 accuracy curves: {str(e)}" 652 653 654# ======================== 655# PREDICTION FUNCTIONS 656# ======================== 657 658 659def predict_single_image_gradio( 660 model_choice, image, carbon_display_text, track_carbon=True 661): 662 """ 663 Gradio wrapper for single image prediction with visualization. 664 665 Performs inference on a single image and generates a horizontal bar 666 chart of class probabilities, highlighting the predicted class. 667 668 Parameters 669 ---------- 670 model_choice : str 671 Name of the selected model (from get_available_models()). 672 image : np.ndarray 673 Input image as numpy array (from gr.Image component). 674 carbon_display_text : str 675 Current HTML string for the carbon display (to be updated). 676 track_carbon : bool, optional 677 Whether to track carbon emissions for this prediction, 678 by default True. 679 680 Returns 681 ------- 682 tuple of (matplotlib.figure.Figure or None, str, str) 683 - Figure: Bar chart of class probabilities 684 - str: Markdown-formatted prediction results and statistics 685 - str: Updated HTML for carbon display 686 687 Notes 688 ----- 689 The probability bar chart uses: 690 - Green bar for the predicted class 691 - Sky blue bars for other classes 692 - Percentage labels on each bar 693 694 If carbon tracking is enabled, emissions are added to the cumulative 695 total and the carbon display is updated. 696 697 See Also 698 -------- 699 predict_batch : Batch prediction on multiple images 700 source.predict.predict_image : Core prediction function 701 """ 702 if image is None: 703 return None, "Please upload an image", carbon_display_text 704 705 if not model_choice: 706 return None, "Please select a model first", carbon_display_text 707 708 try: 709 models_dict = get_available_models() 710 model_path = models_dict.get(model_choice) 711 model, device, transform = load_model_for_inference( 712 model_path=model_path 713 ) 714 715 result = predict_image( 716 image, 717 model=model, 718 transform=transform, 719 device=device, 720 track_carbon=track_carbon, 721 ) 722 723 fig, ax = plt.subplots(figsize=(10, 6)) 724 probs_list = [result["probabilities"][cls] for cls in cfg.CLASS_NAMES] 725 pred_idx = result["predicted_idx"] 726 aux_list = range(len(cfg.CLASS_NAMES)) 727 colors = [ 728 "green" if i == pred_idx else "skyblue" for i in aux_list 729 ] 730 bars = ax.barh(cfg.CLASS_NAMES, probs_list, color=colors) 731 ax.set_xlabel("Probability", fontsize=12) 732 ax.set_title( 733 f'Prediction Probabilities\nPredicted Class: \ 734 {result["predicted_class"]}', 735 fontsize=14, 736 fontweight="bold", 737 ) 738 ax.set_xlim([0, 1]) 739 740 for i, (bar, prob) in enumerate(zip(bars, probs_list)): 741 ax.text( 742 prob + 0.02, 743 bar.get_y() + bar.get_height() / 2, 744 f"{prob*100:.1f}%", 745 va="center", 746 fontsize=10, 747 ) 748 749 plt.tight_layout() 750 751 result_text = f"### 🎯 Prediction: \ 752 **{result['predicted_class']}**\n\n" 753 result_text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n" 754 result_text += "**All Probabilities:**\n" 755 for class_name, prob in result["probabilities"].items(): 756 emoji = "🏆" if class_name == result["predicted_class"] \ 757 else " " 758 result_text += f"{emoji} {class_name}: {prob*100:.2f}%\n" 759 760 updated_carbon_display = carbon_display_text 761 if result["emissions"]: 762 emissions = result["emissions"] 763 result_text += "\n\n### 🌍 Carbon Footprint\n" 764 result_text += f"- **Emissions:** \ 765 {emissions['emissions_g']:.4f}g CO₂eq\n" 766 result_text += f"- **🚗 Car equivalent:** \ 767 {emissions['car_distance_formatted']} driven\n" 768 updated_carbon_display = format_total_emissions_display( 769 get_emissions_path() 770 ) 771 772 return fig, result_text, updated_carbon_display 773 774 except Exception as e: 775 return None, f"❌ Error: {str(e)}", carbon_display_text 776 777 778def predict_folder_gradio( 779 model_choice, files, carbon_display_text, track_carbon=True 780): 781 """ 782 Gradio wrapper for batch prediction on multiple images. 783 784 Performs inference on multiple images simultaneously and returns 785 results in a tabular format with per-image predictions and overall 786 statistics. 787 788 Parameters 789 ---------- 790 model_choice : str 791 Name of the selected model (from get_available_models()). 792 files : list of gr.File 793 List of uploaded file objects from Gradio file input. 794 carbon_display_text : str 795 Current HTML string for the carbon display (to be updated). 796 track_carbon : bool, optional 797 Whether to track carbon emissions for this batch, by default True. 798 799 Returns 800 ------- 801 tuple of (pd.DataFrame or None, str, str) 802 - DataFrame: Table with filename, predicted class, and confidence 803 for each image 804 - str: Markdown-formatted summary statistics and carbon footprint 805 - str: Updated HTML for carbon display 806 807 Notes 808 ----- 809 The results table includes: 810 - Filename: Original image filename 811 - Predicted Class: Predicted garbage category 812 - Confidence (%): Prediction confidence as percentage 813 814 Summary statistics include: 815 - Total images processed 816 - Number of successful predictions 817 - Total carbon emissions (if tracked) 818 - Average emissions per image 819 - Car distance equivalent 820 821 Failed predictions are marked as "Error" in the table with the 822 error message in the confidence column. 823 824 See Also 825 -------- 826 predict_single_image_gradio : Single image prediction 827 source.predict.predict_batch : Core batch prediction function 828 """ 829 if not files or len(files) == 0: 830 return None, "Please upload images", carbon_display_text 831 832 if not model_choice: 833 return None, "Please select a model first", carbon_display_text 834 835 try: 836 models_dict = get_available_models() 837 model_path = models_dict.get(model_choice) 838 model, device, transform = load_model_for_inference( 839 model_path=model_path 840 ) 841 842 image_paths = [file.name for file in files] 843 844 batch_result = predict_batch( 845 image_paths, 846 model=model, 847 transform=transform, 848 device=device, 849 track_carbon=track_carbon, 850 ) 851 852 df_data = [] 853 for result in batch_result["results"]: 854 if result["status"] == "success": 855 df_data.append( 856 { 857 "Filename": result["filename"], 858 "Predicted Class": result["predicted_class"], 859 "Confidence (%)": f"{result['confidence']*100:.2f}", 860 } 861 ) 862 else: 863 df_data.append( 864 { 865 "Filename": result["filename"], 866 "Predicted Class": "Error", 867 "Confidence (%)": result["error"], 868 } 869 ) 870 871 df_results = pd.DataFrame(df_data) 872 873 summary = batch_result["summary"] 874 result_text = "### 📊 Batch Prediction Results\n\n" 875 result_text += f"**Total images processed:** \ 876 {summary['total_images']}\n" 877 result_text += f"**Successful predictions:** \ 878 {summary['successful']}\n\n" 879 880 updated_carbon_display = carbon_display_text 881 if batch_result["emissions"]: 882 emissions = batch_result["emissions"] 883 result_text += "### 🌍 Carbon Footprint\n" 884 result_text += f"- **Emissions:** \ 885 {emissions['emissions_g']:.4f}g CO₂eq\n" 886 result_text += f"- **🚗 Car equivalent:** \ 887 {emissions['car_distance_formatted']} driven\n" 888 result_text += f"- **Avg per image:** \ 889 {emissions['emissions_per_image_g']:.4f}g CO₂eq\n" 890 updated_carbon_display = format_total_emissions_display( 891 get_emissions_path() 892 ) 893 894 return df_results, result_text, updated_carbon_display 895 896 except Exception as e: 897 return None, f"❌ Error: {str(e)}", carbon_display_text 898 899 900# ======================== 901# UI LAYOUT 902# ======================== 903 904 905def model_evaluation_tab(carbon_display): 906 """ 907 Create the Model Evaluation and Inference UI section. 908 909 Builds a comprehensive Gradio interface for model evaluation, metrics 910 visualization, and real-time inference. Organized into three main areas: 911 model selection, metrics visualization, and image prediction. 912 913 Parameters 914 ---------- 915 carbon_display : gr.HTML 916 The carbon counter display component to update after inference 917 operations. Shows cumulative emissions across all tracked operations. 918 919 Returns 920 ------- 921 list 922 Empty list (kept for API consistency). 923 924 Notes 925 ----- 926 The interface is organized into sections: 927 928 1. **Model Selection:** 929 - Radio buttons to choose between best model and latest trained model 930 931 2. **Metrics Visualization Tabs:** 932 - Confusion Matrix: Raw and normalized versions with instant toggle 933 - Loss Curves: Training and validation loss over epochs 934 - Accuracy Curves: Training and validation accuracy over epochs 935 - Calibration Curves: Per-class calibration analysis 936 937 3. **Inference Tabs:** 938 - Single Image: Upload and classify individual images 939 - Batch Prediction: Process multiple images simultaneously 940 941 All expensive computations (confusion matrices, calibration curves) are 942 cached and automatically invalidated when models are retrained. 943 944 Carbon emissions can be optionally tracked for all inference operations 945 and are displayed in both absolute terms and car distance equivalents. 946 947 Examples 948 -------- 949 >>> with gr.Blocks() as demo: 950 ... carbon_display = gr.HTML() 951 ... model_evaluation_tab(carbon_display) 952 """ 953 with gr.Column(): 954 gr.Markdown("### 🔬 Model Evaluation & Inference") 955 gr.Markdown( 956 "Evaluate trained models, visualize metrics, \ 957 and make predictions on new images." 958 ) 959 960 gr.Markdown("#### 🧠 Model Selection") 961 model_choice = gr.Radio( 962 choices=list(get_available_models().keys()), 963 value=list(get_available_models().keys())[0], 964 label="Select Model", 965 info="Choose between the best provided model or \ 966 your latest trained model", 967 ) 968 969 gr.Markdown("---") 970 971 gr.Markdown("#### 📈 Model Metrics & Visualizations") 972 973 with gr.Tabs(): 974 with gr.Tab("Confusion Matrix"): 975 show_normalized = gr.Checkbox( 976 label="Show Normalized", 977 value=False, 978 info="Toggle between raw counts and normalized \ 979 percentages", 980 visible=False, 981 ) 982 cm_button = gr.Button( 983 "Generate Confusion Matrix", variant="primary" 984 ) 985 cm_plot = gr.Plot(label="Confusion Matrix") 986 cm_status = gr.Markdown("") 987 988 cm_button.click( 989 fn=generate_confusion_matrix, 990 inputs=[model_choice, show_normalized], 991 outputs=[cm_plot, cm_status, show_normalized], 992 ) 993 994 show_normalized.change( 995 fn=toggle_confusion_matrix, 996 inputs=[show_normalized], 997 outputs=[cm_plot], 998 ) 999 1000 with gr.Tab("Loss Curves"): 1001 loss_button = gr.Button("Load Loss Curves", variant="primary") 1002 loss_plot = gr.Plot(label="Loss Curves") 1003 loss_status = gr.Markdown("") 1004 1005 loss_button.click( 1006 fn=load_loss_curves, 1007 inputs=[model_choice], 1008 outputs=[loss_plot, loss_status], 1009 ) 1010 1011 with gr.Tab("Accuracy Curves"): 1012 acc_button = gr.Button( 1013 "Load Accuracy Curves", variant="primary" 1014 ) 1015 acc_plot = gr.Plot(label="Accuracy Curves") 1016 acc_status = gr.Markdown("") 1017 1018 acc_button.click( 1019 fn=load_accuracy_curves, 1020 inputs=[model_choice], 1021 outputs=[acc_plot, acc_status], 1022 ) 1023 1024 with gr.Tab("Calibration Curves"): 1025 calib_button = gr.Button( 1026 "Generate Calibration Curves", variant="primary" 1027 ) 1028 calib_plot = gr.Plot(label="Calibration Curves") 1029 calib_status = gr.Markdown("") 1030 1031 calib_button.click( 1032 fn=generate_calibration_curves, 1033 inputs=[model_choice], 1034 outputs=[calib_plot, calib_status], 1035 ) 1036 1037 gr.Markdown("---") 1038 1039 gr.Markdown("#### 🔍 Image Prediction") 1040 1041 with gr.Tabs(): 1042 with gr.Tab("Single Image"): 1043 gr.Markdown( 1044 "Upload an image to classify it into one of \ 1045 the garbage categories." 1046 ) 1047 1048 with gr.Row(): 1049 with gr.Column(scale=1): 1050 single_image_input = gr.Image( 1051 label="Upload Image", type="numpy", height=400 1052 ) 1053 single_track_carbon = gr.Checkbox( 1054 label="🌍 Track Carbon Emissions", value=True 1055 ) 1056 single_predict_button = gr.Button( 1057 "🔍 Predict", variant="primary", size="lg" 1058 ) 1059 1060 with gr.Column(scale=1): 1061 single_probs_plot = gr.Plot( 1062 label="Class Probabilities" 1063 ) 1064 1065 single_result_text = gr.Markdown( 1066 "Upload an image and click 'Predict'" 1067 ) 1068 1069 single_predict_button.click( 1070 fn=predict_single_image_gradio, 1071 inputs=[ 1072 model_choice, 1073 single_image_input, 1074 carbon_display, 1075 single_track_carbon, 1076 ], 1077 outputs=[ 1078 single_probs_plot, 1079 single_result_text, 1080 carbon_display 1081 ], 1082 ) 1083 1084 with gr.Tab("Batch Prediction"): 1085 gr.Markdown("Upload multiple images \ 1086 to classify them all at once.") 1087 1088 batch_image_input = gr.File( 1089 label="Upload Images", 1090 file_count="multiple", file_types=["image"] 1091 ) 1092 batch_track_carbon = gr.Checkbox( 1093 label="🌍 Track Carbon Emissions", value=True 1094 ) 1095 batch_predict_button = gr.Button( 1096 "🔍 Predict All", variant="primary", size="lg" 1097 ) 1098 1099 batch_result_text = gr.Markdown( 1100 "Upload images and click 'Predict All'" 1101 ) 1102 batch_results_table = gr.Dataframe( 1103 label="Prediction Results", 1104 interactive=False, wrap=True 1105 ) 1106 1107 batch_predict_button.click( 1108 fn=predict_folder_gradio, 1109 inputs=[ 1110 model_choice, 1111 batch_image_input, 1112 carbon_display, 1113 batch_track_carbon, 1114 ], 1115 outputs=[ 1116 batch_results_table, 1117 batch_result_text, 1118 carbon_display 1119 ], 1120 ) 1121 1122 gr.Markdown("---") 1123 gr.Markdown( 1124 "**ℹ️ Info:** Carbon emissions are tracked for \ 1125 inference operations and added to the total \ 1126 carbon footprint." 1127 ) 1128 1129 return []
45def get_emissions_path(): 46 """ 47 Get the path to the emissions CSV file. 48 49 Returns 50 ------- 51 pathlib.Path 52 Path object pointing to the emissions.csv file in the model directory. 53 54 Notes 55 ----- 56 The emissions file is stored in the same directory as the trained model 57 checkpoint, as defined in the global configuration. 58 """ 59 return Path(cfg.MODEL_PATH).parent / "emissions.csv"
Get the path to the emissions CSV file.
Returns
- pathlib.Path: Path object pointing to the emissions.csv file in the model directory.
Notes
The emissions file is stored in the same directory as the trained model checkpoint, as defined in the global configuration.
62def get_available_models(): 63 """ 64 Get dictionary of available trained models. 65 66 Returns 67 ------- 68 dict of {str: str} 69 Dictionary mapping model display names to their checkpoint file paths. 70 Includes both the best provided model and the latest trained model. 71 72 Examples 73 -------- 74 >>> models = get_available_models() 75 >>> models 76 { 77 'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt', 78 'Latest Trained Model': 'models/resnet18_garbage.ckpt' 79 } 80 """ 81 models_dict = { 82 "Best Model (Provided)": str(Path( 83 "models/best/model_resnet18_garbage.ckpt" 84 )), 85 "Latest Trained Model": str(cfg.MODEL_PATH), 86 } 87 return models_dict
Get dictionary of available trained models.
Returns
- dict of {str (str}): Dictionary mapping model display names to their checkpoint file paths. Includes both the best provided model and the latest trained model.
Examples
>>> models = get_available_models()
>>> models
{
'Best Model (Provided)': 'models/best/model_resnet18_garbage.ckpt',
'Latest Trained Model': 'models/resnet18_garbage.ckpt'
}
95def is_cache_valid(cache_file, model_path): 96 """ 97 Check if cached file is newer than the model checkpoint. 98 99 Validates cache by comparing modification timestamps. Cache is considered 100 valid only if it was created/modified after the model checkpoint. 101 102 Parameters 103 ---------- 104 cache_file : pathlib.Path 105 Path to the cached file to validate. 106 model_path : str or pathlib.Path 107 Path to the model checkpoint file. 108 109 Returns 110 ------- 111 bool 112 True if cache exists and is newer than the model, False otherwise. 113 114 Notes 115 ----- 116 This function implements automatic cache invalidation to ensure that 117 evaluation metrics are regenerated when a model is retrained. 118 119 Examples 120 -------- 121 >>> cache_file = Path("cached_data/cm_raw_Latest.pkl") 122 >>> model_path = "models/resnet18_garbage.ckpt" 123 >>> is_cache_valid(cache_file, model_path) 124 False # Cache doesn't exist or is older than model 125 """ 126 if not cache_file.exists(): 127 return False 128 129 if not Path(model_path).exists(): 130 return False 131 132 cache_time = cache_file.stat().st_mtime 133 model_time = Path(model_path).stat().st_mtime 134 135 print(cache_file, model_path) 136 print(cache_time, model_time) 137 138 return cache_time > model_time # Cache is valid if newer than model
Check if cached file is newer than the model checkpoint.
Validates cache by comparing modification timestamps. Cache is considered valid only if it was created/modified after the model checkpoint.
Parameters
- cache_file (pathlib.Path): Path to the cached file to validate.
- model_path (str or pathlib.Path): Path to the model checkpoint file.
Returns
- bool: True if cache exists and is newer than the model, False otherwise.
Notes
This function implements automatic cache invalidation to ensure that evaluation metrics are regenerated when a model is retrained.
Examples
>>> cache_file = Path("cached_data/cm_raw_Latest.pkl")
>>> model_path = "models/resnet18_garbage.ckpt"
>>> is_cache_valid(cache_file, model_path)
False # Cache doesn't exist or is older than model
154def generate_confusion_matrix( 155 model_choice, 156 show_normalized, 157 progress=gr.Progress() 158): 159 """ 160 Generate and cache both raw and normalized confusion matrices. 161 162 This function generates confusion matrices for the validation set, 163 caching both raw and normalized versions to enable instant toggling 164 without recomputation. Implements three-tier caching: memory, disk, 165 and automatic invalidation. 166 167 Parameters 168 ---------- 169 model_choice : str 170 Name of the selected model (from get_available_models()). 171 show_normalized : bool 172 Whether to display the normalized version initially. 173 progress : gr.Progress, optional 174 Gradio progress tracker for UI updates. 175 176 Returns 177 ------- 178 tuple of (matplotlib.figure.Figure or None, str, gr.update) 179 - Figure: The confusion matrix plot (raw or normalized) 180 - str: Status message describing the operation result 181 - gr.update: Gradio update object to control checkbox visibility 182 183 Notes 184 ----- 185 Caching strategy: 186 1. Check memory cache (fastest) 187 2. Check disk cache with validation (fast) 188 3. Regenerate if cache invalid or missing (slow) 189 190 Both raw and normalized matrices are always generated together and 191 stored separately to enable instant toggling via the checkbox. 192 193 The cache is automatically invalidated when the model file is modified, 194 ensuring metrics reflect the current model state. 195 196 See Also 197 -------- 198 toggle_confusion_matrix : Switch between raw/normalized without 199 regenerating. 200 is_cache_valid : Cache validation logic 201 """ 202 # global confusion_matrices_state # TODO: Revert this (uncomment) 203 204 if not model_choice: 205 return None, "Please select a model first", gr.update(visible=False) 206 207 # Get model path 208 models_dict = get_available_models() 209 model_path = models_dict.get(model_choice) 210 211 # Check if we already have matrices for this model in memory 212 if ( 213 confusion_matrices_state["model_choice"] == model_choice 214 and confusion_matrices_state["raw"] is not None 215 and confusion_matrices_state["normalized"] is not None 216 ): 217 218 selected_matrix = ( 219 confusion_matrices_state["normalized"] 220 if show_normalized 221 else confusion_matrices_state["raw"] 222 ) 223 matrix_type = "Normalized" if show_normalized else "Raw" 224 return ( 225 selected_matrix, 226 f"✅ {matrix_type} confusion matrix (from memory)", 227 gr.update(visible=True, interactive=True), 228 ) 229 230 # Check disk cache 231 cache_file_raw = \ 232 CACHE_DIR / f"cm_raw_{model_choice.replace(' ', '_')}.pkl" 233 cache_file_norm = \ 234 CACHE_DIR / f"cm_norm_{model_choice.replace(' ', '_')}.pkl" 235 236 if is_cache_valid(cache_file_raw, model_path) and is_cache_valid( 237 cache_file_norm, model_path 238 ): 239 try: 240 progress(0.2, desc="Loading from cache...") 241 with open(cache_file_raw, "rb") as f: 242 fig_raw = pickle.load(f) 243 with open(cache_file_norm, "rb") as f: 244 fig_norm = pickle.load(f) 245 246 confusion_matrices_state["raw"] = fig_raw 247 confusion_matrices_state["normalized"] = fig_norm 248 confusion_matrices_state["model_choice"] = model_choice 249 250 selected_matrix = fig_norm if show_normalized else fig_raw 251 matrix_type = "Normalized" if show_normalized else "Raw" 252 progress(1.0, desc="Done!") 253 return ( 254 selected_matrix, 255 f"✅ {matrix_type} confusion matrix loaded from cache", 256 gr.update(visible=True, interactive=True), 257 ) 258 except Exception as e: 259 print(f"[WARN] Failed to load cache: {e}") 260 261 # Generate new matrices 262 try: 263 progress(0.1, desc="Loading model...") 264 analyzer = GarbageModelAnalyzer() 265 analyzer.load_model(checkpoint_path=model_path) 266 267 progress(0.3, desc="Setting up data...") 268 analyzer.setup_data(batch_size=32) 269 270 progress(0.5, desc="Evaluating model...") 271 val_loader = analyzer.data_module.val_dataloader() 272 preds, labels, probs = analyzer.evaluate_loader(val_loader) 273 274 progress(0.7, desc="Generating confusion matrices...") 275 276 from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 277 278 num_classes = cfg.NUM_CLASSES 279 cm_raw = confusion_matrix( 280 labels.cpu().numpy(), 281 preds.cpu().numpy(), 282 labels=range(num_classes) 283 ) 284 cm_norm = cm_raw.astype("float") / cm_raw.sum(axis=1)[:, np.newaxis] 285 286 # Generate RAW matrix figure 287 fig_raw, ax_raw = plt.subplots(figsize=(10, 8)) 288 disp_raw = ConfusionMatrixDisplay( 289 confusion_matrix=cm_raw, display_labels=cfg.CLASS_NAMES 290 ) 291 disp_raw.plot(cmap=plt.cm.Blues, ax=ax_raw) 292 ax_raw.set_title("Confusion Matrix - Validation Set") 293 plt.tight_layout() 294 295 # Generate NORMALIZED matrix figure 296 fig_norm, ax_norm = plt.subplots(figsize=(10, 8)) 297 disp_norm = ConfusionMatrixDisplay( 298 confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES 299 ) 300 disp_norm.plot(cmap=plt.cm.Blues, ax=ax_norm) 301 ax_norm.set_title("Normalized Confusion Matrix - Validation Set") 302 plt.tight_layout() 303 304 # Save to cache 305 progress(0.9, desc="Saving to cache...") 306 with open(cache_file_raw, "wb") as f: 307 pickle.dump(fig_raw, f) 308 with open(cache_file_norm, "wb") as f: 309 pickle.dump(fig_norm, f) 310 311 # Update state 312 confusion_matrices_state["raw"] = fig_raw 313 confusion_matrices_state["normalized"] = fig_norm 314 confusion_matrices_state["model_choice"] = model_choice 315 316 selected_matrix = fig_norm if show_normalized else fig_raw 317 matrix_type = "Normalized" if show_normalized else "Raw" 318 319 progress(1.0, desc="Done!") 320 return ( 321 selected_matrix, 322 f"✅ {matrix_type} confusion matrix generated and cached", 323 gr.update(visible=True, interactive=True), 324 ) 325 326 except Exception as e: 327 return None, f"❌ Error: {str(e)}", gr.update(visible=False)
Generate and cache both raw and normalized confusion matrices.
This function generates confusion matrices for the validation set, caching both raw and normalized versions to enable instant toggling without recomputation. Implements three-tier caching: memory, disk, and automatic invalidation.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
- show_normalized (bool): Whether to display the normalized version initially.
- progress (gr.Progress, optional): Gradio progress tracker for UI updates.
Returns
- tuple of (matplotlib.figure.Figure or None, str, gr.update): - Figure: The confusion matrix plot (raw or normalized)
- str: Status message describing the operation result
- gr.update: Gradio update object to control checkbox visibility
Notes
Caching strategy:
- Check memory cache (fastest)
- Check disk cache with validation (fast)
- Regenerate if cache invalid or missing (slow)
Both raw and normalized matrices are always generated together and stored separately to enable instant toggling via the checkbox.
The cache is automatically invalidated when the model file is modified, ensuring metrics reflect the current model state.
See Also
toggle_confusion_matrix: Switch between raw/normalized without
regenerating.
is_cache_valid: Cache validation logic
330def toggle_confusion_matrix(show_normalized): 331 """ 332 Toggle between raw and normalized confusion matrices instantly. 333 334 Switches the displayed confusion matrix without regenerating, using 335 cached versions from memory. This enables instant UI response to the 336 normalization checkbox. 337 338 Parameters 339 ---------- 340 show_normalized : bool 341 If True, return normalized matrix; if False, return raw matrix. 342 343 Returns 344 ------- 345 matplotlib.figure.Figure or None 346 The requested confusion matrix figure, or None if not available 347 in memory cache. 348 349 Notes 350 ----- 351 This function only accesses the in-memory cache. If matrices are not 352 yet generated, it returns None. The generate_confusion_matrix function 353 should be called first to populate the cache. 354 355 See Also 356 -------- 357 generate_confusion_matrix : Generate and cache both matrix versions 358 """ 359 # global confusion_matrices_state # TODO: Revert this (uncomment) 360 361 if show_normalized: 362 if confusion_matrices_state["normalized"] is not None: 363 return confusion_matrices_state["normalized"] 364 else: 365 if confusion_matrices_state["raw"] is not None: 366 return confusion_matrices_state["raw"] 367 368 return None
Toggle between raw and normalized confusion matrices instantly.
Switches the displayed confusion matrix without regenerating, using cached versions from memory. This enables instant UI response to the normalization checkbox.
Parameters
- show_normalized (bool): If True, return normalized matrix; if False, return raw matrix.
Returns
- matplotlib.figure.Figure or None: The requested confusion matrix figure, or None if not available in memory cache.
Notes
This function only accesses the in-memory cache. If matrices are not yet generated, it returns None. The generate_confusion_matrix function should be called first to populate the cache.
See Also
generate_confusion_matrix: Generate and cache both matrix versions
371def generate_calibration_curves(model_choice, progress=gr.Progress()): 372 """ 373 Generate and cache calibration curves for the selected model. 374 375 Calibration curves show how well predicted probabilities match actual 376 outcomes. This function generates curves for all classes and caches 377 the result for faster subsequent access. 378 379 Parameters 380 ---------- 381 model_choice : str 382 Name of the selected model (from get_available_models()). 383 progress : gr.Progress, optional 384 Gradio progress tracker for UI updates. 385 386 Returns 387 ------- 388 tuple of (matplotlib.figure.Figure or None, str) 389 - Figure: The calibration curves plot 390 - str: Status message describing the operation result 391 392 Notes 393 ----- 394 The calibration plot is generated using GarbageModelAnalyzer's 395 plot_calibration_curves method and cached to disk. Cache is 396 automatically invalidated when the model is retrained. 397 398 Well-calibrated models should have curves close to the diagonal, 399 indicating that predicted probabilities match actual frequencies. 400 401 See Also 402 -------- 403 is_cache_valid : Cache validation logic 404 GarbageModelAnalyzer.plot_calibration_curves : Core plotting function 405 """ 406 if not model_choice: 407 return None, "Please select a model first" 408 409 # Get model path 410 models_dict = get_available_models() 411 model_path = models_dict.get(model_choice) 412 413 # Check disk cache 414 cache_file = CACHE_DIR / f"calib_{model_choice.replace(' ', '_')}.pkl" 415 416 if is_cache_valid(cache_file, model_path): 417 try: 418 progress(0.2, desc="Loading from cache...") 419 with open(cache_file, "rb") as f: 420 fig = pickle.load(f) 421 progress(1.0, desc="Done!") 422 return fig, "✅ Calibration curves loaded from cache" 423 except Exception as e: 424 print(f"[WARN] Failed to load cache: {e}") 425 426 try: 427 progress(0.1, desc="Loading model...") 428 analyzer = GarbageModelAnalyzer() 429 analyzer.load_model(checkpoint_path=model_path) 430 431 progress(0.3, desc="Setting up data...") 432 analyzer.setup_data(batch_size=32) 433 434 progress(0.5, desc="Evaluating model...") 435 val_loader = analyzer.data_module.val_dataloader() 436 preds, labels, probs = analyzer.evaluate_loader(val_loader) 437 438 progress(0.8, desc="Generating calibration curves...") 439 440 analyzer.plot_calibration_curves(labels, probs) 441 fig = plt.gcf() 442 443 # Save to cache 444 progress(0.9, desc="Saving to cache...") 445 with open(cache_file, "wb") as f: 446 pickle.dump(fig, f) 447 448 progress(1.0, desc="Done!") 449 return fig, "✅ Calibration curves generated and cached" 450 451 except Exception as e: 452 return None, f"❌ Error: {str(e)}"
Generate and cache calibration curves for the selected model.
Calibration curves show how well predicted probabilities match actual outcomes. This function generates curves for all classes and caches the result for faster subsequent access.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
- progress (gr.Progress, optional): Gradio progress tracker for UI updates.
Returns
- tuple of (matplotlib.figure.Figure or None, str): - Figure: The calibration curves plot
- str: Status message describing the operation result
Notes
The calibration plot is generated using GarbageModelAnalyzer's plot_calibration_curves method and cached to disk. Cache is automatically invalidated when the model is retrained.
Well-calibrated models should have curves close to the diagonal, indicating that predicted probabilities match actual frequencies.
See Also
is_cache_valid: Cache validation logic
GarbageModelAnalyzer.plot_calibration_curves: Core plotting function
455def get_metrics_path_for_model(model_choice): 456 """ 457 Get the correct metrics.json path for the selected model. 458 459 Different models store their training metrics in different locations. 460 This function maps model choices to their corresponding metrics files. 461 462 Parameters 463 ---------- 464 model_choice : str 465 Name of the selected model (from get_available_models()). 466 467 Returns 468 ------- 469 pathlib.Path 470 Path to the metrics.json file for the selected model. 471 472 Notes 473 ----- 474 The best provided model has metrics in a fixed location, while the 475 latest trained model uses the configured loss curves path. 476 """ 477 if model_choice == "Best Model (Provided)": 478 return Path("models/best/performance/loss_curves/metrics.json") 479 else: 480 return Path(cfg.LOSS_CURVES_PATH) / "metrics.json"
Get the correct metrics.json path for the selected model.
Different models store their training metrics in different locations. This function maps model choices to their corresponding metrics files.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
Returns
- pathlib.Path: Path to the metrics.json file for the selected model.
Notes
The best provided model has metrics in a fixed location, while the latest trained model uses the configured loss curves path.
483def load_loss_curves(model_choice): 484 """ 485 Load and plot training/validation loss curves from metrics.json. 486 487 Parameters 488 ---------- 489 model_choice : str 490 Name of the selected model (from get_available_models()). 491 492 Returns 493 ------- 494 tuple of (matplotlib.figure.Figure or None, str) 495 - Figure: Loss curves plot showing train and validation loss 496 - str: Status message describing the operation result 497 498 Notes 499 ----- 500 Loss curves show how training and validation loss evolved during 501 training. Diverging curves may indicate overfitting, while parallel 502 curves suggest good generalization. 503 504 The plot includes: 505 - Blue line with circles: Training loss 506 - Red line with squares: Validation loss 507 - Grid for easier reading 508 509 See Also 510 -------- 511 load_accuracy_curves : Load accuracy metrics instead of loss 512 get_metrics_path_for_model : Determine metrics file location 513 """ 514 try: 515 metrics_path = get_metrics_path_for_model(model_choice) 516 517 if not metrics_path.exists(): 518 return ( 519 None, 520 f"❌ No training metrics found for \ 521 {model_choice}. Path: {metrics_path}", 522 ) 523 524 with open(metrics_path, "r") as f: 525 data = json.load(f) 526 527 train_losses = data.get("train_losses", []) 528 val_losses = data.get("val_losses", []) 529 530 if not train_losses and not val_losses: 531 return None, "❌ No loss data available" 532 533 fig, ax = plt.subplots(figsize=(10, 6)) 534 epochs = range(1, len(train_losses) + 1) 535 536 if train_losses: 537 ax.plot( 538 epochs, 539 train_losses, 540 "b-o", 541 label="Train Loss", 542 linewidth=2 543 ) 544 if val_losses: 545 ax.plot( 546 epochs, 547 val_losses, 548 "r-s", 549 label="Validation Loss", 550 linewidth=2 551 ) 552 553 ax.set_xlabel("Epoch", fontsize=12) 554 ax.set_ylabel("Loss", fontsize=12) 555 ax.set_title(f"Loss Curves - \ 556 {model_choice}", fontsize=14, fontweight="bold") 557 ax.legend(fontsize=10) 558 ax.grid(True, alpha=0.3) 559 plt.tight_layout() 560 561 return fig, f"✅ Loss curves loaded successfully from {model_choice}" 562 563 except Exception as e: 564 return None, f"❌ Error loading loss curves: {str(e)}"
Load and plot training/validation loss curves from metrics.json.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
Returns
- tuple of (matplotlib.figure.Figure or None, str): - Figure: Loss curves plot showing train and validation loss
- str: Status message describing the operation result
Notes
Loss curves show how training and validation loss evolved during training. Diverging curves may indicate overfitting, while parallel curves suggest good generalization.
The plot includes:
- Blue line with circles: Training loss
- Red line with squares: Validation loss
- Grid for easier reading
See Also
load_accuracy_curves: Load accuracy metrics instead of loss
get_metrics_path_for_model: Determine metrics file location
567def load_accuracy_curves(model_choice): 568 """ 569 Load and plot training/validation accuracy curves from metrics.json. 570 571 Parameters 572 ---------- 573 model_choice : str 574 Name of the selected model (from get_available_models()). 575 576 Returns 577 ------- 578 tuple of (matplotlib.figure.Figure or None, str) 579 - Figure: Accuracy curves plot showing train and validation accuracy 580 - str: Status message describing the operation result 581 582 Notes 583 ----- 584 Accuracy curves show how classification accuracy evolved during training. 585 The plot includes: 586 - Blue line with circles: Training accuracy 587 - Red line with squares: Validation accuracy 588 - Y-axis limited to [0, 1] for consistency 589 - Grid for easier reading 590 591 A large gap between train and validation accuracy indicates overfitting. 592 593 See Also 594 -------- 595 load_loss_curves : Load loss metrics instead of accuracy 596 get_metrics_path_for_model : Determine metrics file location 597 """ 598 try: 599 metrics_path = get_metrics_path_for_model(model_choice) 600 601 if not metrics_path.exists(): 602 return ( 603 None, 604 f"❌ No training metrics found for \ 605 {model_choice}. Path: {metrics_path}", 606 ) 607 608 with open(metrics_path, "r") as f: 609 data = json.load(f) 610 611 train_accs = data.get("train_accs", []) 612 val_accs = data.get("val_accs", []) 613 614 if not train_accs and not val_accs: 615 return None, "❌ No accuracy data available" 616 617 fig, ax = plt.subplots(figsize=(10, 6)) 618 619 if train_accs: 620 ax.plot( 621 range(1, len(train_accs) + 1), 622 train_accs, 623 "b-o", 624 label="Train Accuracy", 625 linewidth=2, 626 ) 627 if val_accs: 628 ax.plot( 629 range(1, len(val_accs) + 1), 630 val_accs, 631 "r-s", 632 label="Validation Accuracy", 633 linewidth=2, 634 ) 635 636 ax.set_xlabel("Epoch", fontsize=12) 637 ax.set_ylabel("Accuracy", fontsize=12) 638 ax.set_title( 639 f"Accuracy Curves - \ 640 {model_choice}", fontsize=14, fontweight="bold" 641 ) 642 ax.legend(fontsize=10) 643 ax.grid(True, alpha=0.3) 644 ax.set_ylim([0, 1]) 645 plt.tight_layout() 646 647 return fig, f"✅ Accuracy curves \ 648 loaded successfully from {model_choice}" 649 650 except Exception as e: 651 return None, f"❌ Error loading \ 652 accuracy curves: {str(e)}"
Load and plot training/validation accuracy curves from metrics.json.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
Returns
- tuple of (matplotlib.figure.Figure or None, str): - Figure: Accuracy curves plot showing train and validation accuracy
- str: Status message describing the operation result
Notes
Accuracy curves show how classification accuracy evolved during training. The plot includes:
- Blue line with circles: Training accuracy
- Red line with squares: Validation accuracy
- Y-axis limited to [0, 1] for consistency
- Grid for easier reading
A large gap between train and validation accuracy indicates overfitting.
See Also
load_loss_curves: Load loss metrics instead of accuracy
get_metrics_path_for_model: Determine metrics file location
660def predict_single_image_gradio( 661 model_choice, image, carbon_display_text, track_carbon=True 662): 663 """ 664 Gradio wrapper for single image prediction with visualization. 665 666 Performs inference on a single image and generates a horizontal bar 667 chart of class probabilities, highlighting the predicted class. 668 669 Parameters 670 ---------- 671 model_choice : str 672 Name of the selected model (from get_available_models()). 673 image : np.ndarray 674 Input image as numpy array (from gr.Image component). 675 carbon_display_text : str 676 Current HTML string for the carbon display (to be updated). 677 track_carbon : bool, optional 678 Whether to track carbon emissions for this prediction, 679 by default True. 680 681 Returns 682 ------- 683 tuple of (matplotlib.figure.Figure or None, str, str) 684 - Figure: Bar chart of class probabilities 685 - str: Markdown-formatted prediction results and statistics 686 - str: Updated HTML for carbon display 687 688 Notes 689 ----- 690 The probability bar chart uses: 691 - Green bar for the predicted class 692 - Sky blue bars for other classes 693 - Percentage labels on each bar 694 695 If carbon tracking is enabled, emissions are added to the cumulative 696 total and the carbon display is updated. 697 698 See Also 699 -------- 700 predict_batch : Batch prediction on multiple images 701 source.predict.predict_image : Core prediction function 702 """ 703 if image is None: 704 return None, "Please upload an image", carbon_display_text 705 706 if not model_choice: 707 return None, "Please select a model first", carbon_display_text 708 709 try: 710 models_dict = get_available_models() 711 model_path = models_dict.get(model_choice) 712 model, device, transform = load_model_for_inference( 713 model_path=model_path 714 ) 715 716 result = predict_image( 717 image, 718 model=model, 719 transform=transform, 720 device=device, 721 track_carbon=track_carbon, 722 ) 723 724 fig, ax = plt.subplots(figsize=(10, 6)) 725 probs_list = [result["probabilities"][cls] for cls in cfg.CLASS_NAMES] 726 pred_idx = result["predicted_idx"] 727 aux_list = range(len(cfg.CLASS_NAMES)) 728 colors = [ 729 "green" if i == pred_idx else "skyblue" for i in aux_list 730 ] 731 bars = ax.barh(cfg.CLASS_NAMES, probs_list, color=colors) 732 ax.set_xlabel("Probability", fontsize=12) 733 ax.set_title( 734 f'Prediction Probabilities\nPredicted Class: \ 735 {result["predicted_class"]}', 736 fontsize=14, 737 fontweight="bold", 738 ) 739 ax.set_xlim([0, 1]) 740 741 for i, (bar, prob) in enumerate(zip(bars, probs_list)): 742 ax.text( 743 prob + 0.02, 744 bar.get_y() + bar.get_height() / 2, 745 f"{prob*100:.1f}%", 746 va="center", 747 fontsize=10, 748 ) 749 750 plt.tight_layout() 751 752 result_text = f"### 🎯 Prediction: \ 753 **{result['predicted_class']}**\n\n" 754 result_text += f"**Confidence:** {result['confidence']*100:.2f}%\n\n" 755 result_text += "**All Probabilities:**\n" 756 for class_name, prob in result["probabilities"].items(): 757 emoji = "🏆" if class_name == result["predicted_class"] \ 758 else " " 759 result_text += f"{emoji} {class_name}: {prob*100:.2f}%\n" 760 761 updated_carbon_display = carbon_display_text 762 if result["emissions"]: 763 emissions = result["emissions"] 764 result_text += "\n\n### 🌍 Carbon Footprint\n" 765 result_text += f"- **Emissions:** \ 766 {emissions['emissions_g']:.4f}g CO₂eq\n" 767 result_text += f"- **🚗 Car equivalent:** \ 768 {emissions['car_distance_formatted']} driven\n" 769 updated_carbon_display = format_total_emissions_display( 770 get_emissions_path() 771 ) 772 773 return fig, result_text, updated_carbon_display 774 775 except Exception as e: 776 return None, f"❌ Error: {str(e)}", carbon_display_text
Gradio wrapper for single image prediction with visualization.
Performs inference on a single image and generates a horizontal bar chart of class probabilities, highlighting the predicted class.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
- image (np.ndarray): Input image as numpy array (from gr.Image component).
- carbon_display_text (str): Current HTML string for the carbon display (to be updated).
- track_carbon (bool, optional): Whether to track carbon emissions for this prediction, by default True.
Returns
- tuple of (matplotlib.figure.Figure or None, str, str): - Figure: Bar chart of class probabilities
- str: Markdown-formatted prediction results and statistics
- str: Updated HTML for carbon display
Notes
The probability bar chart uses:
- Green bar for the predicted class
- Sky blue bars for other classes
- Percentage labels on each bar
If carbon tracking is enabled, emissions are added to the cumulative total and the carbon display is updated.
See Also
predict_batch: Batch prediction on multiple images
source.predict.predict_image: Core prediction function
779def predict_folder_gradio( 780 model_choice, files, carbon_display_text, track_carbon=True 781): 782 """ 783 Gradio wrapper for batch prediction on multiple images. 784 785 Performs inference on multiple images simultaneously and returns 786 results in a tabular format with per-image predictions and overall 787 statistics. 788 789 Parameters 790 ---------- 791 model_choice : str 792 Name of the selected model (from get_available_models()). 793 files : list of gr.File 794 List of uploaded file objects from Gradio file input. 795 carbon_display_text : str 796 Current HTML string for the carbon display (to be updated). 797 track_carbon : bool, optional 798 Whether to track carbon emissions for this batch, by default True. 799 800 Returns 801 ------- 802 tuple of (pd.DataFrame or None, str, str) 803 - DataFrame: Table with filename, predicted class, and confidence 804 for each image 805 - str: Markdown-formatted summary statistics and carbon footprint 806 - str: Updated HTML for carbon display 807 808 Notes 809 ----- 810 The results table includes: 811 - Filename: Original image filename 812 - Predicted Class: Predicted garbage category 813 - Confidence (%): Prediction confidence as percentage 814 815 Summary statistics include: 816 - Total images processed 817 - Number of successful predictions 818 - Total carbon emissions (if tracked) 819 - Average emissions per image 820 - Car distance equivalent 821 822 Failed predictions are marked as "Error" in the table with the 823 error message in the confidence column. 824 825 See Also 826 -------- 827 predict_single_image_gradio : Single image prediction 828 source.predict.predict_batch : Core batch prediction function 829 """ 830 if not files or len(files) == 0: 831 return None, "Please upload images", carbon_display_text 832 833 if not model_choice: 834 return None, "Please select a model first", carbon_display_text 835 836 try: 837 models_dict = get_available_models() 838 model_path = models_dict.get(model_choice) 839 model, device, transform = load_model_for_inference( 840 model_path=model_path 841 ) 842 843 image_paths = [file.name for file in files] 844 845 batch_result = predict_batch( 846 image_paths, 847 model=model, 848 transform=transform, 849 device=device, 850 track_carbon=track_carbon, 851 ) 852 853 df_data = [] 854 for result in batch_result["results"]: 855 if result["status"] == "success": 856 df_data.append( 857 { 858 "Filename": result["filename"], 859 "Predicted Class": result["predicted_class"], 860 "Confidence (%)": f"{result['confidence']*100:.2f}", 861 } 862 ) 863 else: 864 df_data.append( 865 { 866 "Filename": result["filename"], 867 "Predicted Class": "Error", 868 "Confidence (%)": result["error"], 869 } 870 ) 871 872 df_results = pd.DataFrame(df_data) 873 874 summary = batch_result["summary"] 875 result_text = "### 📊 Batch Prediction Results\n\n" 876 result_text += f"**Total images processed:** \ 877 {summary['total_images']}\n" 878 result_text += f"**Successful predictions:** \ 879 {summary['successful']}\n\n" 880 881 updated_carbon_display = carbon_display_text 882 if batch_result["emissions"]: 883 emissions = batch_result["emissions"] 884 result_text += "### 🌍 Carbon Footprint\n" 885 result_text += f"- **Emissions:** \ 886 {emissions['emissions_g']:.4f}g CO₂eq\n" 887 result_text += f"- **🚗 Car equivalent:** \ 888 {emissions['car_distance_formatted']} driven\n" 889 result_text += f"- **Avg per image:** \ 890 {emissions['emissions_per_image_g']:.4f}g CO₂eq\n" 891 updated_carbon_display = format_total_emissions_display( 892 get_emissions_path() 893 ) 894 895 return df_results, result_text, updated_carbon_display 896 897 except Exception as e: 898 return None, f"❌ Error: {str(e)}", carbon_display_text
Gradio wrapper for batch prediction on multiple images.
Performs inference on multiple images simultaneously and returns results in a tabular format with per-image predictions and overall statistics.
Parameters
- model_choice (str): Name of the selected model (from get_available_models()).
- files (list of gr.File): List of uploaded file objects from Gradio file input.
- carbon_display_text (str): Current HTML string for the carbon display (to be updated).
- track_carbon (bool, optional): Whether to track carbon emissions for this batch, by default True.
Returns
- tuple of (pd.DataFrame or None, str, str): - DataFrame: Table with filename, predicted class, and confidence
for each image
- str: Markdown-formatted summary statistics and carbon footprint
- str: Updated HTML for carbon display
Notes
The results table includes:
- Filename: Original image filename
- Predicted Class: Predicted garbage category
- Confidence (%): Prediction confidence as percentage
Summary statistics include:
- Total images processed
- Number of successful predictions
- Total carbon emissions (if tracked)
- Average emissions per image
- Car distance equivalent
Failed predictions are marked as "Error" in the table with the error message in the confidence column.
See Also
predict_single_image_gradio: Single image prediction
source.predict.predict_batch: Core batch prediction function
906def model_evaluation_tab(carbon_display): 907 """ 908 Create the Model Evaluation and Inference UI section. 909 910 Builds a comprehensive Gradio interface for model evaluation, metrics 911 visualization, and real-time inference. Organized into three main areas: 912 model selection, metrics visualization, and image prediction. 913 914 Parameters 915 ---------- 916 carbon_display : gr.HTML 917 The carbon counter display component to update after inference 918 operations. Shows cumulative emissions across all tracked operations. 919 920 Returns 921 ------- 922 list 923 Empty list (kept for API consistency). 924 925 Notes 926 ----- 927 The interface is organized into sections: 928 929 1. **Model Selection:** 930 - Radio buttons to choose between best model and latest trained model 931 932 2. **Metrics Visualization Tabs:** 933 - Confusion Matrix: Raw and normalized versions with instant toggle 934 - Loss Curves: Training and validation loss over epochs 935 - Accuracy Curves: Training and validation accuracy over epochs 936 - Calibration Curves: Per-class calibration analysis 937 938 3. **Inference Tabs:** 939 - Single Image: Upload and classify individual images 940 - Batch Prediction: Process multiple images simultaneously 941 942 All expensive computations (confusion matrices, calibration curves) are 943 cached and automatically invalidated when models are retrained. 944 945 Carbon emissions can be optionally tracked for all inference operations 946 and are displayed in both absolute terms and car distance equivalents. 947 948 Examples 949 -------- 950 >>> with gr.Blocks() as demo: 951 ... carbon_display = gr.HTML() 952 ... model_evaluation_tab(carbon_display) 953 """ 954 with gr.Column(): 955 gr.Markdown("### 🔬 Model Evaluation & Inference") 956 gr.Markdown( 957 "Evaluate trained models, visualize metrics, \ 958 and make predictions on new images." 959 ) 960 961 gr.Markdown("#### 🧠 Model Selection") 962 model_choice = gr.Radio( 963 choices=list(get_available_models().keys()), 964 value=list(get_available_models().keys())[0], 965 label="Select Model", 966 info="Choose between the best provided model or \ 967 your latest trained model", 968 ) 969 970 gr.Markdown("---") 971 972 gr.Markdown("#### 📈 Model Metrics & Visualizations") 973 974 with gr.Tabs(): 975 with gr.Tab("Confusion Matrix"): 976 show_normalized = gr.Checkbox( 977 label="Show Normalized", 978 value=False, 979 info="Toggle between raw counts and normalized \ 980 percentages", 981 visible=False, 982 ) 983 cm_button = gr.Button( 984 "Generate Confusion Matrix", variant="primary" 985 ) 986 cm_plot = gr.Plot(label="Confusion Matrix") 987 cm_status = gr.Markdown("") 988 989 cm_button.click( 990 fn=generate_confusion_matrix, 991 inputs=[model_choice, show_normalized], 992 outputs=[cm_plot, cm_status, show_normalized], 993 ) 994 995 show_normalized.change( 996 fn=toggle_confusion_matrix, 997 inputs=[show_normalized], 998 outputs=[cm_plot], 999 ) 1000 1001 with gr.Tab("Loss Curves"): 1002 loss_button = gr.Button("Load Loss Curves", variant="primary") 1003 loss_plot = gr.Plot(label="Loss Curves") 1004 loss_status = gr.Markdown("") 1005 1006 loss_button.click( 1007 fn=load_loss_curves, 1008 inputs=[model_choice], 1009 outputs=[loss_plot, loss_status], 1010 ) 1011 1012 with gr.Tab("Accuracy Curves"): 1013 acc_button = gr.Button( 1014 "Load Accuracy Curves", variant="primary" 1015 ) 1016 acc_plot = gr.Plot(label="Accuracy Curves") 1017 acc_status = gr.Markdown("") 1018 1019 acc_button.click( 1020 fn=load_accuracy_curves, 1021 inputs=[model_choice], 1022 outputs=[acc_plot, acc_status], 1023 ) 1024 1025 with gr.Tab("Calibration Curves"): 1026 calib_button = gr.Button( 1027 "Generate Calibration Curves", variant="primary" 1028 ) 1029 calib_plot = gr.Plot(label="Calibration Curves") 1030 calib_status = gr.Markdown("") 1031 1032 calib_button.click( 1033 fn=generate_calibration_curves, 1034 inputs=[model_choice], 1035 outputs=[calib_plot, calib_status], 1036 ) 1037 1038 gr.Markdown("---") 1039 1040 gr.Markdown("#### 🔍 Image Prediction") 1041 1042 with gr.Tabs(): 1043 with gr.Tab("Single Image"): 1044 gr.Markdown( 1045 "Upload an image to classify it into one of \ 1046 the garbage categories." 1047 ) 1048 1049 with gr.Row(): 1050 with gr.Column(scale=1): 1051 single_image_input = gr.Image( 1052 label="Upload Image", type="numpy", height=400 1053 ) 1054 single_track_carbon = gr.Checkbox( 1055 label="🌍 Track Carbon Emissions", value=True 1056 ) 1057 single_predict_button = gr.Button( 1058 "🔍 Predict", variant="primary", size="lg" 1059 ) 1060 1061 with gr.Column(scale=1): 1062 single_probs_plot = gr.Plot( 1063 label="Class Probabilities" 1064 ) 1065 1066 single_result_text = gr.Markdown( 1067 "Upload an image and click 'Predict'" 1068 ) 1069 1070 single_predict_button.click( 1071 fn=predict_single_image_gradio, 1072 inputs=[ 1073 model_choice, 1074 single_image_input, 1075 carbon_display, 1076 single_track_carbon, 1077 ], 1078 outputs=[ 1079 single_probs_plot, 1080 single_result_text, 1081 carbon_display 1082 ], 1083 ) 1084 1085 with gr.Tab("Batch Prediction"): 1086 gr.Markdown("Upload multiple images \ 1087 to classify them all at once.") 1088 1089 batch_image_input = gr.File( 1090 label="Upload Images", 1091 file_count="multiple", file_types=["image"] 1092 ) 1093 batch_track_carbon = gr.Checkbox( 1094 label="🌍 Track Carbon Emissions", value=True 1095 ) 1096 batch_predict_button = gr.Button( 1097 "🔍 Predict All", variant="primary", size="lg" 1098 ) 1099 1100 batch_result_text = gr.Markdown( 1101 "Upload images and click 'Predict All'" 1102 ) 1103 batch_results_table = gr.Dataframe( 1104 label="Prediction Results", 1105 interactive=False, wrap=True 1106 ) 1107 1108 batch_predict_button.click( 1109 fn=predict_folder_gradio, 1110 inputs=[ 1111 model_choice, 1112 batch_image_input, 1113 carbon_display, 1114 batch_track_carbon, 1115 ], 1116 outputs=[ 1117 batch_results_table, 1118 batch_result_text, 1119 carbon_display 1120 ], 1121 ) 1122 1123 gr.Markdown("---") 1124 gr.Markdown( 1125 "**ℹ️ Info:** Carbon emissions are tracked for \ 1126 inference operations and added to the total \ 1127 carbon footprint." 1128 ) 1129 1130 return []
Create the Model Evaluation and Inference UI section.
Builds a comprehensive Gradio interface for model evaluation, metrics visualization, and real-time inference. Organized into three main areas: model selection, metrics visualization, and image prediction.
Parameters
- carbon_display (gr.HTML): The carbon counter display component to update after inference operations. Shows cumulative emissions across all tracked operations.
Returns
- list: Empty list (kept for API consistency).
Notes
The interface is organized into sections:
Model Selection:
- Radio buttons to choose between best model and latest trained model
Metrics Visualization Tabs:
- Confusion Matrix: Raw and normalized versions with instant toggle
- Loss Curves: Training and validation loss over epochs
- Accuracy Curves: Training and validation accuracy over epochs
- Calibration Curves: Per-class calibration analysis
Inference Tabs:
- Single Image: Upload and classify individual images
- Batch Prediction: Process multiple images simultaneously
All expensive computations (confusion matrices, calibration curves) are cached and automatically invalidated when models are retrained.
Carbon emissions can be optionally tracked for all inference operations and are displayed in both absolute terms and car distance equivalents.
Examples
>>> with gr.Blocks() as demo:
... carbon_display = gr.HTML()
... model_evaluation_tab(carbon_display)