dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
1
+ from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
2
+ from PIL import ImageOps, Image
3
+ from torchvision import transforms
4
+ from pathlib import Path
5
+ import json
6
+
7
+ from ._logger import _LOGGER
8
+ from ._script_info import _script_info
9
+ from ._keys import VisionTransformRecipeKeys
10
+ from .path_manager import make_fullpath
11
+
12
+
13
+ __all__ = [
14
+ "TRANSFORM_REGISTRY",
15
+ "ResizeAspectFill",
16
+ "create_offline_augmentations"
17
+ ]
18
+
19
+ # --- Custom Vision Transform Class ---
20
+ class ResizeAspectFill:
21
+ """
22
+ Custom transformation to make an image square by padding it to match the
23
+ longest side, preserving the aspect ratio. The image is finally centered.
24
+
25
+ Args:
26
+ pad_color (Union[str, int]): Color to use for the padding.
27
+ Defaults to "black".
28
+ """
29
+ def __init__(self, pad_color: Union[str, int] = "black") -> None:
30
+ self.pad_color = pad_color
31
+ # Store kwargs to allow for re-creation
32
+ self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
33
+
34
+ def __call__(self, image: Image.Image) -> Image.Image:
35
+ if not isinstance(image, Image.Image):
36
+ _LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
37
+ raise TypeError()
38
+
39
+ w, h = image.size
40
+ if w == h:
41
+ return image
42
+
43
+ # Determine padding to center the image
44
+ if w > h:
45
+ top_padding = (w - h) // 2
46
+ bottom_padding = w - h - top_padding
47
+ padding = (0, top_padding, 0, bottom_padding)
48
+ else: # h > w
49
+ left_padding = (h - w) // 2
50
+ right_padding = h - w - left_padding
51
+ padding = (left_padding, 0, right_padding, 0)
52
+
53
+ return ImageOps.expand(image, padding, fill=self.pad_color)
54
+
55
+
56
+ #############################################################
57
+ #NOTE: Add custom transforms.
58
+ TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
59
+ "ResizeAspectFill": ResizeAspectFill,
60
+ }
61
+ #############################################################
62
+
63
+
64
+ def create_offline_augmentations(
65
+ input_directory: Union[str, Path],
66
+ output_directory: Union[str, Path],
67
+ results_per_image: int,
68
+ recipe: Optional[Dict[str, Any]] = None,
69
+ save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
70
+ save_quality: int = 80
71
+ ) -> None:
72
+ """
73
+ Reads all valid images from an input directory, applies augmentations,
74
+ and saves the new images to an output directory (offline augmentation).
75
+
76
+ Skips subdirectories in the input path.
77
+
78
+ Args:
79
+ input_directory (Union[str, Path]): Path to the directory of source images.
80
+ output_directory (Union[str, Path]): Path to save the augmented images.
81
+ results_per_image (int): The number of augmented versions to create
82
+ for each source image.
83
+ recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
84
+ a default set of strong, random
85
+ augmentations will be used.
86
+ save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
87
+ Defaults to "WEBP" for good compression.
88
+ save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
89
+ """
90
+ VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
91
+
92
+ # --- 1. Validate Paths ---
93
+ in_path = make_fullpath(input_directory, enforce="directory")
94
+ out_path = make_fullpath(output_directory, make=True, enforce="directory")
95
+
96
+ _LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
97
+
98
+ # --- 2. Find Images ---
99
+ image_files = [
100
+ f for f in in_path.iterdir()
101
+ if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
102
+ ]
103
+
104
+ if not image_files:
105
+ _LOGGER.warning(f"No valid image files found in {in_path}.")
106
+ return
107
+
108
+ _LOGGER.info(f"Found {len(image_files)} images to process.")
109
+
110
+ # --- 3. Define Transform Pipeline ---
111
+ transform_pipeline: transforms.Compose
112
+
113
+ if recipe:
114
+ _LOGGER.info("Building transformations from provided recipe.")
115
+ try:
116
+ transform_pipeline = _build_transform_from_recipe(recipe)
117
+ except Exception as e:
118
+ _LOGGER.error(f"Failed to build transform from recipe: {e}")
119
+ return
120
+ else:
121
+ _LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
122
+ # Default "random" pipeline
123
+ transform_pipeline = transforms.Compose([
124
+ transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
125
+ transforms.RandomHorizontalFlip(p=0.5),
126
+ transforms.RandomRotation(degrees=90),
127
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
128
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
129
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
130
+ transforms.RandomApply([
131
+ transforms.GaussianBlur(kernel_size=3)
132
+ ], p=0.3)
133
+ ])
134
+
135
+ # --- 4. Process Images ---
136
+ total_saved = 0
137
+ format_upper = save_format.upper()
138
+
139
+ for img_path in image_files:
140
+ _LOGGER.debug(f"Processing {img_path.name}...")
141
+ try:
142
+ original_image = Image.open(img_path).convert("RGB")
143
+
144
+ for i in range(results_per_image):
145
+ new_stem = f"{img_path.stem}_aug_{i+1:03d}"
146
+ output_path = out_path / f"{new_stem}.{format_upper.lower()}"
147
+
148
+ # Apply transform
149
+ transformed_image = transform_pipeline(original_image)
150
+
151
+ # Save
152
+ transformed_image.save(
153
+ output_path,
154
+ format=format_upper,
155
+ quality=save_quality,
156
+ optimize=True # Add optimize flag
157
+ )
158
+ total_saved += 1
159
+
160
+ except Exception as e:
161
+ _LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
162
+
163
+ _LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
164
+
165
+
166
+ def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
167
+ """Internal helper to build a transform pipeline from a recipe dict."""
168
+ pipeline_steps: List[Callable] = []
169
+
170
+ if VisionTransformRecipeKeys.PIPELINE not in recipe:
171
+ _LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
172
+ raise ValueError("Invalid recipe format.")
173
+
174
+ for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
175
+ t_name = step.get(VisionTransformRecipeKeys.NAME)
176
+ t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
177
+
178
+ if not t_name:
179
+ _LOGGER.error(f"Invalid transform step, missing 'name': {step}")
180
+ continue
181
+
182
+ transform_class: Any = None
183
+
184
+ # 1. Check standard torchvision transforms
185
+ if hasattr(transforms, t_name):
186
+ transform_class = getattr(transforms, t_name)
187
+ # 2. Check custom transforms
188
+ elif t_name in TRANSFORM_REGISTRY:
189
+ transform_class = TRANSFORM_REGISTRY[t_name]
190
+ # 3. Not found
191
+ else:
192
+ _LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
193
+ raise ValueError(f"Unknown transform name: {t_name}")
194
+
195
+ # Instantiate the transform
196
+ try:
197
+ pipeline_steps.append(transform_class(**t_kwargs))
198
+ except Exception as e:
199
+ _LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
200
+ raise
201
+
202
+ return transforms.Compose(pipeline_steps)
203
+
204
+
205
+ def _save_recipe(recipe: Dict[str, Any], filepath: Path) -> None:
206
+ """
207
+ Saves a transform recipe dictionary to a JSON file.
208
+
209
+ Args:
210
+ recipe (Dict[str, Any]): The recipe dictionary to save.
211
+ filepath (str): The path to the output .json file.
212
+ """
213
+ final_filepath = filepath.with_suffix(".json")
214
+
215
+ try:
216
+ with open(final_filepath, 'w') as f:
217
+ json.dump(recipe, f, indent=4)
218
+ _LOGGER.info(f"Transform recipe saved as '{final_filepath.name}'.")
219
+ except Exception as e:
220
+ _LOGGER.error(f"Failed to save recipe to '{final_filepath}': {e}")
221
+ raise
222
+
223
+
224
+ def _load_recipe_and_build_transform(filepath: Union[str,Path]) -> transforms.Compose:
225
+ """
226
+ Loads a transform recipe from a .json file and reconstructs the
227
+ torchvision.transforms.Compose pipeline.
228
+
229
+ Args:
230
+ filepath (str): Path to the saved transform recipe .json file.
231
+
232
+ Returns:
233
+ transforms.Compose: The reconstructed transformation pipeline.
234
+
235
+ Raises:
236
+ ValueError: If a transform name in the recipe is not found in
237
+ torchvision.transforms or the custom TRANSFORM_REGISTRY.
238
+ """
239
+ # validate filepath
240
+ final_filepath = make_fullpath(filepath, enforce="file")
241
+
242
+ try:
243
+ with open(final_filepath, 'r') as f:
244
+ recipe = json.load(f)
245
+ except Exception as e:
246
+ _LOGGER.error(f"Failed to load recipe from '{final_filepath}': {e}")
247
+ raise
248
+
249
+ pipeline_steps: List[Callable] = []
250
+
251
+ if VisionTransformRecipeKeys.PIPELINE not in recipe:
252
+ _LOGGER.error("Recipe file is invalid: missing 'pipeline' key.")
253
+ raise ValueError("Invalid recipe format.")
254
+
255
+ for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
256
+ t_name = step[VisionTransformRecipeKeys.NAME]
257
+ t_kwargs = step[VisionTransformRecipeKeys.KWARGS]
258
+
259
+ transform_class: Any = None
260
+
261
+ # 1. Check standard torchvision transforms
262
+ if hasattr(transforms, t_name):
263
+ transform_class = getattr(transforms, t_name)
264
+ # 2. Check custom transforms
265
+ elif t_name in TRANSFORM_REGISTRY:
266
+ transform_class = TRANSFORM_REGISTRY[t_name]
267
+ # 3. Not found
268
+ else:
269
+ _LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
270
+ raise ValueError(f"Unknown transform name: {t_name}")
271
+
272
+ # Instantiate the transform
273
+ try:
274
+ pipeline_steps.append(transform_class(**t_kwargs))
275
+ except Exception as e:
276
+ _LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
277
+ raise
278
+
279
+ _LOGGER.info(f"Successfully loaded and built transform pipeline from '{final_filepath.name}'.")
280
+ return transforms.Compose(pipeline_steps)
281
+
282
+
283
+ def info():
284
+ _script_info(__all__)
@@ -12,9 +12,9 @@ from .serde import deserialize_object
12
12
  from .math_utilities import threshold_binary_values, threshold_binary_values_batch
