source.utils.custom_classes.EdaAnalyzer

  1import os
  2import zipfile
  3import pandas as pd
  4import numpy as np
  5import seaborn as sns
  6import matplotlib.pyplot as plt
  7from matplotlib.figure import Figure
  8from PIL import Image
  9from typing import Optional
 10import cv2
 11
 12
 13class EdaAnalyzer:
 14    """
 15    A class that encapsulates all Exploratory Data Analysis (EDA) utilities
 16    for the Garbage Classification dataset or similar image datasets.
 17
 18    This class provides methods for dataset management, visualization,
 19    and analysis, including downloading datasets from Kaggle,
 20    loading metadata, plotting class distributions, and computing
 21    prototype mean images.
 22
 23    Attributes
 24    ----------
 25    root_path : str
 26        Path to the raw data folder.
 27    dataset_path : str
 28        Path to the dataset folder.
 29    zip_file : str
 30        Path to the zip file for dataset download.
 31    kaggle_url : str
 32        URL for downloading the Kaggle dataset.
 33    metadata_path : str
 34        Path to the metadata.csv file.
 35    df : pd.DataFrame or None
 36        Metadata DataFrame containing dataset information.
 37    """
 38
 39    def __init__(
 40        self,
 41        root_path: str = "./data/raw",
 42        dataset_name: str = "Garbage_Dataset_Classification",
 43    ):
 44        """
 45        Initialize the EdaAnalyzer instance.
 46
 47        Parameters
 48        ----------
 49        root_path : str, optional
 50            Path to the raw data folder. Default is "./data/raw".
 51        dataset_name : str, optional
 52            Name of the dataset folder. Default is
 53            "Garbage_Dataset_Classification".
 54
 55        Returns
 56        -------
 57        None
 58        """
 59        self.root_path = root_path
 60        self.dataset_path = os.path.join(root_path, dataset_name)
 61        self.zip_file = os.path.join(root_path, "garbage-dataset.zip")
 62        self.kaggle_url = "https://www.kaggle.com/\
 63            api/v1/datasets/download/zlatan599/garbage-dataset-classification"
 64        self.metadata_path = os.path.join(self.dataset_path, "metadata.csv")
 65        self.df = None
 66
 67    # -------------------------------------------------------------------------
 68    # Dataset management
 69    # -------------------------------------------------------------------------
 70    def download_with_curl(self):
 71        """
 72        Download Kaggle dataset using curl and API credentials.
 73
 74        This method downloads the garbage dataset from Kaggle using the
 75        Kaggle API credentials stored in ~/.kaggle/kaggle.json. The
 76        dataset is extracted and the zip file is removed after
 77        extraction.
 78
 79        Parameters
 80        ----------
 81        None
 82
 83        Returns
 84        -------
 85        None
 86
 87        Raises
 88        ------
 89        FileNotFoundError
 90            If Kaggle credentials are not found at ~/.kaggle/kaggle.json.
 91        """
 92        print("Downloading dataset with curl...")
 93
 94        os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
 95        os.chmod(os.path.expanduser("~/.kaggle"), 0o700)
 96
 97        cmd = f"curl -L -o {self.zip_file} -u \
 98            `jq -r .username ~/.kaggle/kaggle.json`:\
 99                `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}"
