dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
from typing import Union, Dict, Type, Callable
|
|
1
|
+
from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
|
|
2
2
|
from PIL import ImageOps, Image
|
|
3
|
+
from torchvision import transforms
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import json
|
|
3
6
|
|
|
4
7
|
from ._logger import _LOGGER
|
|
5
8
|
from ._script_info import _script_info
|
|
6
|
-
from .
|
|
9
|
+
from ._keys import VisionTransformRecipeKeys
|
|
10
|
+
from .path_manager import make_fullpath
|
|
7
11
|
|
|
8
12
|
|
|
9
13
|
__all__ = [
|
|
10
14
|
"TRANSFORM_REGISTRY",
|
|
11
|
-
"ResizeAspectFill"
|
|
15
|
+
"ResizeAspectFill",
|
|
16
|
+
"create_offline_augmentations"
|
|
12
17
|
]
|
|
13
18
|
|
|
14
19
|
# --- Custom Vision Transform Class ---
|
|
@@ -23,9 +28,8 @@ class ResizeAspectFill:
|
|
|
23
28
|
"""
|
|
24
29
|
def __init__(self, pad_color: Union[str, int] = "black") -> None:
|
|
25
30
|
self.pad_color = pad_color
|
|
26
|
-
# Store kwargs to allow for
|
|
31
|
+
# Store kwargs to allow for re-creation
|
|
27
32
|
self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
|
|
28
|
-
# self._kwargs = {"pad_color": pad_color}
|
|
29
33
|
|
|
30
34
|
def __call__(self, image: Image.Image) -> Image.Image:
|
|
31
35
|
if not isinstance(image, Image.Image):
|
|
@@ -47,12 +51,234 @@ class ResizeAspectFill:
|
|
|
47
51
|
padding = (left_padding, 0, right_padding, 0)
|
|
48
52
|
|
|
49
53
|
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
50
|
-
|
|
51
54
|
|
|
52
|
-
|
|
55
|
+
|
|
56
|
+
#############################################################
|
|
57
|
+
#NOTE: Add custom transforms.
|
|
53
58
|
TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
|
|
54
59
|
"ResizeAspectFill": ResizeAspectFill,
|
|
55
60
|
}
|
|
61
|
+
#############################################################
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def create_offline_augmentations(
|
|
65
|
+
input_directory: Union[str, Path],
|
|
66
|
+
output_directory: Union[str, Path],
|
|
67
|
+
results_per_image: int,
|
|
68
|
+
recipe: Optional[Dict[str, Any]] = None,
|
|
69
|
+
save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
|
|
70
|
+
save_quality: int = 80
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Reads all valid images from an input directory, applies augmentations,
|
|
74
|
+
and saves the new images to an output directory (offline augmentation).
|
|
75
|
+
|
|
76
|
+
Skips subdirectories in the input path.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
input_directory (Union[str, Path]): Path to the directory of source images.
|
|
80
|
+
output_directory (Union[str, Path]): Path to save the augmented images.
|
|
81
|
+
results_per_image (int): The number of augmented versions to create
|
|
82
|
+
for each source image.
|
|
83
|
+
recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
|
|
84
|
+
a default set of strong, random
|
|
85
|
+
augmentations will be used.
|
|
86
|
+
save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
|
|
87
|
+
Defaults to "WEBP" for good compression.
|
|
88
|
+
save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
|
|
89
|
+
"""
|
|
90
|
+
VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
|
|
91
|
+
|
|
92
|
+
# --- 1. Validate Paths ---
|
|
93
|
+
in_path = make_fullpath(input_directory, enforce="directory")
|
|
94
|
+
out_path = make_fullpath(output_directory, make=True, enforce="directory")
|
|
95
|
+
|
|
96
|
+
_LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
|
|
97
|
+
|
|
98
|
+
# --- 2. Find Images ---
|
|
99
|
+
image_files = [
|
|
100
|
+
f for f in in_path.iterdir()
|
|
101
|
+
if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
if not image_files:
|
|
105
|
+
_LOGGER.warning(f"No valid image files found in {in_path}.")
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
_LOGGER.info(f"Found {len(image_files)} images to process.")
|
|
109
|
+
|
|
110
|
+
# --- 3. Define Transform Pipeline ---
|
|
111
|
+
transform_pipeline: transforms.Compose
|
|
112
|
+
|
|
113
|
+
if recipe:
|
|
114
|
+
_LOGGER.info("Building transformations from provided recipe.")
|
|
115
|
+
try:
|
|
116
|
+
transform_pipeline = _build_transform_from_recipe(recipe)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
_LOGGER.error(f"Failed to build transform from recipe: {e}")
|
|
119
|
+
return
|
|
120
|
+
else:
|
|
121
|
+
_LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
|
|
122
|
+
# Default "random" pipeline
|
|
123
|
+
transform_pipeline = transforms.Compose([
|
|
124
|
+
transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
|
|
125
|
+
transforms.RandomHorizontalFlip(p=0.5),
|
|
126
|
+
transforms.RandomRotation(degrees=90),
|
|
127
|
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
|
|
128
|
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
|
|
129
|
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
|
130
|
+
transforms.RandomApply([
|
|
131
|
+
transforms.GaussianBlur(kernel_size=3)
|
|
132
|
+
], p=0.3)
|
|
133
|
+
])
|
|
134
|
+
|
|
135
|
+
# --- 4. Process Images ---
|
|
136
|
+
total_saved = 0
|
|
137
|
+
format_upper = save_format.upper()
|
|
138
|
+
|
|
139
|
+
for img_path in image_files:
|
|
140
|
+
_LOGGER.debug(f"Processing {img_path.name}...")
|
|
141
|
+
try:
|
|
142
|
+
original_image = Image.open(img_path).convert("RGB")
|
|
143
|
+
|
|
144
|
+
for i in range(results_per_image):
|
|
145
|
+
new_stem = f"{img_path.stem}_aug_{i+1:03d}"
|
|
146
|
+
output_path = out_path / f"{new_stem}.{format_upper.lower()}"
|
|
147
|
+
|
|
148
|
+
# Apply transform
|
|
149
|
+
transformed_image = transform_pipeline(original_image)
|
|
150
|
+
|
|
151
|
+
# Save
|
|
152
|
+
transformed_image.save(
|
|
153
|
+
output_path,
|
|
154
|
+
format=format_upper,
|
|
155
|
+
quality=save_quality,
|
|
156
|
+
optimize=True # Add optimize flag
|
|
157
|
+
)
|
|
158
|
+
total_saved += 1
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
_LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
|
|
162
|
+
|
|
163
|
+
_LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
|
|
167
|
+
"""Internal helper to build a transform pipeline from a recipe dict."""
|
|
168
|
+
pipeline_steps: List[Callable] = []
|
|
169
|
+
|
|
170
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
171
|
+
_LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
|
|
172
|
+
raise ValueError("Invalid recipe format.")
|
|
173
|
+
|
|
174
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
175
|
+
t_name = step.get(VisionTransformRecipeKeys.NAME)
|
|
176
|
+
t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
|
|
177
|
+
|
|
178
|
+
if not t_name:
|
|
179
|
+
_LOGGER.error(f"Invalid transform step, missing 'name': {step}")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
transform_class: Any = None
|
|
183
|
+
|
|
184
|
+
# 1. Check standard torchvision transforms
|
|
185
|
+
if hasattr(transforms, t_name):
|
|
186
|
+
transform_class = getattr(transforms, t_name)
|
|
187
|
+
# 2. Check custom transforms
|
|
188
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
189
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
190
|
+
# 3. Not found
|
|
191
|
+
else:
|
|
192
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
193
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
194
|
+
|
|
195
|
+
# Instantiate the transform
|
|
196
|
+
try:
|
|
197
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
198
|
+
except Exception as e:
|
|
199
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
return transforms.Compose(pipeline_steps)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _save_recipe(recipe: Dict[str, Any], filepath: Path) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Saves a transform recipe dictionary to a JSON file.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
recipe (Dict[str, Any]): The recipe dictionary to save.
|
|
211
|
+
filepath (str): The path to the output .json file.
|
|
212
|
+
"""
|
|
213
|
+
final_filepath = filepath.with_suffix(".json")
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
with open(final_filepath, 'w') as f:
|
|
217
|
+
json.dump(recipe, f, indent=4)
|
|
218
|
+
_LOGGER.info(f"Transform recipe saved as '{final_filepath.name}'.")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
_LOGGER.error(f"Failed to save recipe to '{final_filepath}': {e}")
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _load_recipe_and_build_transform(filepath: Union[str,Path]) -> transforms.Compose:
|
|
225
|
+
"""
|
|
226
|
+
Loads a transform recipe from a .json file and reconstructs the
|
|
227
|
+
torchvision.transforms.Compose pipeline.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
filepath (str): Path to the saved transform recipe .json file.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
transforms.Compose: The reconstructed transformation pipeline.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
ValueError: If a transform name in the recipe is not found in
|
|
237
|
+
torchvision.transforms or the custom TRANSFORM_REGISTRY.
|
|
238
|
+
"""
|
|
239
|
+
# validate filepath
|
|
240
|
+
final_filepath = make_fullpath(filepath, enforce="file")
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
with open(final_filepath, 'r') as f:
|
|
244
|
+
recipe = json.load(f)
|
|
245
|
+
except Exception as e:
|
|
246
|
+
_LOGGER.error(f"Failed to load recipe from '{final_filepath}': {e}")
|
|
247
|
+
raise
|
|
248
|
+
|
|
249
|
+
pipeline_steps: List[Callable] = []
|
|
250
|
+
|
|
251
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
252
|
+
_LOGGER.error("Recipe file is invalid: missing 'pipeline' key.")
|
|
253
|
+
raise ValueError("Invalid recipe format.")
|
|
254
|
+
|
|
255
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
256
|
+
t_name = step[VisionTransformRecipeKeys.NAME]
|
|
257
|
+
t_kwargs = step[VisionTransformRecipeKeys.KWARGS]
|
|
258
|
+
|
|
259
|
+
transform_class: Any = None
|
|
260
|
+
|
|
261
|
+
# 1. Check standard torchvision transforms
|
|
262
|
+
if hasattr(transforms, t_name):
|
|
263
|
+
transform_class = getattr(transforms, t_name)
|
|
264
|
+
# 2. Check custom transforms
|
|
265
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
266
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
267
|
+
# 3. Not found
|
|
268
|
+
else:
|
|
269
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
270
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
271
|
+
|
|
272
|
+
# Instantiate the transform
|
|
273
|
+
try:
|
|
274
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
275
|
+
except Exception as e:
|
|
276
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
277
|
+
raise
|
|
278
|
+
|
|
279
|
+
_LOGGER.info(f"Successfully loaded and built transform pipeline from '{final_filepath.name}'.")
|
|
280
|
+
return transforms.Compose(pipeline_steps)
|
|
281
|
+
|
|
56
282
|
|
|
57
283
|
def info():
|
|
58
284
|
_script_info(__all__)
|
ml_tools/PSO_optimization.py
CHANGED
|
@@ -12,9 +12,9 @@ from .serde import deserialize_object
|
|
|
12
12
|
from .math_utilities import threshold_binary_values, threshold_binary_values_batch
|
|
13
13
|
from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension
|
|
14
14
|
from ._logger import _LOGGER
|
|
15
|
-
from .
|
|
15
|
+
from ._keys import EnsembleKeys
|
|
16
16
|
from ._script_info import _script_info
|
|
17
|
-
from .SQL import
|
|
17
|
+
from .SQL import DragonSQL
|
|
18
18
|
from .optimization_tools import _save_result
|
|
19
19
|
|
|
20
20
|
"""
|
|
@@ -191,7 +191,7 @@ def _set_feature_names(size: int, names: Union[list[str], None]):
|
|
|
191
191
|
return names
|
|
192
192
|
|
|
193
193
|
|
|
194
|
-
def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[
|
|
194
|
+
def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
|
|
195
195
|
"""Helper for a single PSO run that also handles saving."""
|
|
196
196
|
pso_args.update({"seed": random_state})
|
|
197
197
|
|
|
@@ -213,7 +213,7 @@ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, featu
|
|
|
213
213
|
return best_features_named, best_target_named
|
|
214
214
|
|
|
215
215
|
|
|
216
|
-
def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[
|
|
216
|
+
def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
|
|
217
217
|
"""Helper for post-hoc analysis that saves results incrementally."""
|
|
218
218
|
progress = trange(repetitions, desc="Post-Hoc PSO", unit="run")
|
|
219
219
|
for _ in progress:
|
|
@@ -342,7 +342,7 @@ def run_pso(lower_boundaries: list[float],
|
|
|
342
342
|
schema = {"result_id": "INTEGER PRIMARY KEY AUTOINCREMENT", **schema}
|
|
343
343
|
|
|
344
344
|
# Create table
|
|
345
|
-
with
|
|
345
|
+
with DragonSQL(db_path) as db:
|
|
346
346
|
db.create_table(db_table_name, schema)
|
|
347
347
|
|
|
348
348
|
pso_arguments = {
|
|
@@ -357,7 +357,7 @@ def run_pso(lower_boundaries: list[float],
|
|
|
357
357
|
|
|
358
358
|
# --- Dispatcher ---
|
|
359
359
|
# Use a real or dummy context manager to handle the DB connection cleanly
|
|
360
|
-
db_context =
|
|
360
|
+
db_context = DragonSQL(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
|
|
361
361
|
|
|
362
362
|
with db_context as db_manager:
|
|
363
363
|
if post_hoc_analysis is None or post_hoc_analysis <= 1:
|
ml_tools/SQL.py
CHANGED
|
@@ -9,11 +9,11 @@ from .path_manager import make_fullpath, sanitize_filename
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
__all__ = [
|
|
12
|
-
"
|
|
12
|
+
"DragonSQL",
|
|
13
13
|
]
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class
|
|
16
|
+
class DragonSQL:
|
|
17
17
|
"""
|
|
18
18
|
A user-friendly context manager for handling SQLite database operations.
|
|
19
19
|
|
|
@@ -35,7 +35,7 @@ class DatabaseManager:
|
|
|
35
35
|
... "feature_a": "REAL",
|
|
36
36
|
... "score": "REAL"
|
|
37
37
|
... }
|
|
38
|
-
>>> with
|
|
38
|
+
>>> with DragonSQL("my_results.db") as db:
|
|
39
39
|
... db.create_table("experiments", schema)
|
|
40
40
|
... data = {"run_name": "first_run", "feature_a": 0.123, "score": 95.5}
|
|
41
41
|
... db.insert_row("experiments", data)
|
|
@@ -43,7 +43,7 @@ class DatabaseManager:
|
|
|
43
43
|
... print(df)
|
|
44
44
|
"""
|
|
45
45
|
def __init__(self, db_path: Union[str, Path]):
|
|
46
|
-
"""Initializes the
|
|
46
|
+
"""Initializes the DragonSQL with the path to the database file."""
|
|
47
47
|
if isinstance(db_path, str):
|
|
48
48
|
if not db_path.endswith(".db"):
|
|
49
49
|
db_path = db_path + ".db"
|
ml_tools/{keys.py → _keys.py}
RENAMED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
class MagicWords:
|
|
2
|
+
"""General purpose keys"""
|
|
3
|
+
LATEST = "latest"
|
|
4
|
+
CURRENT = "current"
|
|
5
|
+
RENAME = "rename"
|
|
6
|
+
|
|
7
|
+
|
|
1
8
|
class PyTorchLogKeys:
|
|
2
9
|
"""
|
|
3
10
|
Used internally for ML scripts module.
|
|
@@ -7,6 +14,7 @@ class PyTorchLogKeys:
|
|
|
7
14
|
# --- Epoch Level ---
|
|
8
15
|
TRAIN_LOSS = 'train_loss'
|
|
9
16
|
VAL_LOSS = 'val_loss'
|
|
17
|
+
LEARNING_RATE = 'lr'
|
|
10
18
|
|
|
11
19
|
# --- Batch Level ---
|
|
12
20
|
BATCH_LOSS = 'loss'
|
|
@@ -79,6 +87,13 @@ class PyTorchCheckpointKeys:
|
|
|
79
87
|
SCHEDULER_STATE = "scheduler_state_dict"
|
|
80
88
|
EPOCH = "epoch"
|
|
81
89
|
BEST_SCORE = "best_score"
|
|
90
|
+
HISTORY = "history"
|
|
91
|
+
CHECKPOINT_NAME = "PyModelCheckpoint"
|
|
92
|
+
# Finalized config
|
|
93
|
+
CLASSIFICATION_THRESHOLD = "classification_threshold"
|
|
94
|
+
CLASS_MAP = "class_map"
|
|
95
|
+
SEQUENCE_LENGTH = "sequence_length"
|
|
96
|
+
INITIAL_SEQUENCE = "initial_sequence"
|
|
82
97
|
|
|
83
98
|
|
|
84
99
|
class UtilityKeys:
|
|
@@ -104,8 +119,9 @@ class VisionTransformRecipeKeys:
|
|
|
104
119
|
TASK = "task"
|
|
105
120
|
PIPELINE = "pipeline"
|
|
106
121
|
NAME = "name"
|
|
107
|
-
KWARGS = "
|
|
122
|
+
KWARGS = "kwargs"
|
|
108
123
|
PRE_TRANSFORMS = "pre_transforms"
|
|
124
|
+
|
|
109
125
|
RESIZE_SIZE = "resize_size"
|
|
110
126
|
CROP_SIZE = "crop_size"
|
|
111
127
|
MEAN = "mean"
|
|
@@ -118,6 +134,34 @@ class ObjectDetectionKeys:
|
|
|
118
134
|
LABELS = "labels"
|
|
119
135
|
|
|
120
136
|
|
|
137
|
+
class MLTaskKeys:
|
|
138
|
+
"""Used by the Trainer and InferenceHandlers"""
|
|
139
|
+
REGRESSION = "regression"
|
|
140
|
+
MULTITARGET_REGRESSION = "multitarget regression"
|
|
141
|
+
|
|
142
|
+
BINARY_CLASSIFICATION = "binary classification"
|
|
143
|
+
MULTICLASS_CLASSIFICATION = "multiclass classification"
|
|
144
|
+
MULTILABEL_BINARY_CLASSIFICATION = "multilabel binary classification"
|
|
145
|
+
|
|
146
|
+
BINARY_IMAGE_CLASSIFICATION = "binary image classification"
|
|
147
|
+
MULTICLASS_IMAGE_CLASSIFICATION = "multiclass image classification"
|
|
148
|
+
|
|
149
|
+
BINARY_SEGMENTATION = "binary segmentation"
|
|
150
|
+
MULTICLASS_SEGMENTATION = "multiclass segmentation"
|
|
151
|
+
|
|
152
|
+
OBJECT_DETECTION = "object detection"
|
|
153
|
+
|
|
154
|
+
SEQUENCE_SEQUENCE = "sequence-to-sequence"
|
|
155
|
+
SEQUENCE_VALUE = "sequence-to-value"
|
|
156
|
+
|
|
157
|
+
ALL_BINARY_TASKS = [BINARY_CLASSIFICATION, MULTILABEL_BINARY_CLASSIFICATION, BINARY_IMAGE_CLASSIFICATION, BINARY_SEGMENTATION]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class DragonTrainerKeys:
|
|
161
|
+
VALIDATION_METRICS_DIR = "Validation_Metrics"
|
|
162
|
+
TEST_METRICS_DIR = "Test_Metrics"
|
|
163
|
+
|
|
164
|
+
|
|
121
165
|
class _OneHotOtherPlaceholder:
|
|
122
166
|
"""Used internally by GUI_tools."""
|
|
123
167
|
OTHER_GUI = "OTHER"
|
ml_tools/_schema.py
CHANGED
ml_tools/ensemble_evaluation.py
CHANGED
|
@@ -25,7 +25,7 @@ from typing import Union, Optional, Literal
|
|
|
25
25
|
from .path_manager import sanitize_filename, make_fullpath
|
|
26
26
|
from ._script_info import _script_info
|
|
27
27
|
from ._logger import _LOGGER
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
__all__ = [
|
|
@@ -112,7 +112,7 @@ def evaluate_model_classification(
|
|
|
112
112
|
report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
|
|
113
113
|
plt.figure(figsize=figsize)
|
|
114
114
|
sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
|
|
115
|
-
annot_kws={"size": base_fontsize - 4})
|
|
115
|
+
annot_kws={"size": base_fontsize - 4}, vmin=0.0, vmax=1.0)
|
|
116
116
|
plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
117
117
|
plt.xticks(fontsize=base_fontsize - 2)
|
|
118
118
|
plt.yticks(fontsize=base_fontsize - 2)
|
|
@@ -133,6 +133,7 @@ def evaluate_model_classification(
|
|
|
133
133
|
normalize="true",
|
|
134
134
|
ax=ax
|
|
135
135
|
)
|
|
136
|
+
disp.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
136
137
|
|
|
137
138
|
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
138
139
|
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
@@ -327,7 +328,8 @@ def plot_calibration_curve(
|
|
|
327
328
|
target_name: str,
|
|
328
329
|
figure_size: tuple = (10, 10),
|
|
329
330
|
base_fontsize: int = 24,
|
|
330
|
-
n_bins: int = 15
|
|
331
|
+
n_bins: int = 15,
|
|
332
|
+
line_color: str = 'darkorange'
|
|
331
333
|
) -> plt.Figure: # type: ignore
|
|
332
334
|
"""
|
|
333
335
|
Plots the calibration curve (reliability diagram) for a classifier.
|
|
@@ -348,22 +350,63 @@ def plot_calibration_curve(
|
|
|
348
350
|
"""
|
|
349
351
|
fig, ax = plt.subplots(figsize=figure_size)
|
|
350
352
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
353
|
+
# --- Step 1: Get probabilities from the estimator ---
|
|
354
|
+
# We do this manually so we can pass them to from_predictions
|
|
355
|
+
try:
|
|
356
|
+
y_prob = model.predict_proba(x_test)
|
|
357
|
+
# Use probabilities for the positive class (assuming binary)
|
|
358
|
+
y_score = y_prob[:, 1]
|
|
359
|
+
except Exception as e:
|
|
360
|
+
_LOGGER.error(f"Could not get probabilities from model: {e}")
|
|
361
|
+
plt.close(fig)
|
|
362
|
+
return fig # Return empty figure
|
|
363
|
+
|
|
364
|
+
# --- Step 2: Get binned data *without* plotting ---
|
|
365
|
+
with plt.ioff():
|
|
366
|
+
fig_temp, ax_temp = plt.subplots()
|
|
367
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
368
|
+
y_test,
|
|
369
|
+
y_score,
|
|
370
|
+
n_bins=n_bins,
|
|
371
|
+
ax=ax_temp,
|
|
372
|
+
name="temp"
|
|
373
|
+
)
|
|
374
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
375
|
+
plt.close(fig_temp)
|
|
376
|
+
|
|
377
|
+
# --- Step 3: Build the plot from scratch on ax ---
|
|
378
|
+
|
|
379
|
+
# 3a. Plot the ideal diagonal line
|
|
380
|
+
ax.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
381
|
+
|
|
382
|
+
# 3b. Use regplot for the regression line and its CI
|
|
383
|
+
sns.regplot(
|
|
384
|
+
x=line_x,
|
|
385
|
+
y=line_y,
|
|
386
|
+
ax=ax,
|
|
387
|
+
scatter=False, # No scatter dots
|
|
388
|
+
label=f"Calibration Curve ({n_bins} bins)",
|
|
389
|
+
line_kws={
|
|
390
|
+
'color': line_color,
|
|
391
|
+
'linestyle': '--',
|
|
392
|
+
'linewidth': 2
|
|
393
|
+
}
|
|
357
394
|
)
|
|
358
395
|
|
|
396
|
+
# --- Step 4: Apply original formatting ---
|
|
359
397
|
ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
|
|
360
398
|
ax.tick_params(axis='both', labelsize=base_fontsize - 2)
|
|
361
399
|
ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
|
|
362
400
|
ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
|
|
363
|
-
|
|
401
|
+
|
|
402
|
+
# Set limits
|
|
403
|
+
ax.set_ylim(0.0, 1.0)
|
|
404
|
+
ax.set_xlim(0.0, 1.0)
|
|
405
|
+
|
|
406
|
+
ax.legend(fontsize=base_fontsize - 4, loc='lower right')
|
|
364
407
|
fig.tight_layout()
|
|
365
408
|
|
|
366
|
-
# Save figure
|
|
409
|
+
# --- Step 5: Save figure (using original logic) ---
|
|
367
410
|
save_path = make_fullpath(save_dir, make=True)
|
|
368
411
|
sanitized_target_name = sanitize_filename(target_name)
|
|
369
412
|
full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
|
ml_tools/ensemble_inference.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Union, Literal, Dict, Any, Optional, List
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
import json
|
|
4
|
-
import joblib
|
|
5
4
|
import numpy as np
|
|
6
5
|
# Inference models
|
|
7
6
|
import xgboost
|
|
@@ -10,16 +9,17 @@ import lightgbm
|
|
|
10
9
|
from ._script_info import _script_info
|
|
11
10
|
from ._logger import _LOGGER
|
|
12
11
|
from .path_manager import make_fullpath, list_files_by_extension
|
|
13
|
-
from .
|
|
12
|
+
from ._keys import EnsembleKeys
|
|
13
|
+
from .serde import deserialize_object
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
"
|
|
17
|
+
"DragonEnsembleInferenceHandler",
|
|
18
18
|
"model_report"
|
|
19
19
|
]
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class
|
|
22
|
+
class DragonEnsembleInferenceHandler:
|
|
23
23
|
"""
|
|
24
24
|
Handles loading ensemble models and performing inference for either regression or classification tasks.
|
|
25
25
|
"""
|
|
@@ -44,9 +44,9 @@ class InferenceHandler:
|
|
|
44
44
|
for fname, fpath in model_files.items():
|
|
45
45
|
try:
|
|
46
46
|
full_object: dict
|
|
47
|
-
full_object =
|
|
47
|
+
full_object = deserialize_object(filepath=fpath,
|
|
48
48
|
verbose=self.verbose,
|
|
49
|
-
|
|
49
|
+
expected_type=dict)
|
|
50
50
|
|
|
51
51
|
model: Any = full_object[EnsembleKeys.MODEL]
|
|
52
52
|
target_name: str = full_object[EnsembleKeys.TARGET]
|
|
@@ -170,7 +170,7 @@ def model_report(
|
|
|
170
170
|
|
|
171
171
|
# --- 2. Deserialize and Extract Info ---
|
|
172
172
|
try:
|
|
173
|
-
full_object: dict =
|
|
173
|
+
full_object: dict = deserialize_object(model_p, expected_type=dict, verbose=verbose) # type: ignore
|
|
174
174
|
model = full_object[EnsembleKeys.MODEL]
|
|
175
175
|
target = full_object[EnsembleKeys.TARGET]
|
|
176
176
|
features = full_object[EnsembleKeys.FEATURES]
|
|
@@ -218,31 +218,5 @@ def model_report(
|
|
|
218
218
|
return report_data
|
|
219
219
|
|
|
220
220
|
|
|
221
|
-
# Local implementation to avoid calling utilities dependencies
|
|
222
|
-
def _deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_error: bool=True) -> Optional[Any]:
|
|
223
|
-
"""
|
|
224
|
-
Loads a serialized object from a .joblib file.
|
|
225
|
-
|
|
226
|
-
Parameters:
|
|
227
|
-
filepath (str | Path): Full path to the serialized .joblib file.
|
|
228
|
-
|
|
229
|
-
Returns:
|
|
230
|
-
(Any | None): The deserialized Python object, or None if loading fails.
|
|
231
|
-
"""
|
|
232
|
-
true_filepath = make_fullpath(filepath)
|
|
233
|
-
|
|
234
|
-
try:
|
|
235
|
-
obj = joblib.load(true_filepath)
|
|
236
|
-
except (IOError, OSError, EOFError, TypeError, ValueError) as e:
|
|
237
|
-
_LOGGER.error(f"Failed to deserialize object from '{true_filepath}'.")
|
|
238
|
-
if raise_on_error:
|
|
239
|
-
raise e
|
|
240
|
-
return None
|
|
241
|
-
else:
|
|
242
|
-
if verbose:
|
|
243
|
-
_LOGGER.info(f"Loaded object of type '{type(obj)}'")
|
|
244
|
-
return obj
|
|
245
|
-
|
|
246
|
-
|
|
247
221
|
def info():
|
|
248
222
|
_script_info(__all__)
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -17,7 +17,7 @@ from .utilities import yield_dataframes_from_dir, train_dataset_yielder
|
|
|
17
17
|
from .serde import serialize_object_filename
|
|
18
18
|
from .path_manager import sanitize_filename, make_fullpath
|
|
19
19
|
from ._script_info import _script_info
|
|
20
|
-
from .
|
|
20
|
+
from ._keys import EnsembleKeys
|
|
21
21
|
from ._logger import _LOGGER
|
|
22
22
|
from .ensemble_evaluation import (evaluate_model_classification,
|
|
23
23
|
plot_roc_curve,
|
ml_tools/optimization_tools.py
CHANGED
|
@@ -8,7 +8,7 @@ from .path_manager import make_fullpath, list_csv_paths, sanitize_filename
|
|
|
8
8
|
from .utilities import yield_dataframes_from_dir
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
|
-
from .SQL import
|
|
11
|
+
from .SQL import DragonSQL
|
|
12
12
|
from ._schema import FeatureSchema
|
|
13
13
|
|
|
14
14
|
|
|
@@ -262,7 +262,7 @@ def _save_result(
|
|
|
262
262
|
result_dict: dict,
|
|
263
263
|
save_format: Literal['csv', 'sqlite', 'both'],
|
|
264
264
|
csv_path: Path,
|
|
265
|
-
db_manager: Optional[
|
|
265
|
+
db_manager: Optional[DragonSQL] = None,
|
|
266
266
|
db_table_name: Optional[str] = None,
|
|
267
267
|
categorical_mappings: Optional[Dict[str, Dict[str, int]]] = None
|
|
268
268
|
):
|