13
13
  from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension
14
14
  from ._logger import _LOGGER
15
- from .keys import EnsembleKeys
15
+ from ._keys import EnsembleKeys
16
16
  from ._script_info import _script_info
17
- from .SQL import DatabaseManager
17
+ from .SQL import DragonSQL
18
18
  from .optimization_tools import _save_result
19
19
 
20
20
  """
@@ -191,7 +191,7 @@ def _set_feature_names(size: int, names: Union[list[str], None]):
191
191
  return names
192
192
 
193
193
 
194
- def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DatabaseManager], db_table_name: str):
194
+ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
195
195
  """Helper for a single PSO run that also handles saving."""
196
196
  pso_args.update({"seed": random_state})
197
197
 
@@ -213,7 +213,7 @@ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, featu
213
213
  return best_features_named, best_target_named
214
214
 
215
215
 
216
- def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DatabaseManager], db_table_name: str):
216
+ def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int, save_format: Literal['csv', 'sqlite', 'both'], csv_path: Path, db_manager: Optional[DragonSQL], db_table_name: str):
217
217
  """Helper for post-hoc analysis that saves results incrementally."""
218
218
  progress = trange(repetitions, desc="Post-Hoc PSO", unit="run")
219
219
  for _ in progress:
@@ -342,7 +342,7 @@ def run_pso(lower_boundaries: list[float],
342
342
  schema = {"result_id": "INTEGER PRIMARY KEY AUTOINCREMENT", **schema}
343
343
 
344
344
  # Create table
345
- with DatabaseManager(db_path) as db:
345
+ with DragonSQL(db_path) as db:
346
346
  db.create_table(db_table_name, schema)
347
347
 
348
348
  pso_arguments = {
@@ -357,7 +357,7 @@ def run_pso(lower_boundaries: list[float],
357
357
 
358
358
  # --- Dispatcher ---
359
359
  # Use a real or dummy context manager to handle the DB connection cleanly
360
- db_context = DatabaseManager(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
360
+ db_context = DragonSQL(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
361
361
 
362
362
  with db_context as db_manager:
363
363
  if post_hoc_analysis is None or post_hoc_analysis <= 1:
ml_tools/SQL.py CHANGED
@@ -9,11 +9,11 @@ from .path_manager import make_fullpath, sanitize_filename
9
9
 
10
10
 
11
11
  __all__ = [
12
- "DatabaseManager",
12
+ "DragonSQL",
13
13
  ]
14
14
 
15
15
 
16
- class DatabaseManager:
16
+ class DragonSQL:
17
17
  """