100        os.system(cmd)
101
102        print("Extracting dataset...")
103        with zipfile.ZipFile(self.zip_file, "r") as zip_ref:
104            zip_ref.extractall(self.root_path)
105
106        os.remove(self.zip_file)
107        print("Dataset downloaded and extracted successfully.")
108
109    def ensure_dataset(self):
110        """
111        Check if dataset exists; otherwise, download it.
112
113        Verifies the presence of the dataset at the expected path.
114        If not found, triggers the download process.
115        If already present, prints a confirmation message.
116
117        Parameters
118        ----------
119        None
120
121        Returns
122        -------
123        None
124        """
125        if not os.path.exists(self.dataset_path):
126            self.download_with_curl()
127        else:
128            print(f"{self.dataset_path} already exists, nothing to do.")
129
130    def load_metadata(self):
131        """
132        Load metadata.csv into a pandas DataFrame.
133
134        Reads the metadata CSV file from the dataset path and stores it as
135        self.df. Prints summary statistics about the loaded data.
136
137        Parameters
138        ----------
139        None
140
141        Returns
142        -------
143        pd.DataFrame (The loaded metadata DataFrame containing image filenames
144        and labels).
145
146        Raises
147        ------
148        FileNotFoundError
149            If metadata.csv is not found at the expected path.
150        """
151        if not os.path.exists(self.metadata_path):
152            raise FileNotFoundError(f"Metadata file not found at \
153                {self.metadata_path}")
154        self.df = pd.read_csv(self.metadata_path)
155        print(
156            f"Loaded metadata: {len(self.df)} entries, \
157                {self.df['label'].nunique()} classes."
158        )
159        return self.df
160
161    # -------------------------------------------------------------------------
162    # Visualization utilities
163    # -------------------------------------------------------------------------
164    def plot_random_examples_per_class(
165        self,
166        filename: Optional[str] = None
167    ) -> Figure:
168        """
169        Plot a random image from each class and return the figure.
170
171        Selects one random image per class and displays them in a grid layout.
172        Each subplot is bordered with a color corresponding to its class.
173
174        Parameters
175        ----------
176        filename : str, optional
177            Path to save the generated figure as an image file. If provided,
178            the figure is saved with 150 dpi. Default is None (no save).
179
180        Returns
181        -------
182        matplotlib.figure.Figure
183            The generated figure object containing the plotted images.
184        """
185        df = self.df
186        classes = df["label"].unique()
187        palette = sns.color_palette("tab10", len(classes))
188        class_colors = {cls: palette[i] for i, cls in enumerate(classes)}
189
190        cols, rows = 3, (len(classes) + 2) // 3
191        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
192        axes = axes.flatten()
193
194        for i, cls in enumerate(classes):
195            img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"]
196            img_path = os.path.join(
197                self.dataset_path,
198                "images",
199                cls,
200                img_filename
201            )
202            img = Image.open(img_path)
203
204            ax = axes[i]
205            ax.imshow(img)
206            ax.set_title(cls, fontsize=14, color=class_colors[cls])
207            ax.axis("off")
208            for spine in ax.spines.values():
209                spine.set_edgecolor(class_colors[cls])
210                spine.set_linewidth(4)
211
212        for j in range(i + 1, len(axes)):
213            axes[j].axis("off")
214
215        plt.tight_layout()
216
217        if filename:
218            plt.savefig(filename, dpi=150)
219
220        return fig
221
222    def plot_class_distribution(
223        self,
224        filename: Optional[str] = None
225    ) -> Figure:
226        """
227        Plot class distribution using seaborn and return the figure.
228
229        Creates a countplot showing the number of samples per class, ordered
230        by frequency and using a color palette for visual distinction.
231
232        Parameters
233        ----------
234        filename : str, optional
235            Path to save the generated figure as an image file. If provided,
236            the figure is saved with 150 dpi. Default is None (no save).
237
238        Returns
239        -------
240        matplotlib.figure.Figure
241            The generated figure object containing the class distribution plot.
242        """
243        fig, ax = plt.subplots(figsize=(8, 5))
244        sns.countplot(
245            data=self.df,
246            x="label",
247            order=self.df["label"].value_counts().index,
248            palette="tab10",
249            ax=ax,
250        )
251        ax.set_title("Class Distribution", fontsize=16)
252        ax.set_xlabel("Class")
253        ax.set_ylabel("Count")
254        plt.setp(ax.get_xticklabels(), rotation=45)
255        plt.tight_layout()
256
257        if filename:
258            fig.savefig(filename, dpi=150)
259
260        return fig
261
262    # -------------------------------------------------------------------------
263    # Prototypes
264    # -------------------------------------------------------------------------
265
266    def _compute_mean_images_per_batch(self, batch_size=32):
267        """
268        Compute mean image per class using batch processing.
269
270        Processes images in batches to compute the mean image for each class,
271        reducing memory overhead for large datasets.
272        Images are converted to RGB and normalized to float32.
273
274        Parameters
275        ----------
276        batch_size : int, optional
277            Number of images to process per batch. Default is 32.
278
279        Returns
280        -------
281        dict
282            Dictionary with class names as keys and normalized mean images
283            (values in range [0, 1]) as values.
284
285        Notes
286        -----
287        Images are normalized by dividing by 255.0. Invalid or corrupted images
288        are skipped during processing.
289        """
290
291        classes = self.df["label"].unique()
292        result = {}
293
294        for cls in classes:
295            subset = self.df[self.df["label"] == cls]
296            count = 0
297            mean_acc = None
298
299            for batch_start in range(0, len(subset), batch_size):
300                batch_end = min(batch_start + batch_size, len(subset))
301                batch_rows = subset.iloc[batch_start:batch_end]
302
303                imgs = []
304                for _, row in batch_rows.iterrows():
305                    img_path = os.path.join(
306                        self.dataset_path,
307                        "images",
308                        row["label"],
309                        row["filename"]
310                    )
311                    try:
312                        img = Image.open(img_path).convert("RGB")
313                        imgs.append(np.array(img, dtype=np.float32))
314                    except Exception:
315                        continue
316
317                if imgs:
318                    imgs_stack = np.stack(imgs, axis=0)
319                    batch_mean = np.mean(imgs_stack, axis=0)
320
321                    # Actualizar media acumulada
322                    if mean_acc is None:
323                        mean_acc = batch_mean
324                    else:
325                        aux1 = mean_acc * count
326                        aux2 = batch_mean * len(imgs)
327                        mean_acc = (aux1 + aux2) / (
328                            count + len(imgs)
329                        )
330
331                    count += len(imgs)
332
333            if mean_acc is not None:
334                result[cls] = mean_acc / 255.0
335
336        return result
337
338    def plot_mean_images_per_class(
339        self,
340        filename: Optional[str] = None
341    ) -> Figure:
342        """
343        Compute or load and plot mean images per class, returning the figure.
344
345        Attempts to load pre-computed mean images from a .npy file.
346        If not found, computes them using batch processing and
347        optionally saves the result.
348        Displays all mean images in a grid layout.
349
350        Parameters
351        ----------
352        filename : str, optional
353            Path to the .npy file containing pre-computed mean images,
354            or destination path for saving newly computed mean images.
355            Default is None (no caching).
356
357        Returns
358        -------
359        matplotlib.figure.Figure
360            The generated figure object containing the plotted mean images.
361
362        Notes
363        -----
364        If filename is provided and the file does not exist, computed
365        mean images will be saved to this path for future use.
366        """
367
368        mean_images = None
369
370        if filename and os.path.exists(filename):
371            try:
372                print(f"[INFO] Loading mean images from {filename}")
373                mean_images = np.load(filename, allow_pickle=True).item()
374            except Exception as e:
375                print(f"[WARN] Could not load \
376                    mean images from {filename}: {e}")
377
378        if mean_images is None:
379            print("[INFO] Computing mean images...")
380            mean_images = self._compute_mean_images_per_batch()
381            if filename:
382                np.save(filename, mean_images)
383                print(f"[INFO] Saved mean images to {filename}")
384
385        # --- Plot ---
386        cols, rows = 3, (len(mean_images) + 2) // 3
387        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
388        axes = axes.flatten()
389
390        for i, (cls, img) in enumerate(mean_images.items()):
391            ax = axes[i]
392            ax.imshow(img)
393            ax.set_title(f"Mean {cls}")
394            ax.axis("off")
395
396        for j in range(i + 1, len(axes)):
397            axes[j].axis("off")
398
399        plt.tight_layout()
400
401        return fig
402
403    def plot_mean_images_per_class_with_otsu(
404        self, threshold: float = 0.0, filename: Optional[str] = None
405    ) -> Figure:
406        """
407        Plot mean images per class applying an adjustable Otsu threshold.
408
409        Loads pre-computed mean images and applies a custom thresholding
410        strategy based on Otsu's method with user-defined adjustments.
411        Generates binary masks and overlays them on the original mean
412        images with contour visualization.
413
414        Parameters
415        ----------
416        threshold : float, optional
417            Threshold adjustment parameter. Range: [-1, 1].
418            - -1: Maximum threshold (255, minimal foreground).
419            - 0: Otsu threshold (default).
420            - 1: Minimum threshold (0, maximal foreground).
421            Default is 0.0.
422        filename : str, optional
423            Path to the .npy file containing pre-computed mean images.
424            Must end with ".npy" extension. Default is None.
425
426        Returns
427        -------
428        matplotlib.figure.Figure or None
429            The generated figure object containing the thresholded mean images.
430            Returns None if mean images cannot be loaded or invalid parameters
431            are provided.
432
433        Raises
434        ------
435        None
436
437        Notes
438        -----
439        Red overlays indicate pixels below the threshold (potential
440        foreground objects). Contours are traced around connected
441        components in the binary mask.
442        """
443
444        mean_images = None
445
446        if filename and os.path.exists(filename) and filename.endswith(".npy"):
447            try:
448                print(f"[INFO] Loading mean images from {filename}")
449                mean_images = np.load(filename, allow_pickle=True).item()
450            except Exception as e:
451                print(f"[WARN] Could not load \
452                    mean images from {filename}: {e}")
453                return None
454        else:
455            print("[WARN] No mean images found or invalid file path.")
456            return None
457
458        n_classes = len(mean_images)
459        n_cols = min(3, n_classes)
460        n_rows = int(np.ceil(n_classes / n_cols))
461        fig, axes = plt.subplots(
462            n_rows,
463            n_cols,
464            figsize=(5 * n_cols, 5 * n_rows)
465        )
466        axes = np.array(axes).flatten()
467
468        for i, (cls, mean_image) in enumerate(mean_images.items()):
469            ax = axes[i]
470            gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY)
471
472            if gray.dtype != np.uint8:
473                gray = cv2.normalize(
474                    gray, None, 0, 255, cv2.NORM_MINMAX
475                ).astype(
476                    np.uint8
477                )
478
479            otsu_thresh, _ = cv2.threshold(
480                gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
481            )
482
483            adj = np.clip(threshold, -1, 1)
484            if adj == -1:
485                final_thresh = 255
486            elif adj == 1:
487                final_thresh = 0
488            else:
489                if adj < 0:
490                    final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj)
491                else:
492                    final_thresh = otsu_thresh - (otsu_thresh - 0) * adj
493
494            _, binary = cv2.threshold(
495                gray,
496                final_thresh,
497                255,
498                cv2.THRESH_BINARY
499            )
500
501            mask = (binary == 0).astype(np.uint8)
502            kernel = np.ones((3, 3), np.uint8)
503            mask_dilated = cv2.dilate(mask, kernel, iterations=1)
504            red_overlay = np.zeros((*mask.shape, 4))
505            red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25]
506
507            ax.imshow(mean_image)
508            ax.imshow(red_overlay)
509            ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \
510                (thr={final_thresh:.1f})")
511            ax.axis("off")
512
513            contours, _ = cv2.findContours(
514                mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
515            )
516            for contour in contours:
517                contour = contour.squeeze()
518                if contour.ndim == 2:
519                    ax.plot(
520                        contour[:, 0],
521                        contour[:, 1],
522                        color="red",
523                        linewidth=2
524                    )
525
526        plt.tight_layout()
527
528        return fig
class EdaAnalyzer:
 14class EdaAnalyzer:
 15    """
 16    A class that encapsulates all Exploratory Data Analysis (EDA) utilities
 17    for the Garbage Classification dataset or similar image datasets.
 18
 19    This class provides methods for dataset management, visualization,
 20    and analysis, including downloading datasets from Kaggle,
 21    loading metadata, plotting class distributions, and computing
 22    prototype mean images.
 23
 24    Attributes
 25    ----------
 26    root_path : str
 27        Path to the raw data folder.
 28    dataset_path : str
 29        Path to the dataset folder.
 30    zip_file : str
 31        Path to the zip file for dataset download.
 32    kaggle_url : str
 33        URL for downloading the Kaggle dataset.
 34    metadata_path : str
 35        Path to the metadata.csv file.
 36    df : pd.DataFrame or None
 37        Metadata DataFrame containing dataset information.
 38    """
 39
 40    def __init__(
 41        self,
 42        root_path: str = "./data/raw",
 43        dataset_name: str = "Garbage_Dataset_Classification",
 44    ):
 45        """
 46        Initialize the EdaAnalyzer instance.
 47
 48        Parameters
 49        ----------
 50        root_path : str, optional
 51            Path to the raw data folder. Default is "./data/raw".
 52        dataset_name : str, optional
 53            Name of the dataset folder. Default is
 54            "Garbage_Dataset_Classification".
 55
 56        Returns
 57        -------
 58        None
 59        """
 60        self.root_path = root_path
 61        self.dataset_path = os.path.join(root_path, dataset_name)
 62        self.zip_file = os.path.join(root_path, "garbage-dataset.zip")
 63        self.kaggle_url = "https://www.kaggle.com/\
 64            api/v1/datasets/download/zlatan599/garbage-dataset-classification"
 65        self.metadata_path = os.path.join(self.dataset_path, "metadata.csv")
 66        self.df = None
 67
 68    # -------------------------------------------------------------------------
 69    # Dataset management
 70    # -------------------------------------------------------------------------
 71    def download_with_curl(self):
 72        """
 73        Download Kaggle dataset using curl and API credentials.
 74
 75        This method downloads the garbage dataset from Kaggle using the
 76        Kaggle API credentials stored in ~/.kaggle/kaggle.json. The
 77        dataset is extracted and the zip file is removed after
 78        extraction.
 79
 80        Parameters
 81        ----------
 82        None
 83
 84        Returns
 85        -------
 86        None
 87
 88        Raises
 89        ------
 90        FileNotFoundError
 91            If Kaggle credentials are not found at ~/.kaggle/kaggle.json.
 92        """
 93        print("Downloading dataset with curl...")
 94
 95        os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
 96        os.chmod(os.path.expanduser("~/.kaggle"), 0o700)
 97
 98        cmd = f"curl -L -o {self.zip_file} -u \
 99            `jq -r .username ~/.kaggle/kaggle.json`:\
100                `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}"
101        os.system(cmd)
102
103        print("Extracting dataset...")
104        with zipfile.ZipFile(self.zip_file, "r") as zip_ref:
105            zip_ref.extractall(self.root_path)
106
107        os.remove(self.zip_file)
108        print("Dataset downloaded and extracted successfully.")
109
110    def ensure_dataset(self):
111        """
112        Check if dataset exists; otherwise, download it.
113
114        Verifies the presence of the dataset at the expected path.
115        If not found, triggers the download process.
116        If already present, prints a confirmation message.
117
118        Parameters
119        ----------
120        None
121
122        Returns
123        -------
124        None
125        """
126        if not os.path.exists(self.dataset_path):
127            self.download_with_curl()
128        else:
129            print(f"{self.dataset_path} already exists, nothing to do.")
130
131    def load_metadata(self):
132        """
133        Load metadata.csv into a pandas DataFrame.
134
135        Reads the metadata CSV file from the dataset path and stores it as
136        self.df. Prints summary statistics about the loaded data.
137
138        Parameters
139        ----------
140        None
141
142        Returns
143        -------
144        pd.DataFrame (The loaded metadata DataFrame containing image filenames
145        and labels).
146
147        Raises
148        ------
149        FileNotFoundError
150            If metadata.csv is not found at the expected path.
151        """
152        if not os.path.exists(self.metadata_path):
153            raise FileNotFoundError(f"Metadata file not found at \
154                {self.metadata_path}")
155        self.df = pd.read_csv(self.metadata_path)
156        print(
157            f"Loaded metadata: {len(self.df)} entries, \
158                {self.df['label'].nunique()} classes."
159        )
160        return self.df
161
162    # -------------------------------------------------------------------------
163    # Visualization utilities
164    # -------------------------------------------------------------------------
165    def plot_random_examples_per_class(
166        self,
167        filename: Optional[str] = None
168    ) -> Figure:
169        """
170        Plot a random image from each class and return the figure.
171
172        Selects one random image per class and displays them in a grid layout.
173        Each subplot is bordered with a color corresponding to its class.
174
175        Parameters
176        ----------
177        filename : str, optional
178            Path to save the generated figure as an image file. If provided,
179            the figure is saved with 150 dpi. Default is None (no save).
180
181        Returns
182        -------
183        matplotlib.figure.Figure
184            The generated figure object containing the plotted images.
185        """
186        df = self.df
187        classes = df["label"].unique()
188        palette = sns.color_palette("tab10", len(classes))
189        class_colors = {cls: palette[i] for i, cls in enumerate(classes)}
190
191        cols, rows = 3, (len(classes) + 2) // 3
192        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
193        axes = axes.flatten()
194
195        for i, cls in enumerate(classes):
196            img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"]
197            img_path = os.path.join(
198                self.dataset_path,
199                "images",
200                cls,
201                img_filename
202            )
203            img = Image.open(img_path)
204
205            ax = axes[i]
206            ax.imshow(img)
207            ax.set_title(cls, fontsize=14, color=class_colors[cls])
208            ax.axis("off")
209            for spine in ax.spines.values():
210                spine.set_edgecolor(class_colors[cls])
211                spine.set_linewidth(4)
212
213        for j in range(i + 1, len(axes)):
214            axes[j].axis("off")
215
216        plt.tight_layout()
217
218        if filename:
219            plt.savefig(filename, dpi=150)
220
221        return fig
222
223    def plot_class_distribution(
224        self,
225        filename: Optional[str] = None
226    ) -> Figure:
227        """
228        Plot class distribution using seaborn and return the figure.
229
230        Creates a countplot showing the number of samples per class, ordered
231        by frequency and using a color palette for visual distinction.
232
233        Parameters
234        ----------
235        filename : str, optional
236            Path to save the generated figure as an image file. If provided,
237            the figure is saved with 150 dpi. Default is None (no save).
238
239        Returns
240        -------
241        matplotlib.figure.Figure
242            The generated figure object containing the class distribution plot.
243        """
244        fig, ax = plt.subplots(figsize=(8, 5))
245        sns.countplot(
246            data=self.df,
247            x="label",
248            order=self.df["label"].value_counts().index,
249            palette="tab10",
250            ax=ax,
251        )
252        ax.set_title("Class Distribution", fontsize=16)
253        ax.set_xlabel("Class")
254        ax.set_ylabel("Count")
255        plt.setp(ax.get_xticklabels(), rotation=45)
256        plt.tight_layout()
257
258        if filename:
259            fig.savefig(filename, dpi=150)
260
261        return fig
262
263    # -------------------------------------------------------------------------
264    # Prototypes
265    # -------------------------------------------------------------------------
266
267    def _compute_mean_images_per_batch(self, batch_size=32):
268        """
269        Compute mean image per class using batch processing.
270
271        Processes images in batches to compute the mean image for each class,
272        reducing memory overhead for large datasets.
273        Images are converted to RGB and normalized to float32.
274
275        Parameters
276        ----------
277        batch_size : int, optional
278            Number of images to process per batch. Default is 32.
279
280        Returns
281        -------
282        dict
283            Dictionary with class names as keys and normalized mean images
284            (values in range [0, 1]) as values.
285
286        Notes
287        -----
288        Images are normalized by dividing by 255.0. Invalid or corrupted images
289        are skipped during processing.
290        """
291
292        classes = self.df["label"].unique()
293        result = {}
294
295        for cls in classes:
296            subset = self.df[self.df["label"] == cls]
297            count = 0
298            mean_acc = None
299
300            for batch_start in range(0, len(subset), batch_size):
301                batch_end = min(batch_start + batch_size, len(subset))
302                batch_rows = subset.iloc[batch_start:batch_end]
303
304                imgs = []
305                for _, row in batch_rows.iterrows():
306                    img_path = os.path.join(
307                        self.dataset_path,
308                        "images",
309                        row["label"],
310                        row["filename"]
311                    )
312                    try:
313                        img = Image.open(img_path).convert("RGB")
314                        imgs.append(np.array(img, dtype=np.float32))
315                    except Exception:
316                        continue
317
318                if imgs:
319                    imgs_stack = np.stack(imgs, axis=0)
320                    batch_mean = np.mean(imgs_stack, axis=0)
321
322                    # Actualizar media acumulada
323                    if mean_acc is None:
324                        mean_acc = batch_mean
325                    else:
326                        aux1 = mean_acc * count
327                        aux2 = batch_mean * len(imgs)
328                        mean_acc = (aux1 + aux2) / (
329                            count + len(imgs)
330                        )
331
332                    count += len(imgs)
333
334            if mean_acc is not None:
335                result[cls] = mean_acc / 255.0
336
337        return result
338
339    def plot_mean_images_per_class(
340        self,
341        filename: Optional[str] = None
342    ) -> Figure:
343        """
344        Compute or load and plot mean images per class, returning the figure.
345
346        Attempts to load pre-computed mean images from a .npy file.
347        If not found, computes them using batch processing and
348        optionally saves the result.
349        Displays all mean images in a grid layout.
350
351        Parameters
352        ----------
353        filename : str, optional
354            Path to the .npy file containing pre-computed mean images,
355            or destination path for saving newly computed mean images.
356            Default is None (no caching).
357
358        Returns
359        -------
360        matplotlib.figure.Figure
361            The generated figure object containing the plotted mean images.
362
363        Notes
364        -----
365        If filename is provided and the file does not exist, computed
366        mean images will be saved to this path for future use.
367        """
368
369        mean_images = None
370
371        if filename and os.path.exists(filename):
372            try:
373                print(f"[INFO] Loading mean images from {filename}")
374                mean_images = np.load(filename, allow_pickle=True).item()
375            except Exception as e:
376                print(f"[WARN] Could not load \
377                    mean images from {filename}: {e}")
378
379        if mean_images is None:
380            print("[INFO] Computing mean images...")
381            mean_images = self._compute_mean_images_per_batch()
382            if filename:
383                np.save(filename, mean_images)
384                print(f"[INFO] Saved mean images to {filename}")
385
386        # --- Plot ---
387        cols, rows = 3, (len(mean_images) + 2) // 3
388        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
389        axes = axes.flatten()
390
391        for i, (cls, img) in enumerate(mean_images.items()):
392            ax = axes[i]
393            ax.imshow(img)
394            ax.set_title(f"Mean {cls}")
395            ax.axis("off")
396
397        for j in range(i + 1, len(axes)):
398            axes[j].axis("off")
399
400        plt.tight_layout()
401
402        return fig
403
404    def plot_mean_images_per_class_with_otsu(
405        self, threshold: float = 0.0, filename: Optional[str] = None
406    ) -> Figure:
407        """
408        Plot mean images per class applying an adjustable Otsu threshold.
409
410        Loads pre-computed mean images and applies a custom thresholding
411        strategy based on Otsu's method with user-defined adjustments.
412        Generates binary masks and overlays them on the original mean
413        images with contour visualization.
414
415        Parameters
416        ----------
417        threshold : float, optional
418            Threshold adjustment parameter. Range: [-1, 1].
419            - -1: Maximum threshold (255, minimal foreground).
420            - 0: Otsu threshold (default).
421            - 1: Minimum threshold (0, maximal foreground).
422            Default is 0.0.
423        filename : str, optional
424            Path to the .npy file containing pre-computed mean images.
425            Must end with ".npy" extension. Default is None.
426
427        Returns
428        -------
429        matplotlib.figure.Figure or None
430            The generated figure object containing the thresholded mean images.
431            Returns None if mean images cannot be loaded or invalid parameters
432            are provided.
433
434        Raises
435        ------
436        None
437
438        Notes
439        -----
440        Red overlays indicate pixels below the threshold (potential
441        foreground objects). Contours are traced around connected
442        components in the binary mask.
443        """
444
445        mean_images = None
446
447        if filename and os.path.exists(filename) and filename.endswith(".npy"):
448            try:
449                print(f"[INFO] Loading mean images from {filename}")
450                mean_images = np.load(filename, allow_pickle=True).item()
451            except Exception as e:
452                print(f"[WARN] Could not load \
453                    mean images from {filename}: {e}")
454                return None
455        else:
456            print("[WARN] No mean images found or invalid file path.")
457            return None
458
459        n_classes = len(mean_images)
460        n_cols = min(3, n_classes)
461        n_rows = int(np.ceil(n_classes / n_cols))
462        fig, axes = plt.subplots(
463            n_rows,
464            n_cols,
465            figsize=(5 * n_cols, 5 * n_rows)
466        )
467        axes = np.array(axes).flatten()
468
469        for i, (cls, mean_image) in enumerate(mean_images.items()):
470            ax = axes[i]
471            gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY)
472
473            if gray.dtype != np.uint8:
474                gray = cv2.normalize(
475                    gray, None, 0, 255, cv2.NORM_MINMAX
476                ).astype(
477                    np.uint8
478                )
479
480            otsu_thresh, _ = cv2.threshold(
481                gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
482            )
483
484            adj = np.clip(threshold, -1, 1)
485            if adj == -1:
486                final_thresh = 255
487            elif adj == 1:
488                final_thresh = 0
489            else:
490                if adj < 0:
491                    final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj)
492                else:
493                    final_thresh = otsu_thresh - (otsu_thresh - 0) * adj
494
495            _, binary = cv2.threshold(
496                gray,
497                final_thresh,
498                255,
499                cv2.THRESH_BINARY
500            )
501
502            mask = (binary == 0).astype(np.uint8)
503            kernel = np.ones((3, 3), np.uint8)
504            mask_dilated = cv2.dilate(mask, kernel, iterations=1)
505            red_overlay = np.zeros((*mask.shape, 4))
506            red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25]
507
508            ax.imshow(mean_image)
509            ax.imshow(red_overlay)
510            ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \
511                (thr={final_thresh:.1f})")
512            ax.axis("off")
513
514            contours, _ = cv2.findContours(
515                mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
516            )
517            for contour in contours:
518                contour = contour.squeeze()
519                if contour.ndim == 2:
520                    ax.plot(
521                        contour[:, 0],
522                        contour[:, 1],
523                        color="red",
524                        linewidth=2
525                    )
526
527        plt.tight_layout()
528
529        return fig

