source.utils.custom_classes.EdaAnalyzer
1import os 2import zipfile 3import pandas as pd 4import numpy as np 5import seaborn as sns 6import matplotlib.pyplot as plt 7from matplotlib.figure import Figure 8from PIL import Image 9from typing import Optional 10import cv2 11 12 13class EdaAnalyzer: 14 """ 15 A class that encapsulates all Exploratory Data Analysis (EDA) utilities 16 for the Garbage Classification dataset or similar image datasets. 17 18 This class provides methods for dataset management, visualization, 19 and analysis, including downloading datasets from Kaggle, 20 loading metadata, plotting class distributions, and computing 21 prototype mean images. 22 23 Attributes 24 ---------- 25 root_path : str 26 Path to the raw data folder. 27 dataset_path : str 28 Path to the dataset folder. 29 zip_file : str 30 Path to the zip file for dataset download. 31 kaggle_url : str 32 URL for downloading the Kaggle dataset. 33 metadata_path : str 34 Path to the metadata.csv file. 35 df : pd.DataFrame or None 36 Metadata DataFrame containing dataset information. 37 """ 38 39 def __init__( 40 self, 41 root_path: str = "./data/raw", 42 dataset_name: str = "Garbage_Dataset_Classification", 43 ): 44 """ 45 Initialize the EdaAnalyzer instance. 46 47 Parameters 48 ---------- 49 root_path : str, optional 50 Path to the raw data folder. Default is "./data/raw". 51 dataset_name : str, optional 52 Name of the dataset folder. Default is 53 "Garbage_Dataset_Classification". 54 55 Returns 56 ------- 57 None 58 """ 59 self.root_path = root_path 60 self.dataset_path = os.path.join(root_path, dataset_name) 61 self.zip_file = os.path.join(root_path, "garbage-dataset.zip") 62 self.kaggle_url = "https://www.kaggle.com/\ 63 api/v1/datasets/download/zlatan599/garbage-dataset-classification" 64 self.metadata_path = os.path.join(self.dataset_path, "metadata.csv") 65 self.df = None 66 67 # ------------------------------------------------------------------------- 68 # Dataset management 69 # ------------------------------------------------------------------------- 70 def download_with_curl(self): 71 """ 72 Download Kaggle dataset using curl and API credentials. 73 74 This method downloads the garbage dataset from Kaggle using the 75 Kaggle API credentials stored in ~/.kaggle/kaggle.json. The 76 dataset is extracted and the zip file is removed after 77 extraction. 78 79 Parameters 80 ---------- 81 None 82 83 Returns 84 ------- 85 None 86 87 Raises 88 ------ 89 FileNotFoundError 90 If Kaggle credentials are not found at ~/.kaggle/kaggle.json. 91 """ 92 print("Downloading dataset with curl...") 93 94 os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True) 95 os.chmod(os.path.expanduser("~/.kaggle"), 0o700) 96 97 cmd = f"curl -L -o {self.zip_file} -u \ 98 `jq -r .username ~/.kaggle/kaggle.json`:\ 99 `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}" 100 os.system(cmd) 101 102 print("Extracting dataset...") 103 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: 104 zip_ref.extractall(self.root_path) 105 106 os.remove(self.zip_file) 107 print("Dataset downloaded and extracted successfully.") 108 109 def ensure_dataset(self): 110 """ 111 Check if dataset exists; otherwise, download it. 112 113 Verifies the presence of the dataset at the expected path. 114 If not found, triggers the download process. 115 If already present, prints a confirmation message. 116 117 Parameters 118 ---------- 119 None 120 121 Returns 122 ------- 123 None 124 """ 125 if not os.path.exists(self.dataset_path): 126 self.download_with_curl() 127 else: 128 print(f"{self.dataset_path} already exists, nothing to do.") 129 130 def load_metadata(self): 131 """ 132 Load metadata.csv into a pandas DataFrame. 133 134 Reads the metadata CSV file from the dataset path and stores it as 135 self.df. Prints summary statistics about the loaded data. 136 137 Parameters 138 ---------- 139 None 140 141 Returns 142 ------- 143 pd.DataFrame (The loaded metadata DataFrame containing image filenames 144 and labels). 145 146 Raises 147 ------ 148 FileNotFoundError 149 If metadata.csv is not found at the expected path. 150 """ 151 if not os.path.exists(self.metadata_path): 152 raise FileNotFoundError(f"Metadata file not found at \ 153 {self.metadata_path}") 154 self.df = pd.read_csv(self.metadata_path) 155 print( 156 f"Loaded metadata: {len(self.df)} entries, \ 157 {self.df['label'].nunique()} classes." 158 ) 159 return self.df 160 161 # ------------------------------------------------------------------------- 162 # Visualization utilities 163 # ------------------------------------------------------------------------- 164 def plot_random_examples_per_class( 165 self, 166 filename: Optional[str] = None 167 ) -> Figure: 168 """ 169 Plot a random image from each class and return the figure. 170 171 Selects one random image per class and displays them in a grid layout. 172 Each subplot is bordered with a color corresponding to its class. 173 174 Parameters 175 ---------- 176 filename : str, optional 177 Path to save the generated figure as an image file. If provided, 178 the figure is saved with 150 dpi. Default is None (no save). 179 180 Returns 181 ------- 182 matplotlib.figure.Figure 183 The generated figure object containing the plotted images. 184 """ 185 df = self.df 186 classes = df["label"].unique() 187 palette = sns.color_palette("tab10", len(classes)) 188 class_colors = {cls: palette[i] for i, cls in enumerate(classes)} 189 190 cols, rows = 3, (len(classes) + 2) // 3 191 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 192 axes = axes.flatten() 193 194 for i, cls in enumerate(classes): 195 img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"] 196 img_path = os.path.join( 197 self.dataset_path, 198 "images", 199 cls, 200 img_filename 201 ) 202 img = Image.open(img_path) 203 204 ax = axes[i] 205 ax.imshow(img) 206 ax.set_title(cls, fontsize=14, color=class_colors[cls]) 207 ax.axis("off") 208 for spine in ax.spines.values(): 209 spine.set_edgecolor(class_colors[cls]) 210 spine.set_linewidth(4) 211 212 for j in range(i + 1, len(axes)): 213 axes[j].axis("off") 214 215 plt.tight_layout() 216 217 if filename: 218 plt.savefig(filename, dpi=150) 219 220 return fig 221 222 def plot_class_distribution( 223 self, 224 filename: Optional[str] = None 225 ) -> Figure: 226 """ 227 Plot class distribution using seaborn and return the figure. 228 229 Creates a countplot showing the number of samples per class, ordered 230 by frequency and using a color palette for visual distinction. 231 232 Parameters 233 ---------- 234 filename : str, optional 235 Path to save the generated figure as an image file. If provided, 236 the figure is saved with 150 dpi. Default is None (no save). 237 238 Returns 239 ------- 240 matplotlib.figure.Figure 241 The generated figure object containing the class distribution plot. 242 """ 243 fig, ax = plt.subplots(figsize=(8, 5)) 244 sns.countplot( 245 data=self.df, 246 x="label", 247 order=self.df["label"].value_counts().index, 248 palette="tab10", 249 ax=ax, 250 ) 251 ax.set_title("Class Distribution", fontsize=16) 252 ax.set_xlabel("Class") 253 ax.set_ylabel("Count") 254 plt.setp(ax.get_xticklabels(), rotation=45) 255 plt.tight_layout() 256 257 if filename: 258 fig.savefig(filename, dpi=150) 259 260 return fig 261 262 # ------------------------------------------------------------------------- 263 # Prototypes 264 # ------------------------------------------------------------------------- 265 266 def _compute_mean_images_per_batch(self, batch_size=32): 267 """ 268 Compute mean image per class using batch processing. 269 270 Processes images in batches to compute the mean image for each class, 271 reducing memory overhead for large datasets. 272 Images are converted to RGB and normalized to float32. 273 274 Parameters 275 ---------- 276 batch_size : int, optional 277 Number of images to process per batch. Default is 32. 278 279 Returns 280 ------- 281 dict 282 Dictionary with class names as keys and normalized mean images 283 (values in range [0, 1]) as values. 284 285 Notes 286 ----- 287 Images are normalized by dividing by 255.0. Invalid or corrupted images 288 are skipped during processing. 289 """ 290 291 classes = self.df["label"].unique() 292 result = {} 293 294 for cls in classes: 295 subset = self.df[self.df["label"] == cls] 296 count = 0 297 mean_acc = None 298 299 for batch_start in range(0, len(subset), batch_size): 300 batch_end = min(batch_start + batch_size, len(subset)) 301 batch_rows = subset.iloc[batch_start:batch_end] 302 303 imgs = [] 304 for _, row in batch_rows.iterrows(): 305 img_path = os.path.join( 306 self.dataset_path, 307 "images", 308 row["label"], 309 row["filename"] 310 ) 311 try: 312 img = Image.open(img_path).convert("RGB") 313 imgs.append(np.array(img, dtype=np.float32)) 314 except Exception: 315 continue 316 317 if imgs: 318 imgs_stack = np.stack(imgs, axis=0) 319 batch_mean = np.mean(imgs_stack, axis=0) 320 321 # Actualizar media acumulada 322 if mean_acc is None: 323 mean_acc = batch_mean 324 else: 325 aux1 = mean_acc * count 326 aux2 = batch_mean * len(imgs) 327 mean_acc = (aux1 + aux2) / ( 328 count + len(imgs) 329 ) 330 331 count += len(imgs) 332 333 if mean_acc is not None: 334 result[cls] = mean_acc / 255.0 335 336 return result 337 338 def plot_mean_images_per_class( 339 self, 340 filename: Optional[str] = None 341 ) -> Figure: 342 """ 343 Compute or load and plot mean images per class, returning the figure. 344 345 Attempts to load pre-computed mean images from a .npy file. 346 If not found, computes them using batch processing and 347 optionally saves the result. 348 Displays all mean images in a grid layout. 349 350 Parameters 351 ---------- 352 filename : str, optional 353 Path to the .npy file containing pre-computed mean images, 354 or destination path for saving newly computed mean images. 355 Default is None (no caching). 356 357 Returns 358 ------- 359 matplotlib.figure.Figure 360 The generated figure object containing the plotted mean images. 361 362 Notes 363 ----- 364 If filename is provided and the file does not exist, computed 365 mean images will be saved to this path for future use. 366 """ 367 368 mean_images = None 369 370 if filename and os.path.exists(filename): 371 try: 372 print(f"[INFO] Loading mean images from {filename}") 373 mean_images = np.load(filename, allow_pickle=True).item() 374 except Exception as e: 375 print(f"[WARN] Could not load \ 376 mean images from {filename}: {e}") 377 378 if mean_images is None: 379 print("[INFO] Computing mean images...") 380 mean_images = self._compute_mean_images_per_batch() 381 if filename: 382 np.save(filename, mean_images) 383 print(f"[INFO] Saved mean images to {filename}") 384 385 # --- Plot --- 386 cols, rows = 3, (len(mean_images) + 2) // 3 387 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 388 axes = axes.flatten() 389 390 for i, (cls, img) in enumerate(mean_images.items()): 391 ax = axes[i] 392 ax.imshow(img) 393 ax.set_title(f"Mean {cls}") 394 ax.axis("off") 395 396 for j in range(i + 1, len(axes)): 397 axes[j].axis("off") 398 399 plt.tight_layout() 400 401 return fig 402 403 def plot_mean_images_per_class_with_otsu( 404 self, threshold: float = 0.0, filename: Optional[str] = None 405 ) -> Figure: 406 """ 407 Plot mean images per class applying an adjustable Otsu threshold. 408 409 Loads pre-computed mean images and applies a custom thresholding 410 strategy based on Otsu's method with user-defined adjustments. 411 Generates binary masks and overlays them on the original mean 412 images with contour visualization. 413 414 Parameters 415 ---------- 416 threshold : float, optional 417 Threshold adjustment parameter. Range: [-1, 1]. 418 - -1: Maximum threshold (255, minimal foreground). 419 - 0: Otsu threshold (default). 420 - 1: Minimum threshold (0, maximal foreground). 421 Default is 0.0. 422 filename : str, optional 423 Path to the .npy file containing pre-computed mean images. 424 Must end with ".npy" extension. Default is None. 425 426 Returns 427 ------- 428 matplotlib.figure.Figure or None 429 The generated figure object containing the thresholded mean images. 430 Returns None if mean images cannot be loaded or invalid parameters 431 are provided. 432 433 Raises 434 ------ 435 None 436 437 Notes 438 ----- 439 Red overlays indicate pixels below the threshold (potential 440 foreground objects). Contours are traced around connected 441 components in the binary mask. 442 """ 443 444 mean_images = None 445 446 if filename and os.path.exists(filename) and filename.endswith(".npy"): 447 try: 448 print(f"[INFO] Loading mean images from {filename}") 449 mean_images = np.load(filename, allow_pickle=True).item() 450 except Exception as e: 451 print(f"[WARN] Could not load \ 452 mean images from {filename}: {e}") 453 return None 454 else: 455 print("[WARN] No mean images found or invalid file path.") 456 return None 457 458 n_classes = len(mean_images) 459 n_cols = min(3, n_classes) 460 n_rows = int(np.ceil(n_classes / n_cols)) 461 fig, axes = plt.subplots( 462 n_rows, 463 n_cols, 464 figsize=(5 * n_cols, 5 * n_rows) 465 ) 466 axes = np.array(axes).flatten() 467 468 for i, (cls, mean_image) in enumerate(mean_images.items()): 469 ax = axes[i] 470 gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY) 471 472 if gray.dtype != np.uint8: 473 gray = cv2.normalize( 474 gray, None, 0, 255, cv2.NORM_MINMAX 475 ).astype( 476 np.uint8 477 ) 478 479 otsu_thresh, _ = cv2.threshold( 480 gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 481 ) 482 483 adj = np.clip(threshold, -1, 1) 484 if adj == -1: 485 final_thresh = 255 486 elif adj == 1: 487 final_thresh = 0 488 else: 489 if adj < 0: 490 final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj) 491 else: 492 final_thresh = otsu_thresh - (otsu_thresh - 0) * adj 493 494 _, binary = cv2.threshold( 495 gray, 496 final_thresh, 497 255, 498 cv2.THRESH_BINARY 499 ) 500 501 mask = (binary == 0).astype(np.uint8) 502 kernel = np.ones((3, 3), np.uint8) 503 mask_dilated = cv2.dilate(mask, kernel, iterations=1) 504 red_overlay = np.zeros((*mask.shape, 4)) 505 red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25] 506 507 ax.imshow(mean_image) 508 ax.imshow(red_overlay) 509 ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \ 510 (thr={final_thresh:.1f})") 511 ax.axis("off") 512 513 contours, _ = cv2.findContours( 514 mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 515 ) 516 for contour in contours: 517 contour = contour.squeeze() 518 if contour.ndim == 2: 519 ax.plot( 520 contour[:, 0], 521 contour[:, 1], 522 color="red", 523 linewidth=2 524 ) 525 526 plt.tight_layout() 527 528 return fig
14class EdaAnalyzer: 15 """ 16 A class that encapsulates all Exploratory Data Analysis (EDA) utilities 17 for the Garbage Classification dataset or similar image datasets. 18 19 This class provides methods for dataset management, visualization, 20 and analysis, including downloading datasets from Kaggle, 21 loading metadata, plotting class distributions, and computing 22 prototype mean images. 23 24 Attributes 25 ---------- 26 root_path : str 27 Path to the raw data folder. 28 dataset_path : str 29 Path to the dataset folder. 30 zip_file : str 31 Path to the zip file for dataset download. 32 kaggle_url : str 33 URL for downloading the Kaggle dataset. 34 metadata_path : str 35 Path to the metadata.csv file. 36 df : pd.DataFrame or None 37 Metadata DataFrame containing dataset information. 38 """ 39 40 def __init__( 41 self, 42 root_path: str = "./data/raw", 43 dataset_name: str = "Garbage_Dataset_Classification", 44 ): 45 """ 46 Initialize the EdaAnalyzer instance. 47 48 Parameters 49 ---------- 50 root_path : str, optional 51 Path to the raw data folder. Default is "./data/raw". 52 dataset_name : str, optional 53 Name of the dataset folder. Default is 54 "Garbage_Dataset_Classification". 55 56 Returns 57 ------- 58 None 59 """ 60 self.root_path = root_path 61 self.dataset_path = os.path.join(root_path, dataset_name) 62 self.zip_file = os.path.join(root_path, "garbage-dataset.zip") 63 self.kaggle_url = "https://www.kaggle.com/\ 64 api/v1/datasets/download/zlatan599/garbage-dataset-classification" 65 self.metadata_path = os.path.join(self.dataset_path, "metadata.csv") 66 self.df = None 67 68 # ------------------------------------------------------------------------- 69 # Dataset management 70 # ------------------------------------------------------------------------- 71 def download_with_curl(self): 72 """ 73 Download Kaggle dataset using curl and API credentials. 74 75 This method downloads the garbage dataset from Kaggle using the 76 Kaggle API credentials stored in ~/.kaggle/kaggle.json. The 77 dataset is extracted and the zip file is removed after 78 extraction. 79 80 Parameters 81 ---------- 82 None 83 84 Returns 85 ------- 86 None 87 88 Raises 89 ------ 90 FileNotFoundError 91 If Kaggle credentials are not found at ~/.kaggle/kaggle.json. 92 """ 93 print("Downloading dataset with curl...") 94 95 os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True) 96 os.chmod(os.path.expanduser("~/.kaggle"), 0o700) 97 98 cmd = f"curl -L -o {self.zip_file} -u \ 99 `jq -r .username ~/.kaggle/kaggle.json`:\ 100 `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}" 101 os.system(cmd) 102 103 print("Extracting dataset...") 104 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: 105 zip_ref.extractall(self.root_path) 106 107 os.remove(self.zip_file) 108 print("Dataset downloaded and extracted successfully.") 109 110 def ensure_dataset(self): 111 """ 112 Check if dataset exists; otherwise, download it. 113 114 Verifies the presence of the dataset at the expected path. 115 If not found, triggers the download process. 116 If already present, prints a confirmation message. 117 118 Parameters 119 ---------- 120 None 121 122 Returns 123 ------- 124 None 125 """ 126 if not os.path.exists(self.dataset_path): 127 self.download_with_curl() 128 else: 129 print(f"{self.dataset_path} already exists, nothing to do.") 130 131 def load_metadata(self): 132 """ 133 Load metadata.csv into a pandas DataFrame. 134 135 Reads the metadata CSV file from the dataset path and stores it as 136 self.df. Prints summary statistics about the loaded data. 137 138 Parameters 139 ---------- 140 None 141 142 Returns 143 ------- 144 pd.DataFrame (The loaded metadata DataFrame containing image filenames 145 and labels). 146 147 Raises 148 ------ 149 FileNotFoundError 150 If metadata.csv is not found at the expected path. 151 """ 152 if not os.path.exists(self.metadata_path): 153 raise FileNotFoundError(f"Metadata file not found at \ 154 {self.metadata_path}") 155 self.df = pd.read_csv(self.metadata_path) 156 print( 157 f"Loaded metadata: {len(self.df)} entries, \ 158 {self.df['label'].nunique()} classes." 159 ) 160 return self.df 161 162 # ------------------------------------------------------------------------- 163 # Visualization utilities 164 # ------------------------------------------------------------------------- 165 def plot_random_examples_per_class( 166 self, 167 filename: Optional[str] = None 168 ) -> Figure: 169 """ 170 Plot a random image from each class and return the figure. 171 172 Selects one random image per class and displays them in a grid layout. 173 Each subplot is bordered with a color corresponding to its class. 174 175 Parameters 176 ---------- 177 filename : str, optional 178 Path to save the generated figure as an image file. If provided, 179 the figure is saved with 150 dpi. Default is None (no save). 180 181 Returns 182 ------- 183 matplotlib.figure.Figure 184 The generated figure object containing the plotted images. 185 """ 186 df = self.df 187 classes = df["label"].unique() 188 palette = sns.color_palette("tab10", len(classes)) 189 class_colors = {cls: palette[i] for i, cls in enumerate(classes)} 190 191 cols, rows = 3, (len(classes) + 2) // 3 192 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 193 axes = axes.flatten() 194 195 for i, cls in enumerate(classes): 196 img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"] 197 img_path = os.path.join( 198 self.dataset_path, 199 "images", 200 cls, 201 img_filename 202 ) 203 img = Image.open(img_path) 204 205 ax = axes[i] 206 ax.imshow(img) 207 ax.set_title(cls, fontsize=14, color=class_colors[cls]) 208 ax.axis("off") 209 for spine in ax.spines.values(): 210 spine.set_edgecolor(class_colors[cls]) 211 spine.set_linewidth(4) 212 213 for j in range(i + 1, len(axes)): 214 axes[j].axis("off") 215 216 plt.tight_layout() 217 218 if filename: 219 plt.savefig(filename, dpi=150) 220 221 return fig 222 223 def plot_class_distribution( 224 self, 225 filename: Optional[str] = None 226 ) -> Figure: 227 """ 228 Plot class distribution using seaborn and return the figure. 229 230 Creates a countplot showing the number of samples per class, ordered 231 by frequency and using a color palette for visual distinction. 232 233 Parameters 234 ---------- 235 filename : str, optional 236 Path to save the generated figure as an image file. If provided, 237 the figure is saved with 150 dpi. Default is None (no save). 238 239 Returns 240 ------- 241 matplotlib.figure.Figure 242 The generated figure object containing the class distribution plot. 243 """ 244 fig, ax = plt.subplots(figsize=(8, 5)) 245 sns.countplot( 246 data=self.df, 247 x="label", 248 order=self.df["label"].value_counts().index, 249 palette="tab10", 250 ax=ax, 251 ) 252 ax.set_title("Class Distribution", fontsize=16) 253 ax.set_xlabel("Class") 254 ax.set_ylabel("Count") 255 plt.setp(ax.get_xticklabels(), rotation=45) 256 plt.tight_layout() 257 258 if filename: 259 fig.savefig(filename, dpi=150) 260 261 return fig 262 263 # ------------------------------------------------------------------------- 264 # Prototypes 265 # ------------------------------------------------------------------------- 266 267 def _compute_mean_images_per_batch(self, batch_size=32): 268 """ 269 Compute mean image per class using batch processing. 270 271 Processes images in batches to compute the mean image for each class, 272 reducing memory overhead for large datasets. 273 Images are converted to RGB and normalized to float32. 274 275 Parameters 276 ---------- 277 batch_size : int, optional 278 Number of images to process per batch. Default is 32. 279 280 Returns 281 ------- 282 dict 283 Dictionary with class names as keys and normalized mean images 284 (values in range [0, 1]) as values. 285 286 Notes 287 ----- 288 Images are normalized by dividing by 255.0. Invalid or corrupted images 289 are skipped during processing. 290 """ 291 292 classes = self.df["label"].unique() 293 result = {} 294 295 for cls in classes: 296 subset = self.df[self.df["label"] == cls] 297 count = 0 298 mean_acc = None 299 300 for batch_start in range(0, len(subset), batch_size): 301 batch_end = min(batch_start + batch_size, len(subset)) 302 batch_rows = subset.iloc[batch_start:batch_end] 303 304 imgs = [] 305 for _, row in batch_rows.iterrows(): 306 img_path = os.path.join( 307 self.dataset_path, 308 "images", 309 row["label"], 310 row["filename"] 311 ) 312 try: 313 img = Image.open(img_path).convert("RGB") 314 imgs.append(np.array(img, dtype=np.float32)) 315 except Exception: 316 continue 317 318 if imgs: 319 imgs_stack = np.stack(imgs, axis=0) 320 batch_mean = np.mean(imgs_stack, axis=0) 321 322 # Actualizar media acumulada 323 if mean_acc is None: 324 mean_acc = batch_mean 325 else: 326 aux1 = mean_acc * count 327 aux2 = batch_mean * len(imgs) 328 mean_acc = (aux1 + aux2) / ( 329 count + len(imgs) 330 ) 331 332 count += len(imgs) 333 334 if mean_acc is not None: 335 result[cls] = mean_acc / 255.0 336 337 return result 338 339 def plot_mean_images_per_class( 340 self, 341 filename: Optional[str] = None 342 ) -> Figure: 343 """ 344 Compute or load and plot mean images per class, returning the figure. 345 346 Attempts to load pre-computed mean images from a .npy file. 347 If not found, computes them using batch processing and 348 optionally saves the result. 349 Displays all mean images in a grid layout. 350 351 Parameters 352 ---------- 353 filename : str, optional 354 Path to the .npy file containing pre-computed mean images, 355 or destination path for saving newly computed mean images. 356 Default is None (no caching). 357 358 Returns 359 ------- 360 matplotlib.figure.Figure 361 The generated figure object containing the plotted mean images. 362 363 Notes 364 ----- 365 If filename is provided and the file does not exist, computed 366 mean images will be saved to this path for future use. 367 """ 368 369 mean_images = None 370 371 if filename and os.path.exists(filename): 372 try: 373 print(f"[INFO] Loading mean images from {filename}") 374 mean_images = np.load(filename, allow_pickle=True).item() 375 except Exception as e: 376 print(f"[WARN] Could not load \ 377 mean images from {filename}: {e}") 378 379 if mean_images is None: 380 print("[INFO] Computing mean images...") 381 mean_images = self._compute_mean_images_per_batch() 382 if filename: 383 np.save(filename, mean_images) 384 print(f"[INFO] Saved mean images to {filename}") 385 386 # --- Plot --- 387 cols, rows = 3, (len(mean_images) + 2) // 3 388 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 389 axes = axes.flatten() 390 391 for i, (cls, img) in enumerate(mean_images.items()): 392 ax = axes[i] 393 ax.imshow(img) 394 ax.set_title(f"Mean {cls}") 395 ax.axis("off") 396 397 for j in range(i + 1, len(axes)): 398 axes[j].axis("off") 399 400 plt.tight_layout() 401 402 return fig 403 404 def plot_mean_images_per_class_with_otsu( 405 self, threshold: float = 0.0, filename: Optional[str] = None 406 ) -> Figure: 407 """ 408 Plot mean images per class applying an adjustable Otsu threshold. 409 410 Loads pre-computed mean images and applies a custom thresholding 411 strategy based on Otsu's method with user-defined adjustments. 412 Generates binary masks and overlays them on the original mean 413 images with contour visualization. 414 415 Parameters 416 ---------- 417 threshold : float, optional 418 Threshold adjustment parameter. Range: [-1, 1]. 419 - -1: Maximum threshold (255, minimal foreground). 420 - 0: Otsu threshold (default). 421 - 1: Minimum threshold (0, maximal foreground). 422 Default is 0.0. 423 filename : str, optional 424 Path to the .npy file containing pre-computed mean images. 425 Must end with ".npy" extension. Default is None. 426 427 Returns 428 ------- 429 matplotlib.figure.Figure or None 430 The generated figure object containing the thresholded mean images. 431 Returns None if mean images cannot be loaded or invalid parameters 432 are provided. 433 434 Raises 435 ------ 436 None 437 438 Notes 439 ----- 440 Red overlays indicate pixels below the threshold (potential 441 foreground objects). Contours are traced around connected 442 components in the binary mask. 443 """ 444 445 mean_images = None 446 447 if filename and os.path.exists(filename) and filename.endswith(".npy"): 448 try: 449 print(f"[INFO] Loading mean images from {filename}") 450 mean_images = np.load(filename, allow_pickle=True).item() 451 except Exception as e: 452 print(f"[WARN] Could not load \ 453 mean images from {filename}: {e}") 454 return None 455 else: 456 print("[WARN] No mean images found or invalid file path.") 457 return None 458 459 n_classes = len(mean_images) 460 n_cols = min(3, n_classes) 461 n_rows = int(np.ceil(n_classes / n_cols)) 462 fig, axes = plt.subplots( 463 n_rows, 464 n_cols, 465 figsize=(5 * n_cols, 5 * n_rows) 466 ) 467 axes = np.array(axes).flatten() 468 469 for i, (cls, mean_image) in enumerate(mean_images.items()): 470 ax = axes[i] 471 gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY) 472 473 if gray.dtype != np.uint8: 474 gray = cv2.normalize( 475 gray, None, 0, 255, cv2.NORM_MINMAX 476 ).astype( 477 np.uint8 478 ) 479 480 otsu_thresh, _ = cv2.threshold( 481 gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 482 ) 483 484 adj = np.clip(threshold, -1, 1) 485 if adj == -1: 486 final_thresh = 255 487 elif adj == 1: 488 final_thresh = 0 489 else: 490 if adj < 0: 491 final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj) 492 else: 493 final_thresh = otsu_thresh - (otsu_thresh - 0) * adj 494 495 _, binary = cv2.threshold( 496 gray, 497 final_thresh, 498 255, 499 cv2.THRESH_BINARY 500 ) 501 502 mask = (binary == 0).astype(np.uint8) 503 kernel = np.ones((3, 3), np.uint8) 504 mask_dilated = cv2.dilate(mask, kernel, iterations=1) 505 red_overlay = np.zeros((*mask.shape, 4)) 506 red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25] 507 508 ax.imshow(mean_image) 509 ax.imshow(red_overlay) 510 ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \ 511 (thr={final_thresh:.1f})") 512 ax.axis("off") 513 514 contours, _ = cv2.findContours( 515 mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 516 ) 517 for contour in contours: 518 contour = contour.squeeze() 519 if contour.ndim == 2: 520 ax.plot( 521 contour[:, 0], 522 contour[:, 1], 523 color="red", 524 linewidth=2 525 ) 526 527 plt.tight_layout() 528 529 return fig
A class that encapsulates all Exploratory Data Analysis (EDA) utilities for the Garbage Classification dataset or similar image datasets.
This class provides methods for dataset management, visualization, and analysis, including downloading datasets from Kaggle, loading metadata, plotting class distributions, and computing prototype mean images.
Attributes
- root_path (str): Path to the raw data folder.
- dataset_path (str): Path to the dataset folder.
- zip_file (str): Path to the zip file for dataset download.
- kaggle_url (str): URL for downloading the Kaggle dataset.
- metadata_path (str): Path to the metadata.csv file.
- df (pd.DataFrame or None): Metadata DataFrame containing dataset information.
40 def __init__( 41 self, 42 root_path: str = "./data/raw", 43 dataset_name: str = "Garbage_Dataset_Classification", 44 ): 45 """ 46 Initialize the EdaAnalyzer instance. 47 48 Parameters 49 ---------- 50 root_path : str, optional 51 Path to the raw data folder. Default is "./data/raw". 52 dataset_name : str, optional 53 Name of the dataset folder. Default is 54 "Garbage_Dataset_Classification". 55 56 Returns 57 ------- 58 None 59 """ 60 self.root_path = root_path 61 self.dataset_path = os.path.join(root_path, dataset_name) 62 self.zip_file = os.path.join(root_path, "garbage-dataset.zip") 63 self.kaggle_url = "https://www.kaggle.com/\ 64 api/v1/datasets/download/zlatan599/garbage-dataset-classification" 65 self.metadata_path = os.path.join(self.dataset_path, "metadata.csv") 66 self.df = None
Initialize the EdaAnalyzer instance.
Parameters
- root_path (str, optional): Path to the raw data folder. Default is "./data/raw".
- dataset_name (str, optional): Name of the dataset folder. Default is "Garbage_Dataset_Classification".
Returns
- None
71 def download_with_curl(self): 72 """ 73 Download Kaggle dataset using curl and API credentials. 74 75 This method downloads the garbage dataset from Kaggle using the 76 Kaggle API credentials stored in ~/.kaggle/kaggle.json. The 77 dataset is extracted and the zip file is removed after 78 extraction. 79 80 Parameters 81 ---------- 82 None 83 84 Returns 85 ------- 86 None 87 88 Raises 89 ------ 90 FileNotFoundError 91 If Kaggle credentials are not found at ~/.kaggle/kaggle.json. 92 """ 93 print("Downloading dataset with curl...") 94 95 os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True) 96 os.chmod(os.path.expanduser("~/.kaggle"), 0o700) 97 98 cmd = f"curl -L -o {self.zip_file} -u \ 99 `jq -r .username ~/.kaggle/kaggle.json`:\ 100 `jq -r .key ~/.kaggle/kaggle.json` {self.kaggle_url}" 101 os.system(cmd) 102 103 print("Extracting dataset...") 104 with zipfile.ZipFile(self.zip_file, "r") as zip_ref: 105 zip_ref.extractall(self.root_path) 106 107 os.remove(self.zip_file) 108 print("Dataset downloaded and extracted successfully.")
Download Kaggle dataset using curl and API credentials.
This method downloads the garbage dataset from Kaggle using the Kaggle API credentials stored in ~/.kaggle/kaggle.json. The dataset is extracted and the zip file is removed after extraction.
Parameters
- None
Returns
- None
Raises
- FileNotFoundError: If Kaggle credentials are not found at ~/.kaggle/kaggle.json.
110 def ensure_dataset(self): 111 """ 112 Check if dataset exists; otherwise, download it. 113 114 Verifies the presence of the dataset at the expected path. 115 If not found, triggers the download process. 116 If already present, prints a confirmation message. 117 118 Parameters 119 ---------- 120 None 121 122 Returns 123 ------- 124 None 125 """ 126 if not os.path.exists(self.dataset_path): 127 self.download_with_curl() 128 else: 129 print(f"{self.dataset_path} already exists, nothing to do.")
Check if dataset exists; otherwise, download it.
Verifies the presence of the dataset at the expected path. If not found, triggers the download process. If already present, prints a confirmation message.
Parameters
- None
Returns
- None
131 def load_metadata(self): 132 """ 133 Load metadata.csv into a pandas DataFrame. 134 135 Reads the metadata CSV file from the dataset path and stores it as 136 self.df. Prints summary statistics about the loaded data. 137 138 Parameters 139 ---------- 140 None 141 142 Returns 143 ------- 144 pd.DataFrame (The loaded metadata DataFrame containing image filenames 145 and labels). 146 147 Raises 148 ------ 149 FileNotFoundError 150 If metadata.csv is not found at the expected path. 151 """ 152 if not os.path.exists(self.metadata_path): 153 raise FileNotFoundError(f"Metadata file not found at \ 154 {self.metadata_path}") 155 self.df = pd.read_csv(self.metadata_path) 156 print( 157 f"Loaded metadata: {len(self.df)} entries, \ 158 {self.df['label'].nunique()} classes." 159 ) 160 return self.df
Load metadata.csv into a pandas DataFrame.
Reads the metadata CSV file from the dataset path and stores it as self.df. Prints summary statistics about the loaded data.
Parameters
- None
Returns
- pd.DataFrame (The loaded metadata DataFrame containing image filenames
- and labels).
Raises
- FileNotFoundError: If metadata.csv is not found at the expected path.
165 def plot_random_examples_per_class( 166 self, 167 filename: Optional[str] = None 168 ) -> Figure: 169 """ 170 Plot a random image from each class and return the figure. 171 172 Selects one random image per class and displays them in a grid layout. 173 Each subplot is bordered with a color corresponding to its class. 174 175 Parameters 176 ---------- 177 filename : str, optional 178 Path to save the generated figure as an image file. If provided, 179 the figure is saved with 150 dpi. Default is None (no save). 180 181 Returns 182 ------- 183 matplotlib.figure.Figure 184 The generated figure object containing the plotted images. 185 """ 186 df = self.df 187 classes = df["label"].unique() 188 palette = sns.color_palette("tab10", len(classes)) 189 class_colors = {cls: palette[i] for i, cls in enumerate(classes)} 190 191 cols, rows = 3, (len(classes) + 2) // 3 192 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 193 axes = axes.flatten() 194 195 for i, cls in enumerate(classes): 196 img_filename = df[df["label"] == cls].sample(1).iloc[0]["filename"] 197 img_path = os.path.join( 198 self.dataset_path, 199 "images", 200 cls, 201 img_filename 202 ) 203 img = Image.open(img_path) 204 205 ax = axes[i] 206 ax.imshow(img) 207 ax.set_title(cls, fontsize=14, color=class_colors[cls]) 208 ax.axis("off") 209 for spine in ax.spines.values(): 210 spine.set_edgecolor(class_colors[cls]) 211 spine.set_linewidth(4) 212 213 for j in range(i + 1, len(axes)): 214 axes[j].axis("off") 215 216 plt.tight_layout() 217 218 if filename: 219 plt.savefig(filename, dpi=150) 220 221 return fig
Plot a random image from each class and return the figure.
Selects one random image per class and displays them in a grid layout. Each subplot is bordered with a color corresponding to its class.
Parameters
- filename (str, optional): Path to save the generated figure as an image file. If provided, the figure is saved with 150 dpi. Default is None (no save).
Returns
- matplotlib.figure.Figure: The generated figure object containing the plotted images.
223 def plot_class_distribution( 224 self, 225 filename: Optional[str] = None 226 ) -> Figure: 227 """ 228 Plot class distribution using seaborn and return the figure. 229 230 Creates a countplot showing the number of samples per class, ordered 231 by frequency and using a color palette for visual distinction. 232 233 Parameters 234 ---------- 235 filename : str, optional 236 Path to save the generated figure as an image file. If provided, 237 the figure is saved with 150 dpi. Default is None (no save). 238 239 Returns 240 ------- 241 matplotlib.figure.Figure 242 The generated figure object containing the class distribution plot. 243 """ 244 fig, ax = plt.subplots(figsize=(8, 5)) 245 sns.countplot( 246 data=self.df, 247 x="label", 248 order=self.df["label"].value_counts().index, 249 palette="tab10", 250 ax=ax, 251 ) 252 ax.set_title("Class Distribution", fontsize=16) 253 ax.set_xlabel("Class") 254 ax.set_ylabel("Count") 255 plt.setp(ax.get_xticklabels(), rotation=45) 256 plt.tight_layout() 257 258 if filename: 259 fig.savefig(filename, dpi=150) 260 261 return fig
Plot class distribution using seaborn and return the figure.
Creates a countplot showing the number of samples per class, ordered by frequency and using a color palette for visual distinction.
Parameters
- filename (str, optional): Path to save the generated figure as an image file. If provided, the figure is saved with 150 dpi. Default is None (no save).
Returns
- matplotlib.figure.Figure: The generated figure object containing the class distribution plot.
339 def plot_mean_images_per_class( 340 self, 341 filename: Optional[str] = None 342 ) -> Figure: 343 """ 344 Compute or load and plot mean images per class, returning the figure. 345 346 Attempts to load pre-computed mean images from a .npy file. 347 If not found, computes them using batch processing and 348 optionally saves the result. 349 Displays all mean images in a grid layout. 350 351 Parameters 352 ---------- 353 filename : str, optional 354 Path to the .npy file containing pre-computed mean images, 355 or destination path for saving newly computed mean images. 356 Default is None (no caching). 357 358 Returns 359 ------- 360 matplotlib.figure.Figure 361 The generated figure object containing the plotted mean images. 362 363 Notes 364 ----- 365 If filename is provided and the file does not exist, computed 366 mean images will be saved to this path for future use. 367 """ 368 369 mean_images = None 370 371 if filename and os.path.exists(filename): 372 try: 373 print(f"[INFO] Loading mean images from {filename}") 374 mean_images = np.load(filename, allow_pickle=True).item() 375 except Exception as e: 376 print(f"[WARN] Could not load \ 377 mean images from {filename}: {e}") 378 379 if mean_images is None: 380 print("[INFO] Computing mean images...") 381 mean_images = self._compute_mean_images_per_batch() 382 if filename: 383 np.save(filename, mean_images) 384 print(f"[INFO] Saved mean images to {filename}") 385 386 # --- Plot --- 387 cols, rows = 3, (len(mean_images) + 2) // 3 388 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 389 axes = axes.flatten() 390 391 for i, (cls, img) in enumerate(mean_images.items()): 392 ax = axes[i] 393 ax.imshow(img) 394 ax.set_title(f"Mean {cls}") 395 ax.axis("off") 396 397 for j in range(i + 1, len(axes)): 398 axes[j].axis("off") 399 400 plt.tight_layout() 401 402 return fig
Compute or load and plot mean images per class, returning the figure.
Attempts to load pre-computed mean images from a .npy file. If not found, computes them using batch processing and optionally saves the result. Displays all mean images in a grid layout.
Parameters
- filename (str, optional): Path to the .npy file containing pre-computed mean images, or destination path for saving newly computed mean images. Default is None (no caching).
Returns
- matplotlib.figure.Figure: The generated figure object containing the plotted mean images.
Notes
If filename is provided and the file does not exist, computed mean images will be saved to this path for future use.
404 def plot_mean_images_per_class_with_otsu( 405 self, threshold: float = 0.0, filename: Optional[str] = None 406 ) -> Figure: 407 """ 408 Plot mean images per class applying an adjustable Otsu threshold. 409 410 Loads pre-computed mean images and applies a custom thresholding 411 strategy based on Otsu's method with user-defined adjustments. 412 Generates binary masks and overlays them on the original mean 413 images with contour visualization. 414 415 Parameters 416 ---------- 417 threshold : float, optional 418 Threshold adjustment parameter. Range: [-1, 1]. 419 - -1: Maximum threshold (255, minimal foreground). 420 - 0: Otsu threshold (default). 421 - 1: Minimum threshold (0, maximal foreground). 422 Default is 0.0. 423 filename : str, optional 424 Path to the .npy file containing pre-computed mean images. 425 Must end with ".npy" extension. Default is None. 426 427 Returns 428 ------- 429 matplotlib.figure.Figure or None 430 The generated figure object containing the thresholded mean images. 431 Returns None if mean images cannot be loaded or invalid parameters 432 are provided. 433 434 Raises 435 ------ 436 None 437 438 Notes 439 ----- 440 Red overlays indicate pixels below the threshold (potential 441 foreground objects). Contours are traced around connected 442 components in the binary mask. 443 """ 444 445 mean_images = None 446 447 if filename and os.path.exists(filename) and filename.endswith(".npy"): 448 try: 449 print(f"[INFO] Loading mean images from {filename}") 450 mean_images = np.load(filename, allow_pickle=True).item() 451 except Exception as e: 452 print(f"[WARN] Could not load \ 453 mean images from {filename}: {e}") 454 return None 455 else: 456 print("[WARN] No mean images found or invalid file path.") 457 return None 458 459 n_classes = len(mean_images) 460 n_cols = min(3, n_classes) 461 n_rows = int(np.ceil(n_classes / n_cols)) 462 fig, axes = plt.subplots( 463 n_rows, 464 n_cols, 465 figsize=(5 * n_cols, 5 * n_rows) 466 ) 467 axes = np.array(axes).flatten() 468 469 for i, (cls, mean_image) in enumerate(mean_images.items()): 470 ax = axes[i] 471 gray = cv2.cvtColor(mean_image, cv2.COLOR_RGB2GRAY) 472 473 if gray.dtype != np.uint8: 474 gray = cv2.normalize( 475 gray, None, 0, 255, cv2.NORM_MINMAX 476 ).astype( 477 np.uint8 478 ) 479 480 otsu_thresh, _ = cv2.threshold( 481 gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 482 ) 483 484 adj = np.clip(threshold, -1, 1) 485 if adj == -1: 486 final_thresh = 255 487 elif adj == 1: 488 final_thresh = 0 489 else: 490 if adj < 0: 491 final_thresh = otsu_thresh + (255 - otsu_thresh) * (-adj) 492 else: 493 final_thresh = otsu_thresh - (otsu_thresh - 0) * adj 494 495 _, binary = cv2.threshold( 496 gray, 497 final_thresh, 498 255, 499 cv2.THRESH_BINARY 500 ) 501 502 mask = (binary == 0).astype(np.uint8) 503 kernel = np.ones((3, 3), np.uint8) 504 mask_dilated = cv2.dilate(mask, kernel, iterations=1) 505 red_overlay = np.zeros((*mask.shape, 4)) 506 red_overlay[mask_dilated == 1] = [1, 0, 0, 0.25] 507 508 ax.imshow(mean_image) 509 ax.imshow(red_overlay) 510 ax.set_title(f"{cls}\nOtsu adj={threshold:.2f} \ 511 (thr={final_thresh:.1f})") 512 ax.axis("off") 513 514 contours, _ = cv2.findContours( 515 mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 516 ) 517 for contour in contours: 518 contour = contour.squeeze() 519 if contour.ndim == 2: 520 ax.plot( 521 contour[:, 0], 522 contour[:, 1], 523 color="red", 524 linewidth=2 525 ) 526 527 plt.tight_layout() 528 529 return fig
Plot mean images per class applying an adjustable Otsu threshold.
Loads pre-computed mean images and applies a custom thresholding strategy based on Otsu's method with user-defined adjustments. Generates binary masks and overlays them on the original mean images with contour visualization.
Parameters
- threshold (float, optional):
Threshold adjustment parameter. Range: [-1, 1].
- -1: Maximum threshold (255, minimal foreground).
- 0: Otsu threshold (default).
- 1: Minimum threshold (0, maximal foreground). Default is 0.0.
- filename (str, optional): Path to the .npy file containing pre-computed mean images. Must end with ".npy" extension. Default is None.
Returns
- matplotlib.figure.Figure or None: The generated figure object containing the thresholded mean images. Returns None if mean images cannot be loaded or invalid parameters are provided.
Raises
- None
Notes
Red overlays indicate pixels below the threshold (potential foreground objects). Contours are traced around connected components in the binary mask.