18
18
  A user-friendly context manager for handling SQLite database operations.
19
19
 
@@ -35,7 +35,7 @@ class DatabaseManager:
35
35
  ... "feature_a": "REAL",
36
36
  ... "score": "REAL"
37
37
  ... }
38
- >>> with DatabaseManager("my_results.db") as db:
38
+ >>> with DragonSQL("my_results.db") as db:
39
39
  ... db.create_table("experiments", schema)
40
40
  ... data = {"run_name": "first_run", "feature_a": 0.123, "score": 95.5}
41
41
  ... db.insert_row("experiments", data)
@@ -43,7 +43,7 @@ class DatabaseManager:
43
43
  ... print(df)
44
44
  """
45
45
  def __init__(self, db_path: Union[str, Path]):
46
- """Initializes the DatabaseManager with the path to the database file."""
46
+ """Initializes the DragonSQL with the path to the database file."""
47
47
  if isinstance(db_path, str):
48
48
  if not db_path.endswith(".db"):
49
49
  db_path = db_path + ".db"
ml_tools/_keys.py ADDED
@@ -0,0 +1,171 @@
1
+ class MagicWords:
2
+ """General purpose keys"""
3
+ LATEST = "latest"
4
+ CURRENT = "current"
5
+ RENAME = "rename"
6
+
7
+
8
+ class PyTorchLogKeys:
9
+ """
10
+ Used internally for ML scripts module.
11
+
12
+ Centralized keys for logging and history.
13
+ """
14
+ # --- Epoch Level ---
15
+ TRAIN_LOSS = 'train_loss'
16
+ VAL_LOSS = 'val_loss'
17
+ LEARNING_RATE = 'lr'
18
+
19
+ # --- Batch Level ---
20
+ BATCH_LOSS = 'loss'
21
+ BATCH_INDEX = 'batch'
22
+ BATCH_SIZE = 'size'
23
+
24
+
25
+ class EnsembleKeys:
26
+ """
27
+ Used internally by ensemble_learning.
28
+ """
29
+ # Serializing a trained model metadata.
30
+ MODEL = "model"
31
+ FEATURES = "feature_names"
32
+ TARGET = "target_name"
33
+
34
+ # Classification keys
35
+ CLASSIFICATION_LABEL = "labels"
36
+ CLASSIFICATION_PROBABILITIES = "probabilities"
37
+
38
+
39
+ class PyTorchInferenceKeys:
40
+ """Keys for the output dictionaries of PyTorchInferenceHandler."""
41
+ # For regression tasks
42
+ PREDICTIONS = "predictions"
43
+
44
+ # For classification tasks
45
+ LABELS = "labels"
46
+ PROBABILITIES = "probabilities"
47
+ LABEL_NAMES = "label_names"
48
+
49
+
50
+ class PytorchModelArchitectureKeys:
51
+ """Keys for saving and loading model architecture."""
52
+ MODEL = 'model_class'
53
+ CONFIG = "config"
54
+ SAVENAME = "architecture"
55
+
56
+
57
+ class PytorchArtifactPathKeys:
58
+ """Keys for model artifact paths."""
59
+ FEATURES_PATH = "feature_names_path"
60
+ TARGETS_PATH = "target_names_path"
61
+ ARCHITECTURE_PATH = "model_architecture_path"
62
+ WEIGHTS_PATH = "model_weights_path"
63
+ SCALER_PATH = "scaler_path"
64
+
65
+
66
+ class DatasetKeys:
67
+ """Keys for saving dataset artifacts. Also used by FeatureSchema"""
68
+ FEATURE_NAMES = "feature_names"
69
+ TARGET_NAMES = "target_names"
70
+ SCALER_PREFIX = "scaler_"
71
+ # Feature Schema
72
+ CONTINUOUS_NAMES = "continuous_feature_names"
73
+ CATEGORICAL_NAMES = "categorical_feature_names"
74
+
75
+
76
+ class SHAPKeys:
77
+ """Keys for SHAP functions"""
78
+ FEATURE_COLUMN = "feature"
79
+ SHAP_VALUE_COLUMN = "mean_abs_shap_value"
80
+ SAVENAME = "shap_summary"
81
+
82
+
83
+ class PyTorchCheckpointKeys:
84
+ """Keys for saving/loading a training checkpoint dictionary."""
85
+ MODEL_STATE = "model_state_dict"
86
+ OPTIMIZER_STATE = "optimizer_state_dict"
87
+ SCHEDULER_STATE = "scheduler_state_dict"
88
+ EPOCH = "epoch"
89
+ BEST_SCORE = "best_score"
90
+ HISTORY = "history"
91
+ CHECKPOINT_NAME = "PyModelCheckpoint"
92
+ # Finalized config
93
+ CLASSIFICATION_THRESHOLD = "classification_threshold"
94
+ CLASS_MAP = "class_map"
95
+ SEQUENCE_LENGTH = "sequence_length"
96
+ INITIAL_SEQUENCE = "initial_sequence"
97
+ TARGET_NAME = "target_name"
98
+ TARGET_NAMES = "target_names"
99
+
100
+
101
+ class UtilityKeys:
102
+ """Keys used for utility modules"""
103
+ MODEL_PARAMS_FILE = "model_parameters"
104
+ TOTAL_PARAMS = "Total Parameters"
105
+ TRAINABLE_PARAMS = "Trainable Parameters"
106
+ PTH_FILE = "pth report "
107
+ MODEL_ARCHITECTURE_FILE = "model_architecture_summary"
108
+
109
+
110
+ class VisionKeys:
111
+ """For vision ML metrics"""
112
+ SEGMENTATION_REPORT = "segmentation_report"
113
+ SEGMENTATION_HEATMAP = "segmentation_metrics_heatmap"
114
+ SEGMENTATION_CONFUSION_MATRIX = "segmentation_confusion_matrix"
115
+ # Object detection
116
+ OBJECT_DETECTION_REPORT = "object_detection_report"
117
+
118
+
119
+ class VisionTransformRecipeKeys:
120
+ """Defines the key names for the transform recipe JSON file."""
121
+ TASK = "task"
122
+ PIPELINE = "pipeline"
123
+ NAME = "name"
124
+ KWARGS = "kwargs"
125
+ PRE_TRANSFORMS = "pre_transforms"
126
+
127
+ RESIZE_SIZE = "resize_size"
128
+ CROP_SIZE = "crop_size"
129
+ MEAN = "mean"
130
+ STD = "std"
131
+
132
+
133
+ class ObjectDetectionKeys:
134
+ """Used by the object detection dataset"""
135
+ BOXES = "boxes"
136
+ LABELS = "labels"
137
+
138
+
139
+ class MLTaskKeys:
140
+ """Used by the Trainer and InferenceHandlers"""
141
+ REGRESSION = "regression"
142
+ MULTITARGET_REGRESSION = "multitarget regression"
143
+
144
+ BINARY_CLASSIFICATION = "binary classification"
145
+ MULTICLASS_CLASSIFICATION = "multiclass classification"
146
+ MULTILABEL_BINARY_CLASSIFICATION = "multilabel binary classification"
147
+
148
+ BINARY_IMAGE_CLASSIFICATION = "binary image classification"
149
+ MULTICLASS_IMAGE_CLASSIFICATION = "multiclass image classification"
150
+
151
+ BINARY_SEGMENTATION = "binary segmentation"
152
+ MULTICLASS_SEGMENTATION = "multiclass segmentation"
153
+
154
+ OBJECT_DETECTION = "object detection"
155
+
156
+ SEQUENCE_SEQUENCE = "sequence-to-sequence"
157
+ SEQUENCE_VALUE = "sequence-to-value"
158
+
159
+ ALL_BINARY_TASKS = [BINARY_CLASSIFICATION, MULTILABEL_BINARY_CLASSIFICATION, BINARY_IMAGE_CLASSIFICATION, BINARY_SEGMENTATION]
160
+
161
+
162
+ class DragonTrainerKeys:
163
+ VALIDATION_METRICS_DIR = "Validation_Metrics"
164
+ TEST_METRICS_DIR = "Test_Metrics"
165
+
166
+
167
+ class _OneHotOtherPlaceholder:
168
+ """Used internally by GUI_tools."""
169
+ OTHER_GUI = "OTHER"
170
+ OTHER_MODEL = "one hot OTHER placeholder"
171
+ OTHER_DICT = {OTHER_GUI: OTHER_MODEL}
ml_tools/_schema.py CHANGED
@@ -2,7 +2,7 @@ from typing import NamedTuple, Tuple, Optional, Dict, Union
2
2
  from pathlib import Path
3
3
 
4
4
  from .custom_logger import save_list_strings
5
- from .keys import DatasetKeys
5
+ from ._keys import DatasetKeys
6
6
  from ._logger import _LOGGER
7
7
 
8
8
 
ml_tools/custom_logger.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from pathlib import Path
2
2
  from datetime import datetime
3
- from typing import Union, List, Dict, Any
3
+ from typing import Union, List, Dict, Any, Literal
4
4
  import traceback
5
5
  import json
6
6
  import csv
@@ -29,6 +29,8 @@ def custom_logger(
29
29
  ],
30
30
  save_directory: Union[str, Path],
31
31
  log_name: str,
32
+ add_timestamp: bool=True,
33
+ dict_as: Literal['auto', 'json', 'csv'] = 'auto',
32
34
  ) -> None:
33
35
  """