A class that encapsulates all Exploratory Data Analysis (EDA) utilities for the Garbage Classification dataset or similar image datasets.

This class provides methods for dataset management, visualization, and analysis, including downloading datasets from Kaggle, loading metadata, plotting class distributions, and computing prototype mean images.

Attributes
  • root_path (str): Path to the raw data folder.
  • dataset_path (str): Path to the dataset folder.
  • zip_file (str): Path to the zip file for dataset download.
  • kaggle_url (str): URL for downloading the Kaggle dataset.
  • metadata_path (str): Path to the metadata.csv file.
  • df (pd.DataFrame or None): Metadata DataFrame containing dataset information.
EdaAnalyzer( root_path: str = './data/raw', dataset_name: str = 'Garbage_Dataset_Classification')
40    def __init__(
41        self,
42        root_path: str = "./data/raw",
43        dataset_name: str = "Garbage_Dataset_Classification",
44    ):
45        """
46        Initialize the EdaAnalyzer instance.
47
48        Parameters
49        ----------
50        root_path : str, optional
51            Path to the raw data folder. Default is "./data/raw".
52        dataset_name : str, optional
53            Name of the dataset folder. Default is
54            "Garbage_Dataset_Classification".
55
56        Returns
57        -------
58        None
59        """
60        self.root_path = root_path
61        self.dataset_path = os.path.join(root_path, dataset_name)
62        self.zip_file = os.path.join(root_path, "garbage-dataset.zip")
63        self.kaggle_url = "https://www.kaggle.com/\
64            api/v1/datasets/download/zlatan599/garbage-dataset-classification"
65        self.metadata_path = os.path.join(self.dataset_path, "metadata.csv")
66        self.df = None

