dragon-ml-toolbox 14.3.0__py3-none-any.whl → 14.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,14 +1,18 @@
1
- from typing import Union, Dict, Type, Callable
1
+ from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
2
2
  from PIL import ImageOps, Image
3
+ from torchvision import transforms
4
+ from pathlib import Path
3
5
 
4
6
  from ._logger import _LOGGER
5
7
  from ._script_info import _script_info
6
8
  from .keys import VisionTransformRecipeKeys
9
+ from .path_manager import make_fullpath
7
10
 
8
11
 
9
12
  __all__ = [
10
13
  "TRANSFORM_REGISTRY",
11
- "ResizeAspectFill"
14
+ "ResizeAspectFill",
15
+ "create_offline_augmentations"
12
16
  ]
13
17
 
14
18
  # --- Custom Vision Transform Class ---
@@ -23,9 +27,8 @@ class ResizeAspectFill:
23
27
  """
24
28
  def __init__(self, pad_color: Union[str, int] = "black") -> None:
25
29
  self.pad_color = pad_color
26
- # Store kwargs to allow for recreation
30
+ # Store kwargs to allow for re-creation
27
31
  self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
28
- # self._kwargs = {"pad_color": pad_color}
29
32
 
30
33
  def __call__(self, image: Image.Image) -> Image.Image:
31
34
  if not isinstance(image, Image.Image):
@@ -47,12 +50,154 @@ class ResizeAspectFill:
47
50
  padding = (left_padding, 0, right_padding, 0)
48
51
 
49
52
  return ImageOps.expand(image, padding, fill=self.pad_color)
50
-
51
53
 
52
- #NOTE: Add custom transforms here.
54
+
55
+ #NOTE: Add custom transforms.
53
56
  TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
54
57
  "ResizeAspectFill": ResizeAspectFill,
55
58
  }
56
59
 
60
+
61
+ def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
62
+ """Internal helper to build a transform pipeline from a recipe dict."""
63
+ pipeline_steps: List[Callable] = []
64
+
65
+ if VisionTransformRecipeKeys.PIPELINE not in recipe:
66
+ _LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
67
+ raise ValueError("Invalid recipe format.")
68
+
69
+ for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
70
+ t_name = step.get(VisionTransformRecipeKeys.NAME)
71
+ t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
72
+
73
+ if not t_name:
74
+ _LOGGER.error(f"Invalid transform step, missing 'name': {step}")
75
+ continue
76
+
77
+ transform_class: Any = None
78
+
79
+ # 1. Check standard torchvision transforms
80
+ if hasattr(transforms, t_name):
81
+ transform_class = getattr(transforms, t_name)
82
+ # 2. Check custom transforms
83
+ elif t_name in TRANSFORM_REGISTRY:
84
+ transform_class = TRANSFORM_REGISTRY[t_name]
85
+ # 3. Not found
86
+ else:
87
+ _LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
88
+ raise ValueError(f"Unknown transform name: {t_name}")
89
+
90
+ # Instantiate the transform
91
+ try:
92
+ pipeline_steps.append(transform_class(**t_kwargs))
93
+ except Exception as e:
94
+ _LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
95
+ raise
96
+
97
+ return transforms.Compose(pipeline_steps)
98
+
99
+
100
+ def create_offline_augmentations(
101
+ input_directory: Union[str, Path],
102
+ output_directory: Union[str, Path],
103
+ results_per_image: int,
104
+ recipe: Optional[Dict[str, Any]] = None,
105
+ save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
106
+ save_quality: int = 80
107
+ ) -> None:
108
+ """
109
+ Reads all valid images from an input directory, applies augmentations,
110
+ and saves the new images to an output directory (offline augmentation).
111
+
112
+ Skips subdirectories in the input path.
113
+
114
+ Args:
115
+ input_directory (Union[str, Path]): Path to the directory of source images.
116
+ output_directory (Union[str, Path]): Path to save the augmented images.
117
+ results_per_image (int): The number of augmented versions to create
118
+ for each source image.
119
+ recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
120
+ a default set of strong, random
121
+ augmentations will be used.
122
+ save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
123
+ Defaults to "WEBP" for good compression.
124
+ save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
125
+ """
126
+ VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
127
+
128
+ # --- 1. Validate Paths ---
129
+ in_path = make_fullpath(input_directory, enforce="directory")
130
+ out_path = make_fullpath(output_directory, make=True, enforce="directory")
131
+
132
+ _LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
133
+
134
+ # --- 2. Find Images ---
135
+ image_files = [
136
+ f for f in in_path.iterdir()
137
+ if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
138
+ ]
139
+
140
+ if not image_files:
141
+ _LOGGER.warning(f"No valid image files found in {in_path}.")
142
+ return
143
+
144
+ _LOGGER.info(f"Found {len(image_files)} images to process.")
145
+
146
+ # --- 3. Define Transform Pipeline ---
147
+ transform_pipeline: transforms.Compose
148
+
149
+ if recipe:
150
+ _LOGGER.info("Building transformations from provided recipe.")
151
+ try:
152
+ transform_pipeline = _build_transform_from_recipe(recipe)
153
+ except Exception as e:
154
+ _LOGGER.error(f"Failed to build transform from recipe: {e}")
155
+ return
156
+ else:
157
+ _LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
158
+ # Default "random" pipeline
159
+ transform_pipeline = transforms.Compose([
160
+ transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
161
+ transforms.RandomHorizontalFlip(p=0.5),
162
+ transforms.RandomRotation(degrees=90),
163
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
164
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
165
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
166
+ transforms.RandomApply([
167
+ transforms.GaussianBlur(kernel_size=3)
168
+ ], p=0.3)
169
+ ])
170
+
171
+ # --- 4. Process Images ---
172
+ total_saved = 0
173
+ format_upper = save_format.upper()
174
+
175
+ for img_path in image_files:
176
+ _LOGGER.debug(f"Processing {img_path.name}...")
177
+ try:
178
+ original_image = Image.open(img_path).convert("RGB")
179
+
180
+ for i in range(results_per_image):
181
+ new_stem = f"{img_path.stem}_aug_{i+1:03d}"
182
+ output_path = out_path / f"{new_stem}.{format_upper.lower()}"
183
+
184
+ # Apply transform
185
+ transformed_image = transform_pipeline(original_image)
186
+
187
+ # Save
188
+ transformed_image.save(
189
+ output_path,
190
+ format=format_upper,
191
+ quality=save_quality,
192
+ optimize=True # Add optimize flag
193
+ )
194
+ total_saved += 1
195
+
196
+ except Exception as e:
197
+ _LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
198
+
199
+ _LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
200
+
201
+
57
202
  def info():
58
203
  _script_info(__all__)
@@ -112,7 +112,7 @@ def evaluate_model_classification(
112
112
  report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
113
113
  plt.figure(figsize=figsize)
114
114
  sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
115
- annot_kws={"size": base_fontsize - 4})
115
+ annot_kws={"size": base_fontsize - 4}, vmin=0.0, vmax=1.0)
116
116
  plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
117
117
  plt.xticks(fontsize=base_fontsize - 2)
118
118
  plt.yticks(fontsize=base_fontsize - 2)
@@ -133,6 +133,7 @@ def evaluate_model_classification(
133
133
  normalize="true",
134
134
  ax=ax
135
135
  )
136
+ disp.im_.set_clim(vmin=0.0, vmax=1.0)
136
137
 
137
138
  ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
138
139
  ax.tick_params(axis='both', labelsize=base_fontsize)
@@ -327,7 +328,8 @@ def plot_calibration_curve(
327
328
  target_name: str,
328
329
  figure_size: tuple = (10, 10),
329
330
  base_fontsize: int = 24,
330
- n_bins: int = 15
331
+ n_bins: int = 15,
332
+ line_color: str = 'darkorange'
331
333
  ) -> plt.Figure: # type: ignore
332
334
  """