34
36
  Logs various data types to corresponding output formats:
@@ -36,10 +38,10 @@ def custom_logger(
36
38
  - list[Any] → .txt
37
39
  Each element is written on a new line.
38
40
 
39
- - dict[str, list[Any]] → .csv
41
+ - dict[str, list[Any]] → .csv (if dict_as='auto' or 'csv')
40
42
  Dictionary is treated as tabular data; keys become columns, values become rows.
41
43
 
42
- - dict[str, scalar] → .json
44
+ - dict[str, scalar] → .json (if dict_as='auto' or 'json')
43
45
  Dictionary is treated as structured data and serialized as JSON.
44
46
 
45
47
  - str → .log
@@ -49,29 +51,50 @@ def custom_logger(
49
51
  Full traceback is logged for debugging purposes.
50
52
 
51
53
  Args:
52
- data: The data to be logged. Must be one of the supported types.
53
- save_directory: Directory where the log will be saved. Created if it does not exist.
54
- log_name: Base name for the log file. Timestamp will be appended automatically.
54
+ data (Any): The data to be logged. Must be one of the supported types.
55
+ save_directory (str | Path): Directory where the log will be saved. Created if it does not exist.
56
+ log_name (str): Base name for the log file.
57
+ add_timestamp (bool): Whether to add a timestamp to the filename.
58
+ dict_as ('auto'|'json'|'csv'):
59
+ - 'auto': Guesses format (JSON or CSV) based on dictionary content.
60
+ - 'json': Forces .json format for any dictionary.
61
+ - 'csv': Forces .csv format. Will fail if dict values are not all lists.
55
62
 
56
63
  Raises:
57
64
  ValueError: If the data type is unsupported.
58
65
  """
59
66
  try:
67
+ if not isinstance(data, BaseException) and not data:
68
+ _LOGGER.warning("Empty data received. No log file will be saved.")
69
+ return
70
+
60
71
  save_path = make_fullpath(save_directory, make=True)
61
72
 
62
- timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
63
- log_name = sanitize_filename(log_name)
73
+ sanitized_log_name = sanitize_filename(log_name)
64
74
 
65
- base_path = save_path / f"{log_name}_{timestamp}"
66
-
75
+ if add_timestamp:
76
+ timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
77
+ base_path = save_path / f"{sanitized_log_name}_{timestamp}"
78
+ else:
79
+ base_path = save_path / sanitized_log_name
80
+
81
+ # Router
67
82
  if isinstance(data, list):
68
83
  _log_list_to_txt(data, base_path.with_suffix(".txt"))
69
84
 
70
85
  elif isinstance(data, dict):
71
- if all(isinstance(v, list) for v in data.values()):
72
- _log_dict_to_csv(data, base_path.with_suffix(".csv"))
73
- else:
86
+ if dict_as == 'json':
74
87
  _log_dict_to_json(data, base_path.with_suffix(".json"))
88
+
89
+ elif dict_as == 'csv':
90
+ # This will raise a ValueError if data is not all lists
91
+ _log_dict_to_csv(data, base_path.with_suffix(".csv"))
92
+
93
+ else: # 'auto' mode
94
+ if all(isinstance(v, list) for v in data.values()):
95
+ _log_dict_to_csv(data, base_path.with_suffix(".csv"))
96
+ else:
97
+ _log_dict_to_json(data, base_path.with_suffix(".json"))
75
98
 
76
99
  elif isinstance(data, str):
77
100
  _log_string_to_log(data, base_path.with_suffix(".log"))
@@ -83,7 +106,7 @@ def custom_logger(
83
106
  _LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
84
107
  raise ValueError()
85
108
 
86
- _LOGGER.info(f"Log saved to: '{base_path}'")
109
+ _LOGGER.info(f"Log saved as: '{base_path.name}'")
87
110
 
88
111
  except Exception:
89
112
  _LOGGER.exception(f"Log not saved.")