Initialize the EdaAnalyzer instance.

Parameters
  • root_path (str, optional): Path to the raw data folder. Default is "./data/raw".
  • dataset_name (str, optional): Name of the dataset folder. Default is "Garbage_Dataset_Classification".
Returns
  • None
root_path
dataset_path
zip_file
kaggle_url
metadata_path
df
def download_with_curl(self):
 71    def download_with_curl(self):
 72        """
 73        Download Kaggle dataset using curl and API credentials.
 74
 75        This method downloads the garbage dataset from Kaggle using the
 76        Kaggle API credentials stored in ~/.kaggle/kaggle.json. The
 77        dataset is extracted and the zip file is removed after
 78        extraction.
 79
 80        Parameters
 81        ----------
 82        None
 83
 84        Returns
 85        -------
 86        None
 87
 88        Raises
 89        ------
 90        FileNotFoundError
 91            If Kaggle credentials are not found at ~/.kaggle/kaggle.json.
 92        """
 93        print("Downloading dataset with curl...")
 94
 95        os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
 96        os.chmod(os.path.expanduser("~/.kaggle"), 0o700)
 97
 98        cmd = f"curl -L -o {self.zip_file} -u \
 99            `jq -r .username ~/.kaggle/kaggle.json`:\
100                `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}"
101        os.system(cmd)
102
103        print("Extracting dataset...")
104        with zipfile.ZipFile(self.zip_file, "r") as zip_ref:
105            zip_ref.extractall(self.root_path)
106
107        os.remove(self.zip_file)
108        print("Dataset downloaded and extracted successfully.")