333
335
  Plots the calibration curve (reliability diagram) for a classifier.
@@ -348,22 +350,63 @@ def plot_calibration_curve(
348
350
  """
349
351
  fig, ax = plt.subplots(figsize=figure_size)
350
352
 
351
- disp = CalibrationDisplay.from_estimator(
352
- model,
353
- x_test,
354
- y_test,
355
- n_bins=n_bins,
356
- ax=ax
353
+ # --- Step 1: Get probabilities from the estimator ---
354
+ # We do this manually so we can pass them to from_predictions
355
+ try:
356
+ y_prob = model.predict_proba(x_test)
357
+ # Use probabilities for the positive class (assuming binary)
358
+ y_score = y_prob[:, 1]
359
+ except Exception as e:
360
+ _LOGGER.error(f"Could not get probabilities from model: {e}")
361
+ plt.close(fig)
362
+ return fig # Return empty figure
363
+
364
+ # --- Step 2: Get binned data *without* plotting ---
365
+ with plt.ioff():
366
+ fig_temp, ax_temp = plt.subplots()
367
+ cal_display_temp = CalibrationDisplay.from_predictions(
368
+ y_test,
369
+ y_score,
370
+ n_bins=n_bins,
371
+ ax=ax_temp,
372
+ name="temp"
373
+ )
374
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
375
+ plt.close(fig_temp)
376
+
377
+ # --- Step 3: Build the plot from scratch on ax ---
378
+
379
+ # 3a. Plot the ideal diagonal line
380
+ ax.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
381
+
382
+ # 3b. Use regplot for the regression line and its CI
383
+ sns.regplot(
384
+ x=line_x,
385
+ y=line_y,
386
+ ax=ax,
387
+ scatter=False, # No scatter dots
388
+ label=f"Calibration Curve ({n_bins} bins)",
389
+ line_kws={
390
+ 'color': line_color,
391
+ 'linestyle': '--',
392
+ 'linewidth': 2
393
+ }
357
394
  )
358
395
 
396
+ # --- Step 4: Apply original formatting ---
359
397
  ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
360
398
  ax.tick_params(axis='both', labelsize=base_fontsize - 2)
361
399
  ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
362
400
  ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
363
- ax.legend(fontsize=base_fontsize - 4)
401
+
402
+ # Set limits
403
+ ax.set_ylim(0.0, 1.0)
404
+ ax.set_xlim(0.0, 1.0)
405
+
406
+ ax.legend(fontsize=base_fontsize - 4, loc='lower right')
364
407
  fig.tight_layout()
365
408
 
366
- # Save figure
409
+ # --- Step 5: Save figure (using original logic) ---
367
410
  save_path = make_fullpath(save_dir, make=True)
368
411
  sanitized_target_name = sanitize_filename(target_name)
369
412
  full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
ml_tools/keys.py CHANGED
@@ -104,8 +104,9 @@ class VisionTransformRecipeKeys:
104
104
  TASK = "task"
105
105
  PIPELINE = "pipeline"
106
106
  NAME = "name"
107
- KWARGS = "_kwargs"
107
+ KWARGS = "kwargs"
108
108
  PRE_TRANSFORMS = "pre_transforms"
109
+
109
110
  RESIZE_SIZE = "resize_size"
110
111
  CROP_SIZE = "crop_size"
111
112
  MEAN = "mean"