dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
|
|
2
|
+
from PIL import ImageOps, Image
|
|
3
|
+
from torchvision import transforms
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
from ._logger import _LOGGER
|
|
8
|
+
from ._script_info import _script_info
|
|
9
|
+
from ._keys import VisionTransformRecipeKeys
|
|
10
|
+
from .path_manager import make_fullpath
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TRANSFORM_REGISTRY",
|
|
15
|
+
"ResizeAspectFill",
|
|
16
|
+
"create_offline_augmentations"
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
# --- Custom Vision Transform Class ---
|
|
20
|
+
class ResizeAspectFill:
|
|
21
|
+
"""
|
|
22
|
+
Custom transformation to make an image square by padding it to match the
|
|
23
|
+
longest side, preserving the aspect ratio. The image is finally centered.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
pad_color (Union[str, int]): Color to use for the padding.
|
|
27
|
+
Defaults to "black".
|
|
28
|
+
"""
|
|
29
|
+
def __init__(self, pad_color: Union[str, int] = "black") -> None:
|
|
30
|
+
self.pad_color = pad_color
|
|
31
|
+
# Store kwargs to allow for re-creation
|
|
32
|
+
self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
|
|
33
|
+
|
|
34
|
+
def __call__(self, image: Image.Image) -> Image.Image:
|
|
35
|
+
if not isinstance(image, Image.Image):
|
|
36
|
+
_LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
|
|
37
|
+
raise TypeError()
|
|
38
|
+
|
|
39
|
+
w, h = image.size
|
|
40
|
+
if w == h:
|
|
41
|
+
return image
|
|
42
|
+
|
|
43
|
+
# Determine padding to center the image
|
|
44
|
+
if w > h:
|
|
45
|
+
top_padding = (w - h) // 2
|
|
46
|
+
bottom_padding = w - h - top_padding
|
|
47
|
+
padding = (0, top_padding, 0, bottom_padding)
|
|
48
|
+
else: # h > w
|
|
49
|
+
left_padding = (h - w) // 2
|
|
50
|
+
right_padding = h - w - left_padding
|
|
51
|
+
padding = (left_padding, 0, right_padding, 0)
|
|
52
|
+
|
|
53
|
+
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
#############################################################
|
|
57
|
+
#NOTE: Add custom transforms.
|
|
58
|
+
TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
|
|
59
|
+
"ResizeAspectFill": ResizeAspectFill,
|
|
60
|
+
}
|
|
61
|
+
#############################################################
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def create_offline_augmentations(
|
|
65
|
+
input_directory: Union[str, Path],
|
|
66
|
+
output_directory: Union[str, Path],
|
|
67
|
+
results_per_image: int,
|
|
68
|
+
recipe: Optional[Dict[str, Any]] = None,
|
|
69
|
+
save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
|
|
70
|
+
save_quality: int = 80
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Reads all valid images from an input directory, applies augmentations,
|
|
74
|
+
and saves the new images to an output directory (offline augmentation).
|
|
75
|
+
|
|
76
|
+
Skips subdirectories in the input path.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
input_directory (Union[str, Path]): Path to the directory of source images.
|
|
80
|
+
output_directory (Union[str, Path]): Path to save the augmented images.
|
|
81
|
+
results_per_image (int): The number of augmented versions to create
|
|
82
|
+
for each source image.
|
|
83
|
+
recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
|
|
84
|
+
a default set of strong, random
|
|
85
|
+
augmentations will be used.
|
|
86
|
+
save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
|
|
87
|
+
Defaults to "WEBP" for good compression.
|
|
88
|
+
save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
|
|
89
|
+
"""
|
|
90
|
+
VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
|
|
91
|
+
|
|
92
|
+
# --- 1. Validate Paths ---
|
|
93
|
+
in_path = make_fullpath(input_directory, enforce="directory")
|
|
94
|
+
out_path = make_fullpath(output_directory, make=True, enforce="directory")
|
|
95
|
+
|
|
96
|
+
_LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
|
|
97
|
+
|
|
98
|
+
# --- 2. Find Images ---
|
|
99
|
+
image_files = [
|
|
100
|
+
f for f in in_path.iterdir()
|
|
101
|
+
if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
if not image_files:
|
|
105
|
+
_LOGGER.warning(f"No valid image files found in {in_path}.")
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
_LOGGER.info(f"Found {len(image_files)} images to process.")
|
|
109
|
+
|
|
110
|
+
# --- 3. Define Transform Pipeline ---
|
|
111
|
+
transform_pipeline: transforms.Compose
|
|
112
|
+
|
|
113
|
+
if recipe:
|
|
114
|
+
_LOGGER.info("Building transformations from provided recipe.")
|
|
115
|
+
try:
|
|
116
|
+
transform_pipeline = _build_transform_from_recipe(recipe)
|
|
117
|
+
except Exception as e:
|
|
118
|
+
_LOGGER.error(f"Failed to build transform from recipe: {e}")
|
|
119
|
+
return
|
|
120
|
+
else:
|
|
121
|
+
_LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
|
|
122
|
+
# Default "random" pipeline
|
|
123
|
+
transform_pipeline = transforms.Compose([
|
|
124
|
+
transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
|
|
125
|
+
transforms.RandomHorizontalFlip(p=0.5),
|
|
126
|
+
transforms.RandomRotation(degrees=90),
|
|
127
|
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
|
|
128
|
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
|
|
129
|
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
|
130
|
+
transforms.RandomApply([
|
|
131
|
+
transforms.GaussianBlur(kernel_size=3)
|
|
132
|
+
], p=0.3)
|
|
133
|
+
])
|
|
134
|
+
|
|
135
|
+
# --- 4. Process Images ---
|
|
136
|
+
total_saved = 0
|
|
137
|
+
format_upper = save_format.upper()
|
|
138
|
+
|
|
139
|
+
for img_path in image_files:
|
|
140
|
+
_LOGGER.debug(f"Processing {img_path.name}...")
|
|
141
|
+
try:
|
|
142
|
+
original_image = Image.open(img_path).convert("RGB")
|
|
143
|
+
|
|
144
|
+
for i in range(results_per_image):
|
|
145
|
+
new_stem = f"{img_path.stem}_aug_{i+1:03d}"
|
|
146
|
+
output_path = out_path / f"{new_stem}.{format_upper.lower()}"
|
|
147
|
+
|
|
148
|
+
# Apply transform
|
|
149
|
+
transformed_image = transform_pipeline(original_image)
|
|
150
|
+
|
|
151
|
+
# Save
|
|
152
|
+
transformed_image.save(
|
|
153
|
+
output_path,
|
|
154
|
+
format=format_upper,
|
|
155
|
+
quality=save_quality,
|
|
156
|
+
optimize=True # Add optimize flag
|
|
157
|
+
)
|
|
158
|
+
total_saved += 1
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
_LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
|
|
162
|
+
|
|
163
|
+
_LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
|
|
167
|
+
"""Internal helper to build a transform pipeline from a recipe dict."""
|
|
168
|
+
pipeline_steps: List[Callable] = []
|
|
169
|
+
|
|
170
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
171
|
+
_LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
|
|
172
|
+
raise ValueError("Invalid recipe format.")
|
|
173
|
+
|
|
174
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
175
|
+
t_name = step.get(VisionTransformRecipeKeys.NAME)
|
|
176
|
+
t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
|
|
177
|
+
|
|
178
|
+
if not t_name:
|
|
179
|
+
_LOGGER.error(f"Invalid transform step, missing 'name': {step}")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
transform_class: Any = None
|
|
183
|
+
|
|
184
|
+
# 1. Check standard torchvision transforms
|
|
185
|
+
if hasattr(transforms, t_name):
|
|
186
|
+
transform_class = getattr(transforms, t_name)
|
|
187
|
+
# 2. Check custom transforms
|
|
188
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
189
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
190
|
+
# 3. Not found
|
|
191
|
+
else:
|
|
192
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
193
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
194
|
+
|
|
195
|
+
# Instantiate the transform
|
|
196
|
+
try:
|
|
197
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
198
|
+
except Exception as e:
|
|
199
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
return transforms.Compose(pipeline_steps)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _save_recipe(recipe: Dict[str, Any], filepath: Path) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Saves a transform recipe dictionary to a JSON file.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
recipe (Dict[str, Any]): The recipe dictionary to save.
|
|
211
|
+
filepath (str): The path to the output .json file.
|
|
212
|
+
"""
|
|
213
|
+
final_filepath = filepath.with_suffix(".json")
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
with open(final_filepath, 'w') as f:
|
|
217
|
+
json.dump(recipe, f, indent=4)
|
|
218
|
+
_LOGGER.info(f"Transform recipe saved as '{final_filepath.name}'.")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
_LOGGER.error(f"Failed to save recipe to '{final_filepath}': {e}")
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _load_recipe_and_build_transform(filepath: Union[str,Path]) -> transforms.Compose:
|
|
225
|
+
"""
|
|
226
|
+
Loads a transform recipe from a .json file and reconstructs the
|
|
227
|
+
torchvision.transforms.Compose pipeline.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
filepath (str): Path to the saved transform recipe .json file.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
transforms.Compose: The reconstructed transformation pipeline.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
ValueError: If a transform name in the recipe is not found in
|
|
237
|
+
torchvision.transforms or the custom TRANSFORM_REGISTRY.
|
|
238
|
+
"""
|
|
239
|
+
# validate filepath
|
|
240
|
+
final_filepath = make_fullpath(filepath, enforce="file")
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
with open(final_filepath, 'r') as f:
|
|
244
|
+
recipe = json.load(f)
|
|
245
|
+
except Exception as e:
|
|
246
|
+
_LOGGER.error(f"Failed to load recipe from '{final_filepath}': {e}")
|
|
247
|
+
raise
|
|
248
|
+
|
|
249
|
+
pipeline_steps: List[Callable] = []
|
|
250
|
+
|
|
251
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
252
|
+
_LOGGER.error("Recipe file is invalid: missing 'pipeline' key.")
|
|
253
|
+
raise ValueError("Invalid recipe format.")
|
|
254
|
+
|
|
255
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
256
|
+
t_name = step[VisionTransformRecipeKeys.NAME]
|
|
257
|
+
t_kwargs = step[VisionTransformRecipeKeys.KWARGS]
|
|
258
|
+
|
|
259
|
+
transform_class: Any = None
|
|
260
|
+
|
|
261
|
+
# 1. Check standard torchvision transforms
|
|
262
|
+
if hasattr(transforms, t_name):
|
|
263
|
+
transform_class = getattr(transforms, t_name)
|
|
264
|
+
# 2. Check custom transforms
|
|
265
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
266
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
267
|
+
# 3. Not found
|
|
268
|
+
else:
|
|
269
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
270
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
271
|
+
|
|
272
|
+
# Instantiate the transform
|
|
273
|
+
try:
|
|
274
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
275
|
+
except Exception as e:
|
|
276
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
277
|
+
raise
|
|
278
|
+
|
|
279
|
+
_LOGGER.info(f"Successfully loaded and built transform pipeline from '{final_filepath.name}'.")
|
|
280
|
+
return transforms.Compose(pipeline_steps)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def info():
|
|
284
|
+
_script_info(__all__)
|
ml_tools/PSO_optimization.py
CHANGED
|
@@ -12,9 +12,9 @@ from .serde import deserialize_object
|
|
|
12
12
|
from .math_utilities import threshold_binary_values, threshold_binary_values_batch
|
|
13
13
|
from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension
|
|
14
14
|
from ._logger import _LOGGER
|
|
15
|
-
from .
|
|
15
|
+
from ._keys import EnsembleKeys
|
|
16
16
|
from ._script_info import _script_info
|
|
17
|
-
from .SQL import
|
|
17
|
+
from .SQL import DragonSQL
|
|
18
18
|
from .optimization_tools import _save_result
|
|
19
19
|
|
|
20
20
|
"""
|
|
@@ -191,7 +191,7 @@ def _set_feature_names(size: int, names: Union[list[str], None]):
|
|
|
191
191
|
return names
|
|
192
192
|
|
|
193
193
|
|
|
194
|
-
def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[
|
|
194
|
+
def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
|
|
195
195
|
"""Helper for a single PSO run that also handles saving."""
|
|
196
196
|
pso_args.update({"seed": random_state})
|
|
197
197
|
|
|
@@ -213,7 +213,7 @@ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, featu
|
|
|
213
213
|
return best_features_named, best_target_named
|
|
214
214
|
|
|
215
215
|
|
|
216
|
-
def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[
|
|
216
|
+
def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
|
|
217
217
|
"""Helper for post-hoc analysis that saves results incrementally."""
|
|
218
218
|
progress = trange(repetitions, desc="Post-Hoc PSO", unit="run")
|
|
219
219
|
for _ in progress:
|
|
@@ -342,7 +342,7 @@ def run_pso(lower_boundaries: list[float],
|
|
|
342
342
|
schema = {"result_id": "INTEGER PRIMARY KEY AUTOINCREMENT", **schema}
|
|
343
343
|
|
|
344
344
|
# Create table
|
|
345
|
-
with
|
|
345
|
+
with DragonSQL(db_path) as db:
|
|
346
346
|
db.create_table(db_table_name, schema)
|
|
347
347
|
|
|
348
348
|
pso_arguments = {
|
|
@@ -357,7 +357,7 @@ def run_pso(lower_boundaries: list[float],
|
|
|
357
357
|
|
|
358
358
|
# --- Dispatcher ---
|
|
359
359
|
# Use a real or dummy context manager to handle the DB connection cleanly
|
|
360
|
-
db_context =
|
|
360
|
+
db_context = DragonSQL(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
|
|
361
361
|
|
|
362
362
|
with db_context as db_manager:
|
|
363
363
|
if post_hoc_analysis is None or post_hoc_analysis <= 1:
|
ml_tools/SQL.py
CHANGED
|
@@ -9,11 +9,11 @@ from .path_manager import make_fullpath, sanitize_filename
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
__all__ = [
|
|
12
|
-
"
|
|
12
|
+
"DragonSQL",
|
|
13
13
|
]
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
class
|
|
16
|
+
class DragonSQL:
|
|
17
17
|
"""
|
|
18
18
|
A user-friendly context manager for handling SQLite database operations.
|
|
19
19
|
|
|
@@ -35,7 +35,7 @@ class DatabaseManager:
|
|
|
35
35
|
... "feature_a": "REAL",
|
|
36
36
|
... "score": "REAL"
|
|
37
37
|
... }
|
|
38
|
-
>>> with
|
|
38
|
+
>>> with DragonSQL("my_results.db") as db:
|
|
39
39
|
... db.create_table("experiments", schema)
|
|
40
40
|
... data = {"run_name": "first_run", "feature_a": 0.123, "score": 95.5}
|
|
41
41
|
... db.insert_row("experiments", data)
|
|
@@ -43,7 +43,7 @@ class DatabaseManager:
|
|
|
43
43
|
... print(df)
|
|
44
44
|
"""
|
|
45
45
|
def __init__(self, db_path: Union[str, Path]):
|
|
46
|
-
"""Initializes the
|
|
46
|
+
"""Initializes the DragonSQL with the path to the database file."""
|
|
47
47
|
if isinstance(db_path, str):
|
|
48
48
|
if not db_path.endswith(".db"):
|
|
49
49
|
db_path = db_path + ".db"
|
ml_tools/_keys.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
class MagicWords:
|
|
2
|
+
"""General purpose keys"""
|
|
3
|
+
LATEST = "latest"
|
|
4
|
+
CURRENT = "current"
|
|
5
|
+
RENAME = "rename"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PyTorchLogKeys:
|
|
9
|
+
"""
|
|
10
|
+
Used internally for ML scripts module.
|
|
11
|
+
|
|
12
|
+
Centralized keys for logging and history.
|
|
13
|
+
"""
|
|
14
|
+
# --- Epoch Level ---
|
|
15
|
+
TRAIN_LOSS = 'train_loss'
|
|
16
|
+
VAL_LOSS = 'val_loss'
|
|
17
|
+
LEARNING_RATE = 'lr'
|
|
18
|
+
|
|
19
|
+
# --- Batch Level ---
|
|
20
|
+
BATCH_LOSS = 'loss'
|
|
21
|
+
BATCH_INDEX = 'batch'
|
|
22
|
+
BATCH_SIZE = 'size'
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class EnsembleKeys:
|
|
26
|
+
"""
|
|
27
|
+
Used internally by ensemble_learning.
|
|
28
|
+
"""
|
|
29
|
+
# Serializing a trained model metadata.
|
|
30
|
+
MODEL = "model"
|
|
31
|
+
FEATURES = "feature_names"
|
|
32
|
+
TARGET = "target_name"
|
|
33
|
+
|
|
34
|
+
# Classification keys
|
|
35
|
+
CLASSIFICATION_LABEL = "labels"
|
|
36
|
+
CLASSIFICATION_PROBABILITIES = "probabilities"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PyTorchInferenceKeys:
|
|
40
|
+
"""Keys for the output dictionaries of PyTorchInferenceHandler."""
|
|
41
|
+
# For regression tasks
|
|
42
|
+
PREDICTIONS = "predictions"
|
|
43
|
+
|
|
44
|
+
# For classification tasks
|
|
45
|
+
LABELS = "labels"
|
|
46
|
+
PROBABILITIES = "probabilities"
|
|
47
|
+
LABEL_NAMES = "label_names"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PytorchModelArchitectureKeys:
|
|
51
|
+
"""Keys for saving and loading model architecture."""
|
|
52
|
+
MODEL = 'model_class'
|
|
53
|
+
CONFIG = "config"
|
|
54
|
+
SAVENAME = "architecture"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class PytorchArtifactPathKeys:
|
|
58
|
+
"""Keys for model artifact paths."""
|
|
59
|
+
FEATURES_PATH = "feature_names_path"
|
|
60
|
+
TARGETS_PATH = "target_names_path"
|
|
61
|
+
ARCHITECTURE_PATH = "model_architecture_path"
|
|
62
|
+
WEIGHTS_PATH = "model_weights_path"
|
|
63
|
+
SCALER_PATH = "scaler_path"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class DatasetKeys:
|
|
67
|
+
"""Keys for saving dataset artifacts. Also used by FeatureSchema"""
|
|
68
|
+
FEATURE_NAMES = "feature_names"
|
|
69
|
+
TARGET_NAMES = "target_names"
|
|
70
|
+
SCALER_PREFIX = "scaler_"
|
|
71
|
+
# Feature Schema
|
|
72
|
+
CONTINUOUS_NAMES = "continuous_feature_names"
|
|
73
|
+
CATEGORICAL_NAMES = "categorical_feature_names"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SHAPKeys:
|
|
77
|
+
"""Keys for SHAP functions"""
|
|
78
|
+
FEATURE_COLUMN = "feature"
|
|
79
|
+
SHAP_VALUE_COLUMN = "mean_abs_shap_value"
|
|
80
|
+
SAVENAME = "shap_summary"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class PyTorchCheckpointKeys:
|
|
84
|
+
"""Keys for saving/loading a training checkpoint dictionary."""
|
|
85
|
+
MODEL_STATE = "model_state_dict"
|
|
86
|
+
OPTIMIZER_STATE = "optimizer_state_dict"
|
|
87
|
+
SCHEDULER_STATE = "scheduler_state_dict"
|
|
88
|
+
EPOCH = "epoch"
|
|
89
|
+
BEST_SCORE = "best_score"
|
|
90
|
+
HISTORY = "history"
|
|
91
|
+
CHECKPOINT_NAME = "PyModelCheckpoint"
|
|
92
|
+
# Finalized config
|
|
93
|
+
CLASSIFICATION_THRESHOLD = "classification_threshold"
|
|
94
|
+
CLASS_MAP = "class_map"
|
|
95
|
+
SEQUENCE_LENGTH = "sequence_length"
|
|
96
|
+
INITIAL_SEQUENCE = "initial_sequence"
|
|
97
|
+
TARGET_NAME = "target_name"
|
|
98
|
+
TARGET_NAMES = "target_names"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class UtilityKeys:
|
|
102
|
+
"""Keys used for utility modules"""
|
|
103
|
+
MODEL_PARAMS_FILE = "model_parameters"
|
|
104
|
+
TOTAL_PARAMS = "Total Parameters"
|
|
105
|
+
TRAINABLE_PARAMS = "Trainable Parameters"
|
|
106
|
+
PTH_FILE = "pth report "
|
|
107
|
+
MODEL_ARCHITECTURE_FILE = "model_architecture_summary"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class VisionKeys:
|
|
111
|
+
"""For vision ML metrics"""
|
|
112
|
+
SEGMENTATION_REPORT = "segmentation_report"
|
|
113
|
+
SEGMENTATION_HEATMAP = "segmentation_metrics_heatmap"
|
|
114
|
+
SEGMENTATION_CONFUSION_MATRIX = "segmentation_confusion_matrix"
|
|
115
|
+
# Object detection
|
|
116
|
+
OBJECT_DETECTION_REPORT = "object_detection_report"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class VisionTransformRecipeKeys:
|
|
120
|
+
"""Defines the key names for the transform recipe JSON file."""
|
|
121
|
+
TASK = "task"
|
|
122
|
+
PIPELINE = "pipeline"
|
|
123
|
+
NAME = "name"
|
|
124
|
+
KWARGS = "kwargs"
|
|
125
|
+
PRE_TRANSFORMS = "pre_transforms"
|
|
126
|
+
|
|
127
|
+
RESIZE_SIZE = "resize_size"
|
|
128
|
+
CROP_SIZE = "crop_size"
|
|
129
|
+
MEAN = "mean"
|
|
130
|
+
STD = "std"
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class ObjectDetectionKeys:
|
|
134
|
+
"""Used by the object detection dataset"""
|
|
135
|
+
BOXES = "boxes"
|
|
136
|
+
LABELS = "labels"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class MLTaskKeys:
|
|
140
|
+
"""Used by the Trainer and InferenceHandlers"""
|
|
141
|
+
REGRESSION = "regression"
|
|
142
|
+
MULTITARGET_REGRESSION = "multitarget regression"
|
|
143
|
+
|
|
144
|
+
BINARY_CLASSIFICATION = "binary classification"
|
|
145
|
+
MULTICLASS_CLASSIFICATION = "multiclass classification"
|
|
146
|
+
MULTILABEL_BINARY_CLASSIFICATION = "multilabel binary classification"
|
|
147
|
+
|
|
148
|
+
BINARY_IMAGE_CLASSIFICATION = "binary image classification"
|
|
149
|
+
MULTICLASS_IMAGE_CLASSIFICATION = "multiclass image classification"
|
|
150
|
+
|
|
151
|
+
BINARY_SEGMENTATION = "binary segmentation"
|
|
152
|
+
MULTICLASS_SEGMENTATION = "multiclass segmentation"
|
|
153
|
+
|
|
154
|
+
OBJECT_DETECTION = "object detection"
|
|
155
|
+
|
|
156
|
+
SEQUENCE_SEQUENCE = "sequence-to-sequence"
|
|
157
|
+
SEQUENCE_VALUE = "sequence-to-value"
|
|
158
|
+
|
|
159
|
+
ALL_BINARY_TASKS = [BINARY_CLASSIFICATION, MULTILABEL_BINARY_CLASSIFICATION, BINARY_IMAGE_CLASSIFICATION, BINARY_SEGMENTATION]
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class DragonTrainerKeys:
|
|
163
|
+
VALIDATION_METRICS_DIR = "Validation_Metrics"
|
|
164
|
+
TEST_METRICS_DIR = "Test_Metrics"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class _OneHotOtherPlaceholder:
|
|
168
|
+
"""Used internally by GUI_tools."""
|
|
169
|
+
OTHER_GUI = "OTHER"
|
|
170
|
+
OTHER_MODEL = "one hot OTHER placeholder"
|
|
171
|
+
OTHER_DICT = {OTHER_GUI: OTHER_MODEL}
|
ml_tools/_schema.py
CHANGED
ml_tools/custom_logger.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import Union, List, Dict, Any
|
|
3
|
+
from typing import Union, List, Dict, Any, Literal
|
|
4
4
|
import traceback
|
|
5
5
|
import json
|
|
6
6
|
import csv
|
|
@@ -29,6 +29,8 @@ def custom_logger(
|
|
|
29
29
|
],
|
|
30
30
|
save_directory: Union[str, Path],
|
|
31
31
|
log_name: str,
|
|
32
|
+
add_timestamp: bool=True,
|
|
33
|
+
dict_as: Literal['auto', 'json', 'csv'] = 'auto',
|
|
32
34
|
) -> None:
|
|
33
35
|
"""
|
|
34
36
|
Logs various data types to corresponding output formats:
|
|
@@ -36,10 +38,10 @@ def custom_logger(
|
|
|
36
38
|
- list[Any] → .txt
|
|
37
39
|
Each element is written on a new line.
|
|
38
40
|
|
|
39
|
-
- dict[str, list[Any]] → .csv
|
|
41
|
+
- dict[str, list[Any]] → .csv (if dict_as='auto' or 'csv')
|
|
40
42
|
Dictionary is treated as tabular data; keys become columns, values become rows.
|
|
41
43
|
|
|
42
|
-
- dict[str, scalar] → .json
|
|
44
|
+
- dict[str, scalar] → .json (if dict_as='auto' or 'json')
|
|
43
45
|
Dictionary is treated as structured data and serialized as JSON.
|
|
44
46
|
|
|
45
47
|
- str → .log
|
|
@@ -49,29 +51,50 @@ def custom_logger(
|
|
|
49
51
|
Full traceback is logged for debugging purposes.
|
|
50
52
|
|
|
51
53
|
Args:
|
|
52
|
-
data: The data to be logged. Must be one of the supported types.
|
|
53
|
-
save_directory: Directory where the log will be saved. Created if it does not exist.
|
|
54
|
-
log_name: Base name for the log file.
|
|
54
|
+
data (Any): The data to be logged. Must be one of the supported types.
|
|
55
|
+
save_directory (str | Path): Directory where the log will be saved. Created if it does not exist.
|
|
56
|
+
log_name (str): Base name for the log file.
|
|
57
|
+
add_timestamp (bool): Whether to add a timestamp to the filename.
|
|
58
|
+
dict_as ('auto'|'json'|'csv'):
|
|
59
|
+
- 'auto': Guesses format (JSON or CSV) based on dictionary content.
|
|
60
|
+
- 'json': Forces .json format for any dictionary.
|
|
61
|
+
- 'csv': Forces .csv format. Will fail if dict values are not all lists.
|
|
55
62
|
|
|
56
63
|
Raises:
|
|
57
64
|
ValueError: If the data type is unsupported.
|
|
58
65
|
"""
|
|
59
66
|
try:
|
|
67
|
+
if not isinstance(data, BaseException) and not data:
|
|
68
|
+
_LOGGER.warning("Empty data received. No log file will be saved.")
|
|
69
|
+
return
|
|
70
|
+
|
|
60
71
|
save_path = make_fullpath(save_directory, make=True)
|
|
61
72
|
|
|
62
|
-
|
|
63
|
-
log_name = sanitize_filename(log_name)
|
|
73
|
+
sanitized_log_name = sanitize_filename(log_name)
|
|
64
74
|
|
|
65
|
-
|
|
66
|
-
|
|
75
|
+
if add_timestamp:
|
|
76
|
+
timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
|
|
77
|
+
base_path = save_path / f"{sanitized_log_name}_{timestamp}"
|
|
78
|
+
else:
|
|
79
|
+
base_path = save_path / sanitized_log_name
|
|
80
|
+
|
|
81
|
+
# Router
|
|
67
82
|
if isinstance(data, list):
|
|
68
83
|
_log_list_to_txt(data, base_path.with_suffix(".txt"))
|
|
69
84
|
|
|
70
85
|
elif isinstance(data, dict):
|
|
71
|
-
if
|
|
72
|
-
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
73
|
-
else:
|
|
86
|
+
if dict_as == 'json':
|
|
74
87
|
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
88
|
+
|
|
89
|
+
elif dict_as == 'csv':
|
|
90
|
+
# This will raise a ValueError if data is not all lists
|
|
91
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
92
|
+
|
|
93
|
+
else: # 'auto' mode
|
|
94
|
+
if all(isinstance(v, list) for v in data.values()):
|
|
95
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
96
|
+
else:
|
|
97
|
+
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
75
98
|
|
|
76
99
|
elif isinstance(data, str):
|
|
77
100
|
_log_string_to_log(data, base_path.with_suffix(".log"))
|
|
@@ -83,7 +106,7 @@ def custom_logger(
|
|
|
83
106
|
_LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
|
|
84
107
|
raise ValueError()
|
|
85
108
|
|
|
86
|
-
_LOGGER.info(f"Log saved
|
|
109
|
+
_LOGGER.info(f"Log saved as: '{base_path.name}'")
|
|
87
110
|
|
|
88
111
|
except Exception:
|
|
89
112
|
_LOGGER.exception(f"Log not saved.")
|