Download Kaggle dataset using curl and API credentials.

This method downloads the garbage dataset from Kaggle using the Kaggle API credentials stored in ~/.kaggle/kaggle.json. The dataset is extracted and the zip file is removed after extraction.

Parameters
  • None
Returns
  • None
Raises
  • FileNotFoundError: If Kaggle credentials are not found at ~/.kaggle/kaggle.json.
def ensure_dataset(self):
110    def ensure_dataset(self):
111        """
112        Check if dataset exists; otherwise, download it.
113
114        Verifies the presence of the dataset at the expected path.
115        If not found, triggers the download process.
116        If already present, prints a confirmation message.
117
118        Parameters
119        ----------
120        None
121
122        Returns
123        -------
124        None
125        """
126        if not os.path.exists(self.dataset_path):
127            self.download_with_curl()
128        else:
129            print(f"{self.dataset_path} already exists, nothing to do.")

Check if dataset exists; otherwise, download it.

Verifies the presence of the dataset at the expected path. If not found, triggers the download process. If already present, prints a confirmation message.

Parameters
  • None
Returns
  • None
def load_metadata(self):
131    def load_metadata(self):
132        """
133        Load metadata.csv into a pandas DataFrame.
134
135        Reads the metadata CSV file from the dataset path and stores it as
136        self.df. Prints summary statistics about the loaded data.
137
138        Parameters
139        ----------
140        None
141
142        Returns
143        -------
144        pd.DataFrame (The loaded metadata DataFrame containing image filenames
145        and labels).
146
147        Raises
148        ------
149        FileNotFoundError
150            If metadata.csv is not found at the expected path.
151        """
152        if not os.path.exists(self.metadata_path):
153            raise FileNotFoundError(f"Metadata file not found at \
154                {self.metadata_path}")
155        self.df = pd.read_csv(self.metadata_path)
156        print(
157            f"Loaded metadata: {len(self.df)} entries, \
158                {self.df['label'].nunique()} classes."
159        )
160        return self.df

