source.utils.custom_classes.EvalAnalyzer

  1#!/usr/bin/env python
  2# coding: utf-8
  3
  4import os
  5import numpy as np
  6import pandas as pd
  7import matplotlib.pyplot as plt
  8import torch
  9from pathlib import Path
 10from PIL import Image
 11from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
 12from sklearn.calibration import calibration_curve
 13
 14from source.utils.custom_classes.GarbageClassifier import GarbageClassifier
 15from source.utils.custom_classes.GarbageDataModule import GarbageDataModule
 16from source.utils import config as cfg
 17
 18
 19class GarbageModelAnalyzer:
 20    """
 21    Analyzer class for evaluating and visualizing garbage
 22    classification model performance.
 23
 24    This class provides comprehensive tools for model evaluation
 25    including confusion matrices, calibration curves, and misclassified
 26    sample visualization. It handles both model loading and data
 27    module setup, with support for GPU acceleration.
 28
 29    Attributes
 30    ----------
 31    dataset_path : str
 32        Path to the dataset folder containing images organized by class.
 33    performance_path : str
 34        Path to the directory where performance figures are saved.
 35    device : torch.device
 36        Computational device (CUDA if available, otherwise CPU).
 37    df : pd.DataFrame or None
 38        Metadata DataFrame containing dataset information.
 39    model : GarbageClassifier or None
 40        Loaded model instance for inference.
 41    data_module : GarbageDataModule or None
 42        Data module for dataset loading and preprocessing.
 43    """
 44
 45    def __init__(self, dataset_path=None, performance_path=None):
 46        """
 47        Initialize the GarbageModelAnalyzer instance.
 48
 49        Loads metadata from the dataset path and sets up paths
 50        for storing performance figures. Detects CUDA availability
 51        for GPU acceleration.
 52
 53        Parameters
 54        ----------
 55        dataset_path : str, optional
 56            Path to the dataset folder. Default is
 57            "../data/raw/sample_dataset".
 58        performance_path : str, optional
 59            Path to the directory for saving performance figures. Default is
 60            "../reports/figures/performance/".
 61
 62        Returns
 63        -------
 64        None
 65        """
 66        self.dataset_path = dataset_path or os.path.join(
 67            "..", "data", "raw", "sample_dataset"
 68        )
 69        self.performance_path = performance_path or \
 70            "../reports/figures/performance/"
 71        self.device = torch.device(
 72            "cuda" if torch.cuda.is_available() else "cpu")
 73        metadata_path = Path(cfg.DATASET_PATH).parent / "metadata.csv"
 74
 75        if metadata_path.exists():
 76            self.df = pd.read_csv(metadata_path)
 77        else:
 78            print(f"Warning: metadata.csv not found at {metadata_path}")
 79            self.df = None
 80
 81        self.model = None
 82        self.data_module = None
 83
 84    def load_model(self, checkpoint_path=None, num_classes=None):
 85        """
 86        Load a trained GarbageClassifier model from a checkpoint.
 87
 88        Loads a model from a PyTorch Lightning checkpoint and moves it to the
 89        specified device in evaluation mode.
 90
 91        Parameters
 92        ----------
 93        checkpoint_path : str, optional
 94            Path to the model checkpoint file. If None, uses cfg.MODEL_PATH.
 95            Default is None.
 96        num_classes : int, optional
 97            Number of output classes. If None, uses cfg.NUM_CLASSES.
 98            Default is None.
 99
100        Returns
101        -------
102        None
103
104        Notes
105        -----
106        Model is automatically set to evaluation mode and moved
107        to the configured device.
108        """
109        checkpoint_path = checkpoint_path or cfg.MODEL_PATH
110        num_classes = num_classes or cfg.NUM_CLASSES
111        print("Loading model...")
112        self.model = GarbageClassifier.load_from_checkpoint(
113            checkpoint_path, num_classes=num_classes
114        )
115        self.model.to(self.device).eval()
116        print("Model loaded.")
117
118    def setup_data(self, batch_size=32):
119        """
120        Set up the data module and filter metadata for available samples.
121
122        Initializes the GarbageDataModule and creates a filtered subset
123        of metadata containing only files present in the dataset directory.
124
125        Parameters
126        ----------
127        batch_size : int, optional
128            Batch size for data loading. Default is 32.
129
130        Returns
131        -------
132        None
133
134        Notes
135        -----
136        Creates self.df_subset containing only samples actually
137        present in the dataset.
138        """
139
140        self.data_module = GarbageDataModule(batch_size=batch_size)
141        self.data_module.setup()
142        file_names = []
143        for root, dirs, files in os.walk(cfg.DATASET_PATH):
144            for file in files:
145                file_names.append(file)
146        self.df_subset = (
147            self.df[
148                self.df["filename"].isin(file_names)
149            ].reset_index(drop=True).copy()
150        )
151
152    def evaluate_loader(self, loader):
153        """
154        Evaluate model on a data loader and collect predictions
155        and probabilities.
156
157        Iterates through a PyTorch DataLoader, performs inference,
158        and collects predictions, true labels, and confidence scores.
159
160        Parameters
161        ----------
162        loader : torch.utils.data.DataLoader
163            DataLoader containing batches of (images, labels).
164
165        Returns
166        -------
167        tuple
168            A tuple containing:
169            - all_preds : torch.Tensor
170                Predicted class indices.
171            - all_labels : torch.Tensor
172                True class labels.
173            - all_probs : np.ndarray
174                Confidence scores for each class (shape: [N, num_classes]).
175
176        Notes
177        -----
178        Model inference is performed without gradient computation
179        for efficiency.
180        Probabilities are computed using softmax activation.
181        """
182        all_preds, all_labels, all_probs = [], [], []
183        with torch.no_grad():
184            for xb, yb in loader:
185                xb = xb.to(self.device)
186                yb = yb.to(self.device)
187                out = self.model(xb)
188                preds = out.argmax(dim=1)
189                probs = torch.softmax(out, dim=1)
190                all_preds.append(preds)
191                all_probs.append(probs.cpu())
192                all_labels.append(yb)
193        all_preds = torch.cat(all_preds)
194        all_labels = torch.cat(all_labels)
195        all_probs = torch.cat(all_probs).numpy()
196        return all_preds, all_labels, all_probs
197
198    def plot_confusion_matrix(self, labels, preds, set_name="Train"):
199        """
200        Plot and save confusion matrices (raw and normalized)
201        with class metrics.
202
203        Generates both raw and normalized confusion matrices,
204        computes TP, FP, FN, TN per class, and saves visualizations
205        as PDF files.
206
207        Parameters
208        ----------
209        labels : array-like
210            True class labels.
211        preds : array-like
212            Predicted class labels.
213        set_name : str, optional
214            Name of the dataset split (e.g., "Train", "Val", "Test")
215            for plot titles and filenames. Default is "Train".
216
217        Returns
218        -------
219        None
220
221        Side Effects
222        -----------
223        - Saves confusion_mat_{set_name.lower()}.pdf to performance_path.
224        - Saves confusion_mat_{set_name.lower()}_norm.pdf to performance_path.
225        - Prints TP, FP, FN, TN statistics for each class.
226        - Displays matplotlib figures.
227
228        Notes
229        -----
230        Normalized confusion matrix divides by row sums to show percentages
231        per class.
232        """
233        num_classes = self.data_module.num_classes
234        cm = confusion_matrix(labels, preds, labels=range(num_classes))
235        disp = ConfusionMatrixDisplay(
236            confusion_matrix=cm, display_labels=cfg.CLASS_NAMES
237        )
238        disp.plot(cmap=plt.cm.Blues)
239        plt.title(f"Confusion Matrix - {set_name} set")
240        plt.savefig(
241            os.path.join(
242                self.performance_path, f"confusion_mat_{set_name.lower()}.pdf"
243            ),
244            dpi=80,
245        )
246        plt.show()
247
248        # Normalized
249        cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
250        disp_norm = ConfusionMatrixDisplay(
251            confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES
252        )
253        disp_norm.plot(cmap=plt.cm.Blues)
254        plt.title(f"Normalized Confusion Matrix - {set_name} set")
255        plt.savefig(
256            os.path.join(
257                self.performance_path,
258                f"confusion_mat_{set_name.lower()}_norm.pdf"
259            ),
260            dpi=80,
261        )
262        plt.show()
263
264        # TP, FP, FN, TN
265        TP = np.diag(cm)
266        FP = cm.sum(axis=0) - TP
267        FN = cm.sum(axis=1) - TP
268        TN = cm.sum() - (TP + FP + FN)
269        for i in range(num_classes):
270            print(f"Clase {i}: TP={TP[i]}, FP={FP[i]}, FN={FN[i]}, TN={TN[i]}")
271
272    def plot_top_misclassified(
273        self, df_set, y_true, y_pred, y_proba, N=10, filename=None
274    ):
275        """
276        Visualize the top N misclassified samples with lowest confidence.
277
278        Identifies misclassified samples, sorts them by confidence
279        on the true class, and displays the least confident (worst)
280        predictions with their images.
281
282        Parameters
283        ----------
284        df_set : pd.DataFrame
285            DataFrame containing sample metadata with 'label' and 'filename'
286            columns.
287        y_true : array-like
288            True class labels (can be integers or strings).
289        y_pred : array-like
290            Predicted class labels (integers).
291        y_proba : np.ndarray
292            Confidence scores matrix (shape: [N, num_classes]).
293        N : int, optional
294            Number of top misclassified samples to display. Default is 10.
295        filename : str, optional
296            Filename (without extension) to save the figure. If provided,
297            saves as PDF to performance_path. Default is None (no save).
298
299        Returns
300        -------
301        None
302
303        Side Effects
304        -----------
305        - Displays matplotlib figure with misclassified samples.
306        - Saves figure to performance_path/{filename}.pdf
307        if filename is provided.
308
309        Notes
310        -----
311        Samples are sorted by confidence on the true class in ascending order,
312        showing the most uncertain misclassifications first.
313        Images are loaded from dataset_path/{label}/{filename} structure.
314        """
315        y_true = np.array(y_true)
316        y_pred = np.array(y_pred)
317
318        # Handle case: y_true are integers vs strings
319        if np.issubdtype(y_true.dtype, np.integer):
320            true_indices = y_true
321            classes = sorted(df_set["label"].unique())
322        else:
323            classes = sorted(np.unique(y_true))
324            class_to_idx = {cls: i for i, cls in enumerate(classes)}
325            true_indices = np.array([class_to_idx[label] for label in y_true])
326
327        true_confidences = y_proba[np.arange(len(y_true)), true_indices]
328        misclassified_idx = np.where(y_true != y_pred)[0]
329
330        if len(misclassified_idx) == 0:
331            print("No misclassified samples found!")
332            return
333
334        sorted_idx = misclassified_idx[
335            np.argsort(true_confidences[misclassified_idx])
336        ]
337        selected_idx = sorted_idx[:N]
338
339        plt.figure(figsize=(15, 3 * (N // 5 + 1)))
340        for i, idx in enumerate(selected_idx, 1):
341            row = df_set.iloc[idx]
342            img_path = os.path.join(
343                self.dataset_path,
344                row["label"],
345                row["filename"]
346            )
347            if not os.path.exists(img_path):
348                continue
349            try:
350                img = Image.open(img_path).convert("RGB")
351            except Exception:
352                continue
353            plt.subplot(int(np.ceil(N / 5)), 5, i)
354            plt.imshow(img)
355            plt.axis("off")
356            plt.title(
357                f"True: {row['label']}\nPred: \
358                    {classes[y_pred[idx]]}\nConf True Class: \
359                        {true_confidences[idx]:.2f}",
360                fontsize=9,
361                color="red",
362            )
363        plt.tight_layout()
364        if filename:
365            plt.savefig(
366                os.path.join(
367                    self.performance_path,
368                    f"{filename}.pdf"
369                ),
370                dpi=80
371            )
372        plt.show()
373
374    def plot_calibration_curves(self, y_true, y_probs):
375        """
376        Plot calibration curves for all classes using one-vs-rest approach.
377
378        Generates calibration curves showing the relationship between predicted
379        probabilities and actual positive fractions for each class.
380
381        Parameters
382        ----------
383        y_true : array-like or torch.Tensor
384            True class labels (integers). Shape: [N,].
385        y_probs : np.ndarray or torch.Tensor
386            Predicted probability matrix. Shape: [N, num_classes].
387
388        Returns
389        -------
390        None
391
392        Side Effects
393        -----------
394        - Displays matplotlib figure with 2x3 grid of calibration curves.
395        - One subplot per class showing calibration curve and reference
396        diagonal.
397
398        Notes
399        -----
400        Uses sklearn's calibration_curve function with 10 bins for each class.
401        Each class is converted to a binary classification problem (OVR).
402        Calibration curves below the diagonal indicate:
403        -> overconfident predictions.
404        Calibration curves above the diagonal indicate:
405        -> underconfident predictions.
406        """
407        num_classes = self.data_module.num_classes
408        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
409        axes = axes.flatten()
410
411        if isinstance(y_true, torch.Tensor):
412            y_true_np = y_true.cpu().numpy()
413        else:
414            y_true_np = y_true
415
416        if isinstance(y_probs, torch.Tensor):
417            y_probs = y_probs.cpu().numpy()
418
419        for c in range(num_classes):
420            ax = axes[c]
421            y_true_c = (y_true_np == c).astype(int)
422            y_prob_c = y_probs[:, c]
423
424            frac_pos, mean_pred = calibration_curve(
425                y_true_c, y_prob_c, n_bins=10
426            )
427            ax.plot(mean_pred, frac_pos, marker="o", label=f"Class {c}")
428            ax.plot(
429                [0, 1],
430                [0, 1],
431                linestyle="--",
432                color="gray",
433                label="Reference"
434            )
435            ax.set_xlabel("Mean predicted probability")
436            ax.set_ylabel("Fraction of positives")
437            ax.set_title(f"Calibration Curve: {cfg.CLASS_NAMES[c]}")
438            ax.set_xticks(np.arange(0, 1.1, 0.1))
439            ax.set_yticks(np.arange(0, 1.1, 0.1))
440            ax.set_xlim(-0.05, 1.05)
441            ax.set_ylim(-0.05, 1.05)
442            ax.grid(True)
443            ax.legend(fontsize=8)
444
445        plt.tight_layout()
446        plt.show()
class GarbageModelAnalyzer:
 20class GarbageModelAnalyzer:
 21    """
 22    Analyzer class for evaluating and visualizing garbage
 23    classification model performance.
 24
 25    This class provides comprehensive tools for model evaluation
 26    including confusion matrices, calibration curves, and misclassified
 27    sample visualization. It handles both model loading and data
 28    module setup, with support for GPU acceleration.
 29
 30    Attributes
 31    ----------
 32    dataset_path : str
 33        Path to the dataset folder containing images organized by class.
 34    performance_path : str
 35        Path to the directory where performance figures are saved.
 36    device : torch.device
 37        Computational device (CUDA if available, otherwise CPU).
 38    df : pd.DataFrame or None
 39        Metadata DataFrame containing dataset information.
 40    model : GarbageClassifier or None
 41        Loaded model instance for inference.
 42    data_module : GarbageDataModule or None
 43        Data module for dataset loading and preprocessing.
 44    """
 45
 46    def __init__(self, dataset_path=None, performance_path=None):
 47        """
 48        Initialize the GarbageModelAnalyzer instance.
 49
 50        Loads metadata from the dataset path and sets up paths
 51        for storing performance figures. Detects CUDA availability
 52        for GPU acceleration.
 53
 54        Parameters
 55        ----------
 56        dataset_path : str, optional
 57            Path to the dataset folder. Default is
 58            "../data/raw/sample_dataset".
 59        performance_path : str, optional
 60            Path to the directory for saving performance figures. Default is
 61            "../reports/figures/performance/".
 62
 63        Returns
 64        -------
 65        None
 66        """
 67        self.dataset_path = dataset_path or os.path.join(
 68            "..", "data", "raw", "sample_dataset"
 69        )
 70        self.performance_path = performance_path or \
 71            "../reports/figures/performance/"
 72        self.device = torch.device(
 73            "cuda" if torch.cuda.is_available() else "cpu")
 74        metadata_path = Path(cfg.DATASET_PATH).parent / "metadata.csv"
 75
 76        if metadata_path.exists():
 77            self.df = pd.read_csv(metadata_path)
 78        else:
 79            print(f"Warning: metadata.csv not found at {metadata_path}")
 80            self.df = None
 81
 82        self.model = None
 83        self.data_module = None
 84
 85    def load_model(self, checkpoint_path=None, num_classes=None):
 86        """
 87        Load a trained GarbageClassifier model from a checkpoint.
 88
 89        Loads a model from a PyTorch Lightning checkpoint and moves it to the
 90        specified device in evaluation mode.
 91
 92        Parameters
 93        ----------
 94        checkpoint_path : str, optional
 95            Path to the model checkpoint file. If None, uses cfg.MODEL_PATH.
 96            Default is None.
 97        num_classes : int, optional
 98            Number of output classes. If None, uses cfg.NUM_CLASSES.
 99            Default is None.
100
101        Returns
102        -------
103        None
104
105        Notes
106        -----
107        Model is automatically set to evaluation mode and moved
108        to the configured device.
109        """
110        checkpoint_path = checkpoint_path or cfg.MODEL_PATH
111        num_classes = num_classes or cfg.NUM_CLASSES
112        print("Loading model...")
113        self.model = GarbageClassifier.load_from_checkpoint(
114            checkpoint_path, num_classes=num_classes
115        )
116        self.model.to(self.device).eval()
117        print("Model loaded.")
118
119    def setup_data(self, batch_size=32):
120        """
121        Set up the data module and filter metadata for available samples.
122
123        Initializes the GarbageDataModule and creates a filtered subset
124        of metadata containing only files present in the dataset directory.
125
126        Parameters
127        ----------
128        batch_size : int, optional
129            Batch size for data loading. Default is 32.
130
131        Returns
132        -------
133        None
134
135        Notes
136        -----
137        Creates self.df_subset containing only samples actually
138        present in the dataset.
139        """
140
141        self.data_module = GarbageDataModule(batch_size=batch_size)
142        self.data_module.setup()
143        file_names = []
144        for root, dirs, files in os.walk(cfg.DATASET_PATH):
145            for file in files:
146                file_names.append(file)
147        self.df_subset = (
148            self.df[
149                self.df["filename"].isin(file_names)
150            ].reset_index(drop=True).copy()
151        )
152
153    def evaluate_loader(self, loader):
154        """
155        Evaluate model on a data loader and collect predictions
156        and probabilities.
157
158        Iterates through a PyTorch DataLoader, performs inference,
159        and collects predictions, true labels, and confidence scores.
160
161        Parameters
162        ----------
163        loader : torch.utils.data.DataLoader
164            DataLoader containing batches of (images, labels).
165
166        Returns
167        -------
168        tuple
169            A tuple containing:
170            - all_preds : torch.Tensor
171                Predicted class indices.
172            - all_labels : torch.Tensor
173                True class labels.
174            - all_probs : np.ndarray
175                Confidence scores for each class (shape: [N, num_classes]).
176
177        Notes
178        -----
179        Model inference is performed without gradient computation
180        for efficiency.
181        Probabilities are computed using softmax activation.
182        """
183        all_preds, all_labels, all_probs = [], [], []
184        with torch.no_grad():
185            for xb, yb in loader:
186                xb = xb.to(self.device)
187                yb = yb.to(self.device)
188                out = self.model(xb)
189                preds = out.argmax(dim=1)
190                probs = torch.softmax(out, dim=1)
191                all_preds.append(preds)
192                all_probs.append(probs.cpu())
193                all_labels.append(yb)
194        all_preds = torch.cat(all_preds)
195        all_labels = torch.cat(all_labels)
196        all_probs = torch.cat(all_probs).numpy()
197        return all_preds, all_labels, all_probs
198
199    def plot_confusion_matrix(self, labels, preds, set_name="Train"):
200        """
201        Plot and save confusion matrices (raw and normalized)
202        with class metrics.
203
204        Generates both raw and normalized confusion matrices,
205        computes TP, FP, FN, TN per class, and saves visualizations
206        as PDF files.
207
208        Parameters
209        ----------
210        labels : array-like
211            True class labels.
212        preds : array-like
213            Predicted class labels.
214        set_name : str, optional
215            Name of the dataset split (e.g., "Train", "Val", "Test")
216            for plot titles and filenames. Default is "Train".
217
218        Returns
219        -------
220        None
221
222        Side Effects
223        -----------
224        - Saves confusion_mat_{set_name.lower()}.pdf to performance_path.
225        - Saves confusion_mat_{set_name.lower()}_norm.pdf to performance_path.
226        - Prints TP, FP, FN, TN statistics for each class.
227        - Displays matplotlib figures.
228
229        Notes
230        -----
231        Normalized confusion matrix divides by row sums to show percentages
232        per class.
233        """
234        num_classes = self.data_module.num_classes
235        cm = confusion_matrix(labels, preds, labels=range(num_classes))
236        disp = ConfusionMatrixDisplay(
237            confusion_matrix=cm, display_labels=cfg.CLASS_NAMES
238        )
239        disp.plot(cmap=plt.cm.Blues)
240        plt.title(f"Confusion Matrix - {set_name} set")
241        plt.savefig(
242            os.path.join(
243                self.performance_path, f"confusion_mat_{set_name.lower()}.pdf"
244            ),
245            dpi=80,
246        )
247        plt.show()
248
249        # Normalized
250        cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
251        disp_norm = ConfusionMatrixDisplay(
252            confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES
253        )
254        disp_norm.plot(cmap=plt.cm.Blues)
255        plt.title(f"Normalized Confusion Matrix - {set_name} set")
256        plt.savefig(
257            os.path.join(
258                self.performance_path,
259                f"confusion_mat_{set_name.lower()}_norm.pdf"
260            ),
261            dpi=80,
262        )
263        plt.show()
264
265        # TP, FP, FN, TN
266        TP = np.diag(cm)
267        FP = cm.sum(axis=0) - TP
268        FN = cm.sum(axis=1) - TP
269        TN = cm.sum() - (TP + FP + FN)
270        for i in range(num_classes):
271            print(f"Clase {i}: TP={TP[i]}, FP={FP[i]}, FN={FN[i]}, TN={TN[i]}")
272
273    def plot_top_misclassified(
274        self, df_set, y_true, y_pred, y_proba, N=10, filename=None
275    ):
276        """
277        Visualize the top N misclassified samples with lowest confidence.
278
279        Identifies misclassified samples, sorts them by confidence
280        on the true class, and displays the least confident (worst)
281        predictions with their images.
282
283        Parameters
284        ----------
285        df_set : pd.DataFrame
286            DataFrame containing sample metadata with 'label' and 'filename'
287            columns.
288        y_true : array-like
289            True class labels (can be integers or strings).
290        y_pred : array-like
291            Predicted class labels (integers).
292        y_proba : np.ndarray
293            Confidence scores matrix (shape: [N, num_classes]).
294        N : int, optional
295            Number of top misclassified samples to display. Default is 10.
296        filename : str, optional
297            Filename (without extension) to save the figure. If provided,
298            saves as PDF to performance_path. Default is None (no save).
299
300        Returns
301        -------
302        None
303
304        Side Effects
305        -----------
306        - Displays matplotlib figure with misclassified samples.
307        - Saves figure to performance_path/{filename}.pdf
308        if filename is provided.
309
310        Notes
311        -----
312        Samples are sorted by confidence on the true class in ascending order,
313        showing the most uncertain misclassifications first.
314        Images are loaded from dataset_path/{label}/{filename} structure.
315        """
316        y_true = np.array(y_true)
317        y_pred = np.array(y_pred)
318
319        # Handle case: y_true are integers vs strings
320        if np.issubdtype(y_true.dtype, np.integer):
321            true_indices = y_true
322            classes = sorted(df_set["label"].unique())
323        else:
324            classes = sorted(np.unique(y_true))
325            class_to_idx = {cls: i for i, cls in enumerate(classes)}
326            true_indices = np.array([class_to_idx[label] for label in y_true])
327
328        true_confidences = y_proba[np.arange(len(y_true)), true_indices]
329        misclassified_idx = np.where(y_true != y_pred)[0]
330
331        if len(misclassified_idx) == 0:
332            print("No misclassified samples found!")
333            return
334
335        sorted_idx = misclassified_idx[
336            np.argsort(true_confidences[misclassified_idx])
337        ]
338        selected_idx = sorted_idx[:N]
339
340        plt.figure(figsize=(15, 3 * (N // 5 + 1)))
341        for i, idx in enumerate(selected_idx, 1):
342            row = df_set.iloc[idx]
343            img_path = os.path.join(
344                self.dataset_path,
345                row["label"],
346                row["filename"]
347            )
348            if not os.path.exists(img_path):
349                continue
350            try:
351                img = Image.open(img_path).convert("RGB")
352            except Exception:
353                continue
354            plt.subplot(int(np.ceil(N / 5)), 5, i)
355            plt.imshow(img)
356            plt.axis("off")
357            plt.title(
358                f"True: {row['label']}\nPred: \
359                    {classes[y_pred[idx]]}\nConf True Class: \
360                        {true_confidences[idx]:.2f}",
361                fontsize=9,
362                color="red",
363            )
364        plt.tight_layout()
365        if filename:
366            plt.savefig(
367                os.path.join(
368                    self.performance_path,
369                    f"{filename}.pdf"
370                ),
371                dpi=80
372            )
373        plt.show()
374
375    def plot_calibration_curves(self, y_true, y_probs):
376        """
377        Plot calibration curves for all classes using one-vs-rest approach.
378
379        Generates calibration curves showing the relationship between predicted
380        probabilities and actual positive fractions for each class.
381
382        Parameters
383        ----------
384        y_true : array-like or torch.Tensor
385            True class labels (integers). Shape: [N,].
386        y_probs : np.ndarray or torch.Tensor
387            Predicted probability matrix. Shape: [N, num_classes].
388
389        Returns
390        -------
391        None
392
393        Side Effects
394        -----------
395        - Displays matplotlib figure with 2x3 grid of calibration curves.
396        - One subplot per class showing calibration curve and reference
397        diagonal.
398
399        Notes
400        -----
401        Uses sklearn's calibration_curve function with 10 bins for each class.
402        Each class is converted to a binary classification problem (OVR).
403        Calibration curves below the diagonal indicate:
404        -> overconfident predictions.
405        Calibration curves above the diagonal indicate:
406        -> underconfident predictions.
407        """
408        num_classes = self.data_module.num_classes
409        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
410        axes = axes.flatten()
411
412        if isinstance(y_true, torch.Tensor):
413            y_true_np = y_true.cpu().numpy()
414        else:
415            y_true_np = y_true
416
417        if isinstance(y_probs, torch.Tensor):
418            y_probs = y_probs.cpu().numpy()
419
420        for c in range(num_classes):
421            ax = axes[c]
422            y_true_c = (y_true_np == c).astype(int)
423            y_prob_c = y_probs[:, c]
424
425            frac_pos, mean_pred = calibration_curve(
426                y_true_c, y_prob_c, n_bins=10
427            )
428            ax.plot(mean_pred, frac_pos, marker="o", label=f"Class {c}")
429            ax.plot(
430                [0, 1],
431                [0, 1],
432                linestyle="--",
433                color="gray",
434                label="Reference"
435            )
436            ax.set_xlabel("Mean predicted probability")
437            ax.set_ylabel("Fraction of positives")
438            ax.set_title(f"Calibration Curve: {cfg.CLASS_NAMES[c]}")
439            ax.set_xticks(np.arange(0, 1.1, 0.1))
440            ax.set_yticks(np.arange(0, 1.1, 0.1))
441            ax.set_xlim(-0.05, 1.05)
442            ax.set_ylim(-0.05, 1.05)
443            ax.grid(True)
444            ax.legend(fontsize=8)
445
446        plt.tight_layout()
447        plt.show()

Analyzer class for evaluating and visualizing garbage classification model performance.

This class provides comprehensive tools for model evaluation including confusion matrices, calibration curves, and misclassified sample visualization. It handles both model loading and data module setup, with support for GPU acceleration.

Attributes
  • dataset_path (str): Path to the dataset folder containing images organized by class.
  • performance_path (str): Path to the directory where performance figures are saved.
  • device (torch.device): Computational device (CUDA if available, otherwise CPU).
  • df (pd.DataFrame or None): Metadata DataFrame containing dataset information.
  • model (GarbageClassifier or None): Loaded model instance for inference.
  • data_module (GarbageDataModule or None): Data module for dataset loading and preprocessing.
GarbageModelAnalyzer(dataset_path=None, performance_path=None)
46    def __init__(self, dataset_path=None, performance_path=None):
47        """
48        Initialize the GarbageModelAnalyzer instance.
49
50        Loads metadata from the dataset path and sets up paths
51        for storing performance figures. Detects CUDA availability
52        for GPU acceleration.
53
54        Parameters
55        ----------
56        dataset_path : str, optional
57            Path to the dataset folder. Default is
58            "../data/raw/sample_dataset".
59        performance_path : str, optional
60            Path to the directory for saving performance figures. Default is
61            "../reports/figures/performance/".
62
63        Returns
64        -------
65        None
66        """
67        self.dataset_path = dataset_path or os.path.join(
68            "..", "data", "raw", "sample_dataset"
69        )
70        self.performance_path = performance_path or \
71            "../reports/figures/performance/"
72        self.device = torch.device(
73            "cuda" if torch.cuda.is_available() else "cpu")
74        metadata_path = Path(cfg.DATASET_PATH).parent / "metadata.csv"
75
76        if metadata_path.exists():
77            self.df = pd.read_csv(metadata_path)
78        else:
79            print(f"Warning: metadata.csv not found at {metadata_path}")
80            self.df = None
81
82        self.model = None
83        self.data_module = None

Initialize the GarbageModelAnalyzer instance.

Loads metadata from the dataset path and sets up paths for storing performance figures. Detects CUDA availability for GPU acceleration.

Parameters
  • dataset_path (str, optional): Path to the dataset folder. Default is "../data/raw/sample_dataset".
  • performance_path (str, optional): Path to the directory for saving performance figures. Default is "../reports/figures/performance/".
Returns
  • None
dataset_path
performance_path
device
model
data_module
def load_model(self, checkpoint_path=None, num_classes=None):
 85    def load_model(self, checkpoint_path=None, num_classes=None):
 86        """
 87        Load a trained GarbageClassifier model from a checkpoint.
 88
 89        Loads a model from a PyTorch Lightning checkpoint and moves it to the
 90        specified device in evaluation mode.
 91
 92        Parameters
 93        ----------
 94        checkpoint_path : str, optional
 95            Path to the model checkpoint file. If None, uses cfg.MODEL_PATH.
 96            Default is None.
 97        num_classes : int, optional
 98            Number of output classes. If None, uses cfg.NUM_CLASSES.
 99            Default is None.
100
101        Returns
102        -------
103        None
104
105        Notes
106        -----
107        Model is automatically set to evaluation mode and moved
108        to the configured device.
109        """
110        checkpoint_path = checkpoint_path or cfg.MODEL_PATH
111        num_classes = num_classes or cfg.NUM_CLASSES
112        print("Loading model...")
113        self.model = GarbageClassifier.load_from_checkpoint(
114            checkpoint_path, num_classes=num_classes
115        )
116        self.model.to(self.device).eval()
117        print("Model loaded.")

Load a trained GarbageClassifier model from a checkpoint.

Loads a model from a PyTorch Lightning checkpoint and moves it to the specified device in evaluation mode.

Parameters
  • checkpoint_path (str, optional): Path to the model checkpoint file. If None, uses cfg.MODEL_PATH. Default is None.
  • num_classes (int, optional): Number of output classes. If None, uses cfg.NUM_CLASSES. Default is None.
Returns
  • None
Notes

Model is automatically set to evaluation mode and moved to the configured device.

def setup_data(self, batch_size=32):
119    def setup_data(self, batch_size=32):
120        """
121        Set up the data module and filter metadata for available samples.
122
123        Initializes the GarbageDataModule and creates a filtered subset
124        of metadata containing only files present in the dataset directory.
125
126        Parameters
127        ----------
128        batch_size : int, optional
129            Batch size for data loading. Default is 32.
130
131        Returns
132        -------
133        None
134
135        Notes
136        -----
137        Creates self.df_subset containing only samples actually
138        present in the dataset.
139        """
140
141        self.data_module = GarbageDataModule(batch_size=batch_size)
142        self.data_module.setup()
143        file_names = []
144        for root, dirs, files in os.walk(cfg.DATASET_PATH):
145            for file in files:
146                file_names.append(file)
147        self.df_subset = (
148            self.df[
149                self.df["filename"].isin(file_names)
150            ].reset_index(drop=True).copy()
151        )

Set up the data module and filter metadata for available samples.

Initializes the GarbageDataModule and creates a filtered subset of metadata containing only files present in the dataset directory.

Parameters
  • batch_size (int, optional): Batch size for data loading. Default is 32.
Returns
  • None
Notes

Creates self.df_subset containing only samples actually present in the dataset.

def evaluate_loader(self, loader):
153    def evaluate_loader(self, loader):
154        """
155        Evaluate model on a data loader and collect predictions
156        and probabilities.
157
158        Iterates through a PyTorch DataLoader, performs inference,
159        and collects predictions, true labels, and confidence scores.
160
161        Parameters
162        ----------
163        loader : torch.utils.data.DataLoader
164            DataLoader containing batches of (images, labels).
165
166        Returns
167        -------
168        tuple
169            A tuple containing:
170            - all_preds : torch.Tensor
171                Predicted class indices.
172            - all_labels : torch.Tensor
173                True class labels.
174            - all_probs : np.ndarray
175                Confidence scores for each class (shape: [N, num_classes]).
176
177        Notes
178        -----
179        Model inference is performed without gradient computation
180        for efficiency.
181        Probabilities are computed using softmax activation.
182        """
183        all_preds, all_labels, all_probs = [], [], []
184        with torch.no_grad():
185            for xb, yb in loader:
186                xb = xb.to(self.device)
187                yb = yb.to(self.device)
188                out = self.model(xb)
189                preds = out.argmax(dim=1)
190                probs = torch.softmax(out, dim=1)
191                all_preds.append(preds)
192                all_probs.append(probs.cpu())
193                all_labels.append(yb)
194        all_preds = torch.cat(all_preds)
195        all_labels = torch.cat(all_labels)
196        all_probs = torch.cat(all_probs).numpy()
197        return all_preds, all_labels, all_probs

Evaluate model on a data loader and collect predictions and probabilities.

Iterates through a PyTorch DataLoader, performs inference, and collects predictions, true labels, and confidence scores.

Parameters
  • loader (torch.utils.data.DataLoader): DataLoader containing batches of (images, labels).
Returns
  • tuple: A tuple containing:
    • all_preds : torch.Tensor Predicted class indices.
    • all_labels : torch.Tensor True class labels.
    • all_probs : np.ndarray Confidence scores for each class (shape: [N, num_classes]).
Notes

Model inference is performed without gradient computation for efficiency. Probabilities are computed using softmax activation.

def plot_confusion_matrix(self, labels, preds, set_name='Train'):
199    def plot_confusion_matrix(self, labels, preds, set_name="Train"):
200        """
201        Plot and save confusion matrices (raw and normalized)
202        with class metrics.
203
204        Generates both raw and normalized confusion matrices,
205        computes TP, FP, FN, TN per class, and saves visualizations
206        as PDF files.
207
208        Parameters
209        ----------
210        labels : array-like
211            True class labels.
212        preds : array-like
213            Predicted class labels.
214        set_name : str, optional
215            Name of the dataset split (e.g., "Train", "Val", "Test")
216            for plot titles and filenames. Default is "Train".
217
218        Returns
219        -------
220        None
221
222        Side Effects
223        -----------
224        - Saves confusion_mat_{set_name.lower()}.pdf to performance_path.
225        - Saves confusion_mat_{set_name.lower()}_norm.pdf to performance_path.
226        - Prints TP, FP, FN, TN statistics for each class.
227        - Displays matplotlib figures.
228
229        Notes
230        -----
231        Normalized confusion matrix divides by row sums to show percentages
232        per class.
233        """
234        num_classes = self.data_module.num_classes
235        cm = confusion_matrix(labels, preds, labels=range(num_classes))
236        disp = ConfusionMatrixDisplay(
237            confusion_matrix=cm, display_labels=cfg.CLASS_NAMES
238        )
239        disp.plot(cmap=plt.cm.Blues)
240        plt.title(f"Confusion Matrix - {set_name} set")
241        plt.savefig(
242            os.path.join(
243                self.performance_path, f"confusion_mat_{set_name.lower()}.pdf"
244            ),
245            dpi=80,
246        )
247        plt.show()
248
249        # Normalized
250        cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
251        disp_norm = ConfusionMatrixDisplay(
252            confusion_matrix=cm_norm, display_labels=cfg.CLASS_NAMES
253        )
254        disp_norm.plot(cmap=plt.cm.Blues)
255        plt.title(f"Normalized Confusion Matrix - {set_name} set")
256        plt.savefig(
257            os.path.join(
258                self.performance_path,
259                f"confusion_mat_{set_name.lower()}_norm.pdf"
260            ),
261            dpi=80,
262        )
263        plt.show()
264
265        # TP, FP, FN, TN
266        TP = np.diag(cm)
267        FP = cm.sum(axis=0) - TP
268        FN = cm.sum(axis=1) - TP
269        TN = cm.sum() - (TP + FP + FN)
270        for i in range(num_classes):
271            print(f"Clase {i}: TP={TP[i]}, FP={FP[i]}, FN={FN[i]}, TN={TN[i]}")

Plot and save confusion matrices (raw and normalized) with class metrics.

Generates both raw and normalized confusion matrices, computes TP, FP, FN, TN per class, and saves visualizations as PDF files.

Parameters
  • labels (array-like): True class labels.
  • preds (array-like): Predicted class labels.
  • set_name (str, optional): Name of the dataset split (e.g., "Train", "Val", "Test") for plot titles and filenames. Default is "Train".
Returns
  • None
Side Effects
  • Saves confusion_mat_{set_name.lower()}.pdf to performance_path.
  • Saves confusion_mat_{set_name.lower()}_norm.pdf to performance_path.
  • Prints TP, FP, FN, TN statistics for each class.
  • Displays matplotlib figures.
Notes

Normalized confusion matrix divides by row sums to show percentages per class.

def plot_top_misclassified(self, df_set, y_true, y_pred, y_proba, N=10, filename=None):
273    def plot_top_misclassified(
274        self, df_set, y_true, y_pred, y_proba, N=10, filename=None
275    ):
276        """
277        Visualize the top N misclassified samples with lowest confidence.
278
279        Identifies misclassified samples, sorts them by confidence
280        on the true class, and displays the least confident (worst)
281        predictions with their images.
282
283        Parameters
284        ----------
285        df_set : pd.DataFrame
286            DataFrame containing sample metadata with 'label' and 'filename'
287            columns.
288        y_true : array-like
289            True class labels (can be integers or strings).
290        y_pred : array-like
291            Predicted class labels (integers).
292        y_proba : np.ndarray
293            Confidence scores matrix (shape: [N, num_classes]).
294        N : int, optional
295            Number of top misclassified samples to display. Default is 10.
296        filename : str, optional
297            Filename (without extension) to save the figure. If provided,
298            saves as PDF to performance_path. Default is None (no save).
299
300        Returns
301        -------
302        None
303
304        Side Effects
305        -----------
306        - Displays matplotlib figure with misclassified samples.
307        - Saves figure to performance_path/{filename}.pdf
308        if filename is provided.
309
310        Notes
311        -----
312        Samples are sorted by confidence on the true class in ascending order,
313        showing the most uncertain misclassifications first.
314        Images are loaded from dataset_path/{label}/{filename} structure.
315        """
316        y_true = np.array(y_true)
317        y_pred = np.array(y_pred)
318
319        # Handle case: y_true are integers vs strings
320        if np.issubdtype(y_true.dtype, np.integer):
321            true_indices = y_true
322            classes = sorted(df_set["label"].unique())
323        else:
324            classes = sorted(np.unique(y_true))
325            class_to_idx = {cls: i for i, cls in enumerate(classes)}
326            true_indices = np.array([class_to_idx[label] for label in y_true])
327
328        true_confidences = y_proba[np.arange(len(y_true)), true_indices]
329        misclassified_idx = np.where(y_true != y_pred)[0]
330
331        if len(misclassified_idx) == 0:
332            print("No misclassified samples found!")
333            return
334
335        sorted_idx = misclassified_idx[
336            np.argsort(true_confidences[misclassified_idx])
337        ]
338        selected_idx = sorted_idx[:N]
339
340        plt.figure(figsize=(15, 3 * (N // 5 + 1)))
341        for i, idx in enumerate(selected_idx, 1):
342            row = df_set.iloc[idx]
343            img_path = os.path.join(
344                self.dataset_path,
345                row["label"],
346                row["filename"]
347            )
348            if not os.path.exists(img_path):
349                continue
350            try:
351                img = Image.open(img_path).convert("RGB")
352            except Exception:
353                continue
354            plt.subplot(int(np.ceil(N / 5)), 5, i)
355            plt.imshow(img)
356            plt.axis("off")
357            plt.title(
358                f"True: {row['label']}\nPred: \
359                    {classes[y_pred[idx]]}\nConf True Class: \
360                        {true_confidences[idx]:.2f}",
361                fontsize=9,
362                color="red",
363            )
364        plt.tight_layout()
365        if filename:
366            plt.savefig(
367                os.path.join(
368                    self.performance_path,
369                    f"{filename}.pdf"
370                ),
371                dpi=80
372            )
373        plt.show()

Visualize the top N misclassified samples with lowest confidence.

Identifies misclassified samples, sorts them by confidence on the true class, and displays the least confident (worst) predictions with their images.

Parameters
  • df_set (pd.DataFrame): DataFrame containing sample metadata with 'label' and 'filename' columns.
  • y_true (array-like): True class labels (can be integers or strings).
  • y_pred (array-like): Predicted class labels (integers).
  • y_proba (np.ndarray): Confidence scores matrix (shape: [N, num_classes]).
  • N (int, optional): Number of top misclassified samples to display. Default is 10.
  • filename (str, optional): Filename (without extension) to save the figure. If provided, saves as PDF to performance_path. Default is None (no save).
Returns
  • None
Side Effects
  • Displays matplotlib figure with misclassified samples.
  • Saves figure to performance_path/{filename}.pdf if filename is provided.
Notes

Samples are sorted by confidence on the true class in ascending order, showing the most uncertain misclassifications first. Images are loaded from dataset_path/{label}/{filename} structure.

def plot_calibration_curves(self, y_true, y_probs):
375    def plot_calibration_curves(self, y_true, y_probs):
376        """
377        Plot calibration curves for all classes using one-vs-rest approach.
378
379        Generates calibration curves showing the relationship between predicted
380        probabilities and actual positive fractions for each class.
381
382        Parameters
383        ----------
384        y_true : array-like or torch.Tensor
385            True class labels (integers). Shape: [N,].
386        y_probs : np.ndarray or torch.Tensor
387            Predicted probability matrix. Shape: [N, num_classes].
388
389        Returns
390        -------
391        None
392
393        Side Effects
394        -----------
395        - Displays matplotlib figure with 2x3 grid of calibration curves.
396        - One subplot per class showing calibration curve and reference
397        diagonal.
398
399        Notes
400        -----
401        Uses sklearn's calibration_curve function with 10 bins for each class.
402        Each class is converted to a binary classification problem (OVR).
403        Calibration curves below the diagonal indicate:
404        -> overconfident predictions.
405        Calibration curves above the diagonal indicate:
406        -> underconfident predictions.
407        """
408        num_classes = self.data_module.num_classes
409        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
410        axes = axes.flatten()
411
412        if isinstance(y_true, torch.Tensor):
413            y_true_np = y_true.cpu().numpy()
414        else:
415            y_true_np = y_true
416
417        if isinstance(y_probs, torch.Tensor):
418            y_probs = y_probs.cpu().numpy()
419
420        for c in range(num_classes):
421            ax = axes[c]
422            y_true_c = (y_true_np == c).astype(int)
423            y_prob_c = y_probs[:, c]
424
425            frac_pos, mean_pred = calibration_curve(
426                y_true_c, y_prob_c, n_bins=10
427            )
428            ax.plot(mean_pred, frac_pos, marker="o", label=f"Class {c}")
429            ax.plot(
430                [0, 1],
431                [0, 1],
432                linestyle="--",
433                color="gray",
434                label="Reference"
435            )
436            ax.set_xlabel("Mean predicted probability")
437            ax.set_ylabel("Fraction of positives")
438            ax.set_title(f"Calibration Curve: {cfg.CLASS_NAMES[c]}")
439            ax.set_xticks(np.arange(0, 1.1, 0.1))
440            ax.set_yticks(np.arange(0, 1.1, 0.1))
441            ax.set_xlim(-0.05, 1.05)
442            ax.set_ylim(-0.05, 1.05)
443            ax.grid(True)
444            ax.legend(fontsize=8)
445
446        plt.tight_layout()
447        plt.show()

Plot calibration curves for all classes using one-vs-rest approach.

Generates calibration curves showing the relationship between predicted probabilities and actual positive fractions for each class.

Parameters
  • y_true (array-like or torch.Tensor): True class labels (integers). Shape: [N,].
  • y_probs (np.ndarray or torch.Tensor): Predicted probability matrix. Shape: [N, num_classes].
Returns
  • None
Side Effects
  • Displays matplotlib figure with 2x3 grid of calibration curves.
  • One subplot per class showing calibration curve and reference diagonal.
Notes

Uses sklearn's calibration_curve function with 10 bins for each class. Each class is converted to a binary classification problem (OVR). Calibration curves below the diagonal indicate: -> overconfident predictions. Calibration curves above the diagonal indicate: -> underconfident predictions.