Load metadata.csv into a pandas DataFrame.

Reads the metadata CSV file from the dataset path and stores it as self.df. Prints summary statistics about the loaded data.

Parameters
  • None
Returns
  • pd.DataFrame (The loaded metadata DataFrame containing image filenames
  • and labels).
Raises
  • FileNotFoundError: If metadata.csv is not found at the expected path.
def plot_random_examples_per_class(self, filename: Optional[str] = None) -> matplotlib.figure.Figure:
165    def plot_random_examples_per_class(
166        self,
167        filename: Optional[str] = None
168    ) -> Figure:
169        """
170        Plot a random image from each class and return the figure.
171
172        Selects one random image per class and displays them in a grid layout.
173        Each subplot is bordered with a color corresponding to its class.
174
175        Parameters
176        ----------
177        filename : str, optional
178            Path to save the generated figure as an image file. If provided,
179            the figure is saved with 150 dpi. Default is None (no save).
180
181        Returns
182        -------
183        matplotlib.figure.Figure
184            The generated figure object containing the plotted images.
185        """
186        df = self.df
187        classes = df["label"].unique()
188        palette = sns.color_palette("tab10", len(classes))
189        class_colors = {cls: palette[i] for i, cls in enumerate(classes)}
190
191        cols, rows = 3, (len(classes) + 2) // 3
192        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
193        axes = axes.flatten()
194
195        for i, cls in enumerate(classes):
196            img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"]
197            img_path = os.path.join(
198                self.dataset_path,
199                "images",
200                cls,
201                img_filename
202            )
203            img = Image.open(img_path)
204
205            ax = axes[i]
206            ax.imshow(img)
207            ax.set_title(cls, fontsize=14, color=class_colors[cls])
208            ax.axis("off")
209            for spine in ax.spines.values():
210                spine.set_edgecolor(class_colors[cls])
211                spine.set_linewidth(4)
212
213        for j in range(i + 1, len(axes)):
214            axes[j].axis("off")
215
216        plt.tight_layout()
217
218        if filename:
219            plt.savefig(filename, dpi=150)
220
221        return fig

Plot a random image from each class and return the figure.

Selects one random image per class and displays them in a grid layout. Each subplot is bordered with a color corresponding to its class.

Parameters
  • filename (str, optional): Path to save the generated figure as an image file. If provided, the figure is saved with 150 dpi. Default is None (no save).
Returns
  • matplotlib.figure.Figure: The generated figure object containing the plotted images.
def plot_class_distribution(self, filename: Optional[str] = None) -> matplotlib.figure.Figure:
223    def plot_class_distribution(
224        self,
225        filename: Optional[str] = None
226    ) -> Figure:
227        """
228        Plot class distribution using seaborn and return the figure.
229
230        Creates a countplot showing the number of samples per class, ordered
231        by frequency and using a color palette for visual distinction.
232
233        Parameters
234        ----------
235        filename : str, optional
236            Path to save the generated figure as an image file. If provided,
237            the figure is saved with 150 dpi. Default is None (no save).
238
239        Returns
240        -------
241        matplotlib.figure.Figure
242            The generated figure object containing the class distribution plot.
243        """
244        fig, ax = plt.subplots(figsize=(8, 5))
245        sns.countplot(
246            data=self.df,
247            x="label",
248            order=self.df["label"].value_counts().index,
249            palette="tab10",
250            ax=ax,
251        )
252        ax.set_title("Class Distribution", fontsize=16)
253        ax.set_xlabel("Class")
254        ax.set_ylabel("Count")
255        plt.setp(ax.get_xticklabels(), rotation=45)
256        plt.tight_layout()
257
258        if filename:
259            fig.savefig(filename, dpi=150)
260
261        return fig

Plot class distribution using seaborn and return the figure.

Creates a countplot showing the number of samples per class, ordered by frequency and using a color palette for visual distinction.

Parameters
  • filename (str, optional): Path to save the generated figure as an image file. If provided, the figure is saved with 150 dpi. Default is None (no save).
Returns
  • matplotlib.figure.Figure: The generated figure object containing the class distribution plot.
def plot_mean_images_per_class(self, filename: Optional[str] = None) -> matplotlib.figure.Figure:
339    def plot_mean_images_per_class(
340        self,
341        filename: Optional[str] = None
342    ) -> Figure:
343        """
344        Compute or load and plot mean images per class, returning the figure.
345
346        Attempts to load pre-computed mean images from a .npy file.
347        If not found, computes them using batch processing and
348        optionally saves the result.
349        Displays all mean images in a grid layout.
350
351        Parameters
352        ----------
353        filename : str, optional
354            Path to the .npy file containing pre-computed mean images,
355            or destination path for saving newly computed mean images.
356            Default is None (no caching).
357
358        Returns
359        -------
360        matplotlib.figure.Figure
361            The generated figure object containing the plotted mean images.
362
363        Notes
364        -----
365        If filename is provided and the file does not exist, computed
366        mean images will be saved to this path for future use.
367        """
368
369        mean_images = None
370
371        if filename and os.path.exists(filename):
372            try:
373                print(f"[INFO] Loading mean images from {filename}")
374                mean_images = np.load(filename, allow_pickle=True).item()
375            except Exception as e:
376                print(f"[WARN] Could not load \
377                    mean images from {filename}: {e}")
378
379        if mean_images is None:
380            print("[INFO] Computing mean images...")
381            mean_images = self._compute_mean_images_per_batch()
382            if filename:
383                np.save(filename, mean_images)
384                print(f"[INFO] Saved mean images to {filename}")
385
386        # --- Plot ---
387        cols, rows = 3, (len(mean_images) + 2) // 3
388        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
389        axes = axes.flatten()
390
391        for i, (cls, img) in enumerate(mean_images.items()):
392            ax = axes[i]
393            ax.imshow(img)
394            ax.set_title(f"Mean {cls}")
395            ax.axis("off")
396
397        for j in range(i + 1, len(axes)):
398            axes[j].axis("off")
399
400        plt.tight_layout()
401
402        return fig

Compute or load and plot mean images per class, returning the figure.

Attempts to load pre-computed mean images from a .npy file. If not found, computes them using batch processing and optionally saves the result. Displays all mean images in a grid layout.

Parameters
  • filename (str, optional): Path to the .npy file containing pre-computed mean images, or destination path for saving newly computed mean images. Default is None (no caching).
Returns
  • matplotlib.figure.Figure: The generated figure object containing the plotted mean images.
Notes

If filename is provided and the file does not exist, computed mean images will be saved to this path for future use.

def plot_mean_images_per_class_with_otsu( self, threshold: float = 0.0, filename: Optional[str] = None) -> matplotlib.figure.Figure:
404    def plot_mean_images_per_class_with_otsu(
405        self, threshold: float = 0.0, filename: Optional[str] = None
406    ) -> Figure:
407        """
408        Plot mean images per class applying an adjustable Otsu threshold.
409
410        Loads pre-computed mean images and applies a custom thresholding
411        strategy based on Otsu's method with user-defined adjustments.
412        Generates binary masks and overlays them on the original mean
413        images with contour visualization.
414
415        Parameters
416        ----------
417        threshold : float, optional
418            Threshold adjustment parameter. Range: [-1, 1].
419            - -1: Maximum threshold (255, minimal foreground).
420            - 0: Otsu threshold (default).
421            - 1: Minimum threshold (0, maximal foreground).
422            Default is 0.0.
423        filename : str, optional
424            Path to the .npy file containing pre-computed mean images.
425            Must end with ".npy" extension. Default is None.
426
427        Returns
428        -------
429        matplotlib.figure.Figure or None
430            The generated figure object containing the thresholded mean images.
431            Returns None if mean images cannot be loaded or invalid parameters
432            are provided.
433
434        Raises
435        ------
436        None
437
438        Notes
439        -----
440        Red overlays indicate pixels below the threshold (potential
441        foreground objects). Contours are traced around connected
442        components in the binary mask.
443        """
444
445        mean_images = None
446
447        if filename and os.path.exists(filename) and filename.endswith(".npy"):
448            try:
449                print(f"[INFO] Loading mean images from {filename}")
450                mean_images = np.load(filename, allow_pickle=True).item()
451            except Exception as e:
452                print(f"[WARN] Could not load \
453                    mean images from {filename}: {e}")
454                return None
455        else:
456            print("[WARN] No mean images found or invalid file path.")
457            return None
458
459        n_classes = len(mean_images)
460        n_cols = min(3, n_classes)
461        n_rows = int(np.ceil(n_classes / n_cols))
462        fig, axes = plt.subplots(
463            n_rows,
464            n_cols,
465            figsize=(5 * n_cols, 5 * n_rows)
466        )
467        axes = np.array(axes).flatten()
468
469        for i, (cls, mean_image) in enumerate(mean_images.items()):
470            ax = axes[i]
471            gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY)
472
473            if gray.dtype != np.uint8:
474                gray = cv2.normalize(
475                    gray, None, 0, 255, cv2.NORM_MINMAX
476                ).astype(
477                    np.uint8
478                )
479
480            otsu_thresh, _ = cv2.threshold(
481                gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
482            )
483
484            adj = np.clip(threshold, -1, 1)
485            if adj == -1:
486                final_thresh = 255
487            elif adj == 1:
488                final_thresh = 0
489            else:
490                if adj < 0:
491                    final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj)
492                else:
493                    final_thresh = otsu_thresh - (otsu_thresh - 0) * adj
494
495            _, binary = cv2.threshold(
496                gray,
497                final_thresh,
498                255,
499                cv2.THRESH_BINARY
500            )
501
502            mask = (binary == 0).astype(np.uint8)
503            kernel = np.ones((3, 3), np.uint8)
504            mask_dilated = cv2.dilate(mask, kernel, iterations=1)
505            red_overlay = np.zeros((*mask.shape, 4))
506            red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25]
507
508            ax.imshow(mean_image)
509            ax.imshow(red_overlay)
510            ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \
511                (thr={final_thresh:.1f})")
512            ax.axis("off")
513
514            contours, _ = cv2.findContours(
515                mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
516            )
517            for contour in contours:
518                contour = contour.squeeze()
519                if contour.ndim == 2:
520                    ax.plot(
521                        contour[:, 0],
522                        contour[:, 1],
523                        color="red",
524                        linewidth=2
525                    )
526
527        plt.tight_layout()
528
529        return fig

Plot mean images per class applying an adjustable Otsu threshold.

Loads pre-computed mean images and applies a custom thresholding strategy based on Otsu's method with user-defined adjustments. Generates binary masks and overlays them on the original mean images with contour visualization.

Parameters
  • threshold (float, optional): Threshold adjustment parameter. Range: [-1, 1].
    • -1: Maximum threshold (255, minimal foreground).
    • 0: Otsu threshold (default).
    • 1: Minimum threshold (0, maximal foreground). Default is 0.0.
  • filename (str, optional): Path to the .npy file containing pre-computed mean images. Must end with ".npy" extension. Default is None.
Returns
  • matplotlib.figure.Figure or None: The generated figure object containing the thresholded mean images. Returns None if mean images cannot be loaded or invalid parameters are provided.
Raises
  • None
Notes

Red overlays indicate pixels below the threshold (potential foreground objects). Contours are traced around connected components in the binary mask.