dragon-ml-toolbox 3.6.0__tar.gz → 3.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-3.6.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-3.8.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ETL_engineering.py +2 -2
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/GUI_tools.py +22 -84
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/MICE_imputation.py +2 -2
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ML_callbacks.py +0 -5
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ML_evaluation.py +10 -10
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ML_trainer.py +2 -2
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/PSO_optimization.py +57 -65
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/VIF_factor.py +2 -2
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/data_exploration.py +5 -4
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/datasetmaster.py +11 -14
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ensemble_learning.py +2 -2
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/logger.py +3 -4
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/utilities.py +208 -4
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/LICENSE +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/README.md +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/ML_tutorial.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/_particle_swarm_optimization.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/_pytorch_models.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/setup.cfg +0 -0
|
@@ -294,7 +294,7 @@ class DataProcessor:
|
|
|
294
294
|
raise TypeError(f"Invalid 'transform' action for '{input_col_name}': {transform_action}")
|
|
295
295
|
|
|
296
296
|
if not processed_columns:
|
|
297
|
-
_LOGGER.warning("The transformation resulted in an empty DataFrame.")
|
|
297
|
+
_LOGGER.warning("⚠️ The transformation resulted in an empty DataFrame.")
|
|
298
298
|
return pl.DataFrame()
|
|
299
299
|
|
|
300
300
|
return pl.DataFrame(processed_columns)
|
|
@@ -588,7 +588,7 @@ class NumberExtractor:
|
|
|
588
588
|
if not isinstance(round_digits, int):
|
|
589
589
|
raise TypeError("round_digits must be an integer.")
|
|
590
590
|
if dtype == "int":
|
|
591
|
-
_LOGGER.warning(f"'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
|
|
591
|
+
_LOGGER.warning(f"⚠️ 'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
|
|
592
592
|
|
|
593
593
|
self.regex_pattern = regex_pattern
|
|
594
594
|
self.dtype = dtype
|
|
@@ -4,14 +4,13 @@ from typing import Optional, Callable, Any
|
|
|
4
4
|
import traceback
|
|
5
5
|
import FreeSimpleGUI as sg
|
|
6
6
|
from functools import wraps
|
|
7
|
-
from typing import Any, Dict, Tuple, List
|
|
7
|
+
from typing import Any, Dict, Tuple, List, Literal
|
|
8
8
|
from .utilities import _script_info
|
|
9
9
|
import numpy as np
|
|
10
10
|
from .logger import _LOGGER
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
|
-
"PathManager",
|
|
15
14
|
"ConfigManager",
|
|
16
15
|
"GUIFactory",
|
|
17
16
|
"catch_exceptions",
|
|
@@ -19,68 +18,6 @@ __all__ = [
|
|
|
19
18
|
"update_target_fields"
|
|
20
19
|
]
|
|
21
20
|
|
|
22
|
-
|
|
23
|
-
# --- Path Management ---
|
|
24
|
-
class PathManager:
|
|
25
|
-
"""
|
|
26
|
-
Manages paths for a Python application, supporting both development mode and bundled mode via Briefcase.
|
|
27
|
-
"""
|
|
28
|
-
def __init__(self, anchor_file: str):
|
|
29
|
-
"""
|
|
30
|
-
Initializes the PathManager. The package name is automatically inferred
|
|
31
|
-
from the parent directory of the anchor file.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
anchor_file (str): The absolute path to a file within the project's
|
|
35
|
-
package, typically `__file__` from a module inside
|
|
36
|
-
that package (paths.py).
|
|
37
|
-
|
|
38
|
-
Note:
|
|
39
|
-
This inference assumes that the anchor file's parent directory
|
|
40
|
-
has the same name as the package (e.g., `.../src/my_app/paths.py`).
|
|
41
|
-
This is a standard and recommended project structure.
|
|
42
|
-
"""
|
|
43
|
-
resolved_anchor_path = Path(anchor_file).resolve()
|
|
44
|
-
self.package_name = resolved_anchor_path.parent.name
|
|
45
|
-
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
46
|
-
|
|
47
|
-
if self._is_bundled:
|
|
48
|
-
# In a Briefcase bundle, resource_path gives an absolute path
|
|
49
|
-
# to the resource directory.
|
|
50
|
-
self.package_root = self._resource_path_func(self.package_name, "") # type: ignore
|
|
51
|
-
else:
|
|
52
|
-
# In development mode, the package root is the directory
|
|
53
|
-
# containing the anchor file.
|
|
54
|
-
self.package_root = resolved_anchor_path.parent
|
|
55
|
-
|
|
56
|
-
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
57
|
-
"""Checks if the app is running in a bundled environment."""
|
|
58
|
-
try:
|
|
59
|
-
# This is the function Briefcase provides in a bundled app
|
|
60
|
-
from briefcase.platforms.base import resource_path # type: ignore
|
|
61
|
-
return True, resource_path
|
|
62
|
-
except ImportError:
|
|
63
|
-
return False, None
|
|
64
|
-
|
|
65
|
-
def get_path(self, relative_path: str | Path) -> Path:
|
|
66
|
-
"""
|
|
67
|
-
Gets the absolute path for a given resource file or directory
|
|
68
|
-
relative to the package root.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
relative_path (str | Path): The path relative to the package root (e.g., 'helpers/icon.png').
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
Path: The absolute path to the resource.
|
|
75
|
-
"""
|
|
76
|
-
if self._is_bundled:
|
|
77
|
-
# Briefcase's resource_path handles resolving the path within the app bundle
|
|
78
|
-
return self._resource_path_func(self.package_name, str(relative_path)) # type: ignore
|
|
79
|
-
else:
|
|
80
|
-
# In dev mode, join package root with the relative path.
|
|
81
|
-
return self.package_root / relative_path
|
|
82
|
-
|
|
83
|
-
|
|
84
21
|
# --- Configuration Management ---
|
|
85
22
|
class _SectionProxy:
|
|
86
23
|
"""A helper class to represent a section of the .ini file as an object."""
|
|
@@ -148,7 +85,7 @@ class ConfigManager:
|
|
|
148
85
|
"""
|
|
149
86
|
path = Path(file_path)
|
|
150
87
|
if path.exists() and not force_overwrite:
|
|
151
|
-
_LOGGER.warning(f"Configuration file already exists at {path}. Aborting.")
|
|
88
|
+
_LOGGER.warning(f"⚠️ Configuration file already exists at {path}. Aborting.")
|
|
152
89
|
return
|
|
153
90
|
|
|
154
91
|
config = configparser.ConfigParser()
|
|
@@ -206,7 +143,7 @@ class ConfigManager:
|
|
|
206
143
|
|
|
207
144
|
with open(path, 'w') as configfile:
|
|
208
145
|
config.write(configfile)
|
|
209
|
-
_LOGGER.info(f"Successfully generated config template at: '{path}'")
|
|
146
|
+
_LOGGER.info(f"📝 Successfully generated config template at: '{path}'")
|
|
210
147
|
|
|
211
148
|
|
|
212
149
|
# --- GUI Factory ---
|
|
@@ -273,8 +210,8 @@ class GUIFactory:
|
|
|
273
210
|
self,
|
|
274
211
|
data_dict: Dict[str, Tuple[float, float]],
|
|
275
212
|
is_target: bool = False,
|
|
276
|
-
layout_mode:
|
|
277
|
-
|
|
213
|
+
layout_mode: Literal["grid", "row"] = 'grid',
|
|
214
|
+
features_per_column: int = 4
|
|
278
215
|
) -> List[List[sg.Column]]:
|
|
279
216
|
"""
|
|
280
217
|
Generates a layout for continuous features or targets.
|
|
@@ -283,7 +220,7 @@ class GUIFactory:
|
|
|
283
220
|
data_dict (dict): Keys are feature names, values are (min, max) tuples.
|
|
284
221
|
is_target (bool): If True, creates disabled inputs for displaying results.
|
|
285
222
|
layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
|
|
286
|
-
|
|
223
|
+
features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
|
|
287
224
|
|
|
288
225
|
Returns:
|
|
289
226
|
A list of lists of sg.Column elements, ready to be used in a window layout.
|
|
@@ -294,7 +231,7 @@ class GUIFactory:
|
|
|
294
231
|
|
|
295
232
|
columns = []
|
|
296
233
|
for name, (val_min, val_max) in data_dict.items():
|
|
297
|
-
key =
|
|
234
|
+
key = name
|
|
298
235
|
default_text = "" if is_target else str(val_max)
|
|
299
236
|
|
|
300
237
|
label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
|
|
@@ -313,6 +250,7 @@ class GUIFactory:
|
|
|
313
250
|
range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
|
|
314
251
|
layout = [[label], [element], [range_text]]
|
|
315
252
|
|
|
253
|
+
# each feature is wrapped as a column element
|
|
316
254
|
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
317
255
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
318
256
|
|
|
@@ -320,13 +258,13 @@ class GUIFactory:
|
|
|
320
258
|
return [columns] # A single row containing all columns
|
|
321
259
|
|
|
322
260
|
# Default to 'grid' layout
|
|
323
|
-
return [columns[i:i +
|
|
261
|
+
return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
|
|
324
262
|
|
|
325
263
|
def generate_combo_layout(
|
|
326
264
|
self,
|
|
327
265
|
data_dict: Dict[str, List[Any]],
|
|
328
|
-
layout_mode:
|
|
329
|
-
|
|
266
|
+
layout_mode: Literal["grid", "row"] = 'grid',
|
|
267
|
+
features_per_column: int = 4
|
|
330
268
|
) -> List[List[sg.Column]]:
|
|
331
269
|
"""
|
|
332
270
|
Generates a layout for categorical or binary features using Combo boxes.
|
|
@@ -334,7 +272,7 @@ class GUIFactory:
|
|
|
334
272
|
Args:
|
|
335
273
|
data_dict (dict): Keys are feature names, values are lists of options.
|
|
336
274
|
layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
|
|
337
|
-
|
|
275
|
+
features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
|
|
338
276
|
|
|
339
277
|
Returns:
|
|
340
278
|
A list of lists of sg.Column elements, ready to be used in a window layout.
|
|
@@ -352,13 +290,14 @@ class GUIFactory:
|
|
|
352
290
|
)
|
|
353
291
|
layout = [[label], [element]]
|
|
354
292
|
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
293
|
+
# each feature is wrapped in a Column element
|
|
355
294
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
356
295
|
|
|
357
296
|
if layout_mode == 'row':
|
|
358
297
|
return [columns] # A single row containing all columns
|
|
359
298
|
|
|
360
299
|
# Default to 'grid' layout
|
|
361
|
-
return [columns[i:i +
|
|
300
|
+
return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
|
|
362
301
|
|
|
363
302
|
# --- Window Creation ---
|
|
364
303
|
def create_window(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Window:
|
|
@@ -421,8 +360,8 @@ def _default_categorical_processor(feature_name: str, chosen_value: Any) -> List
|
|
|
421
360
|
return [1.0] if str(chosen_value) == 'True' else [0.0]
|
|
422
361
|
|
|
423
362
|
def prepare_feature_vector(
|
|
424
|
-
|
|
425
|
-
|
|
363
|
+
window_values: Dict[str, Any],
|
|
364
|
+
gui_feature_order: List[str],
|
|
426
365
|
continuous_features: List[str],
|
|
427
366
|
categorical_features: List[str],
|
|
428
367
|
categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
|
|
@@ -432,8 +371,8 @@ def prepare_feature_vector(
|
|
|
432
371
|
This function supports label encoding and one-hot encoding via the processor.
|
|
433
372
|
|
|
434
373
|
Args:
|
|
435
|
-
|
|
436
|
-
|
|
374
|
+
window_values (dict): The values dictionary from a `window.read()` call.
|
|
375
|
+
gui_feature_order (list): A list of all feature names that have a GUI element.
|
|
437
376
|
For one-hot encoding, this should be the name of the
|
|
438
377
|
single GUI element (e.g., 'material_type'), not the
|
|
439
378
|
expanded feature names (e.g., 'material_is_steel').
|
|
@@ -456,8 +395,8 @@ def prepare_feature_vector(
|
|
|
456
395
|
cont_set = set(continuous_features)
|
|
457
396
|
cat_set = set(categorical_features)
|
|
458
397
|
|
|
459
|
-
for name in
|
|
460
|
-
chosen_value =
|
|
398
|
+
for name in gui_feature_order:
|
|
399
|
+
chosen_value = window_values.get(name)
|
|
461
400
|
|
|
462
401
|
if chosen_value is None or chosen_value == '':
|
|
463
402
|
raise ValueError(f"Feature '{name}' is missing a value.")
|
|
@@ -482,13 +421,12 @@ def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
|
|
|
482
421
|
|
|
483
422
|
Args:
|
|
484
423
|
window (sg.Window): The application's window object.
|
|
485
|
-
results_dict (dict): A dictionary where keys are target
|
|
486
|
-
'TARGET_' prefix) and values are the predicted results.
|
|
424
|
+
results_dict (dict): A dictionary where keys are target element-keys and values are the predicted results to update.
|
|
487
425
|
"""
|
|
488
426
|
for target_name, result in results_dict.items():
|
|
489
427
|
# Format numbers to 2 decimal places, leave other types as-is
|
|
490
428
|
display_value = f"{result:.2f}" if isinstance(result, (int, float)) else result
|
|
491
|
-
window[
|
|
429
|
+
window[target_name].update(display_value) # type: ignore
|
|
492
430
|
|
|
493
431
|
|
|
494
432
|
def info():
|
|
@@ -128,7 +128,7 @@ def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_name
|
|
|
128
128
|
plt.savefig(save_path, bbox_inches='tight', format="svg")
|
|
129
129
|
plt.close()
|
|
130
130
|
|
|
131
|
-
_LOGGER.info(f"{dataset_file_dir} completed.")
|
|
131
|
+
_LOGGER.info(f"✅ {dataset_file_dir} process completed.")
|
|
132
132
|
|
|
133
133
|
|
|
134
134
|
# Imputed distributions
|
|
@@ -213,7 +213,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
|
|
|
213
213
|
fig = kernel.plot_imputed_distributions(variables=[feature])
|
|
214
214
|
_process_figure(fig, feature)
|
|
215
215
|
|
|
216
|
-
_LOGGER.info(f"{local_dir_name} completed.")
|
|
216
|
+
_LOGGER.info(f"✅ {local_dir_name} completed.")
|
|
217
217
|
|
|
218
218
|
|
|
219
219
|
def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
|
|
@@ -178,7 +178,6 @@ class EarlyStopping(Callback):
|
|
|
178
178
|
self.stopped_epoch = epoch
|
|
179
179
|
self.trainer.stop_training = True # type: ignore
|
|
180
180
|
if self.verbose > 0:
|
|
181
|
-
print("")
|
|
182
181
|
_LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
|
|
183
182
|
|
|
184
183
|
|
|
@@ -256,7 +255,6 @@ class ModelCheckpoint(Callback):
|
|
|
256
255
|
new_filepath = self.save_dir / filename
|
|
257
256
|
|
|
258
257
|
if self.verbose > 0:
|
|
259
|
-
print("")
|
|
260
258
|
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
261
259
|
|
|
262
260
|
# Save the new best model
|
|
@@ -276,7 +274,6 @@ class ModelCheckpoint(Callback):
|
|
|
276
274
|
filepath = self.save_dir / filename
|
|
277
275
|
|
|
278
276
|
if self.verbose > 0:
|
|
279
|
-
print("")
|
|
280
277
|
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
281
278
|
torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
|
|
282
279
|
|
|
@@ -325,7 +322,6 @@ class LRScheduler(Callback):
|
|
|
325
322
|
if metric_val is not None:
|
|
326
323
|
self.scheduler.step(metric_val)
|
|
327
324
|
else:
|
|
328
|
-
print("")
|
|
329
325
|
_LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
|
|
330
326
|
|
|
331
327
|
# For all other schedulers
|
|
@@ -335,7 +331,6 @@ class LRScheduler(Callback):
|
|
|
335
331
|
# Log the change if the LR was updated
|
|
336
332
|
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
337
333
|
if current_lr != self.previous_lr:
|
|
338
|
-
print("")
|
|
339
334
|
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
340
335
|
self.previous_lr = current_lr
|
|
341
336
|
|
|
@@ -65,7 +65,7 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
|
|
|
65
65
|
save_dir_path = make_fullpath(save_dir, make=True)
|
|
66
66
|
save_path = save_dir_path / "loss_plot.svg"
|
|
67
67
|
plt.savefig(save_path)
|
|
68
|
-
_LOGGER.info(f"Loss plot saved as '{save_path.name}'")
|
|
68
|
+
_LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
|
|
69
69
|
else:
|
|
70
70
|
plt.show()
|
|
71
71
|
plt.close(fig)
|
|
@@ -92,7 +92,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
|
|
|
92
92
|
# Save text report
|
|
93
93
|
report_path = save_dir_path / "classification_report.txt"
|
|
94
94
|
report_path.write_text(report, encoding="utf-8")
|
|
95
|
-
_LOGGER.info(f"Classification report saved as '{report_path.name}'")
|
|
95
|
+
_LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
|
|
96
96
|
|
|
97
97
|
# Save Confusion Matrix
|
|
98
98
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
@@ -100,7 +100,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
|
|
|
100
100
|
ax_cm.set_title("Confusion Matrix")
|
|
101
101
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
102
102
|
plt.savefig(cm_path)
|
|
103
|
-
_LOGGER.info(f"Confusion matrix saved as '{cm_path.name}'")
|
|
103
|
+
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
104
104
|
plt.close(fig_cm)
|
|
105
105
|
|
|
106
106
|
# Save ROC Curve
|
|
@@ -117,7 +117,7 @@ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optio
|
|
|
117
117
|
ax_roc.grid(True)
|
|
118
118
|
roc_path = save_dir_path / "roc_curve.svg"
|
|
119
119
|
plt.savefig(roc_path)
|
|
120
|
-
_LOGGER.info(f"ROC curve saved as '{roc_path.name}'")
|
|
120
|
+
_LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
|
|
121
121
|
plt.close(fig_roc)
|
|
122
122
|
else:
|
|
123
123
|
# Show plots if not saving
|
|
@@ -162,7 +162,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
|
|
|
162
162
|
# Save text report
|
|
163
163
|
report_path = save_dir_path / "regression_report.txt"
|
|
164
164
|
report_path.write_text(report_string)
|
|
165
|
-
_LOGGER.info(f"Regression report saved as '{report_path.name}'")
|
|
165
|
+
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
166
166
|
|
|
167
167
|
# Save residual plot
|
|
168
168
|
residuals = y_true - y_pred
|
|
@@ -176,7 +176,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
|
|
|
176
176
|
plt.tight_layout()
|
|
177
177
|
res_path = save_dir_path / "residual_plot.svg"
|
|
178
178
|
plt.savefig(res_path)
|
|
179
|
-
_LOGGER.info(f"Residual plot saved as '{res_path.name}'")
|
|
179
|
+
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
180
180
|
plt.close(fig_res)
|
|
181
181
|
|
|
182
182
|
# Save true vs predicted plot
|
|
@@ -190,7 +190,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
|
|
|
190
190
|
plt.tight_layout()
|
|
191
191
|
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
192
192
|
plt.savefig(tvp_path)
|
|
193
|
-
_LOGGER.info(f"True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
193
|
+
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
194
194
|
plt.close(fig_tvp)
|
|
195
195
|
|
|
196
196
|
|
|
@@ -227,7 +227,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
227
227
|
plt.title("SHAP Feature Importance")
|
|
228
228
|
plt.tight_layout()
|
|
229
229
|
plt.savefig(bar_path)
|
|
230
|
-
_LOGGER.info(f"SHAP bar plot saved as '{bar_path.name}'")
|
|
230
|
+
_LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
|
|
231
231
|
plt.close()
|
|
232
232
|
|
|
233
233
|
# Save Dot Plot
|
|
@@ -236,7 +236,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
236
236
|
plt.title("SHAP Feature Importance")
|
|
237
237
|
plt.tight_layout()
|
|
238
238
|
plt.savefig(dot_path)
|
|
239
|
-
_LOGGER.info(f"SHAP dot plot saved as '{dot_path.name}'")
|
|
239
|
+
_LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
|
|
240
240
|
plt.close()
|
|
241
241
|
|
|
242
242
|
# Save Summary Data to CSV
|
|
@@ -249,7 +249,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
|
|
|
249
249
|
'mean_abs_shap_value': mean_abs_shap
|
|
250
250
|
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
251
251
|
summary_df.to_csv(summary_path, index=False)
|
|
252
|
-
_LOGGER.info(f"SHAP summary data saved as '{summary_path.name}'")
|
|
252
|
+
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
253
253
|
else:
|
|
254
254
|
_LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
|
|
255
255
|
shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
|
|
@@ -72,10 +72,10 @@ class MyTrainer:
|
|
|
72
72
|
"""Validates the selected device and returns a torch.device object."""
|
|
73
73
|
device_lower = device.lower()
|
|
74
74
|
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
75
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
75
|
+
_LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
|
|
76
76
|
device = "cpu"
|
|
77
77
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
78
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
78
|
+
_LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
79
79
|
device = "cpu"
|
|
80
80
|
return torch.device(device)
|
|
81
81
|
|
|
@@ -22,7 +22,6 @@ import torch
|
|
|
22
22
|
from tqdm import trange
|
|
23
23
|
import matplotlib.pyplot as plt
|
|
24
24
|
import seaborn as sns
|
|
25
|
-
from collections import defaultdict
|
|
26
25
|
from .logger import _LOGGER
|
|
27
26
|
|
|
28
27
|
|
|
@@ -307,7 +306,7 @@ def run_pso(lower_boundaries: list[float],
|
|
|
307
306
|
else:
|
|
308
307
|
device = torch.device("cpu")
|
|
309
308
|
|
|
310
|
-
_LOGGER.info(f"Using device: '{device}'")
|
|
309
|
+
_LOGGER.info(f"👾 Using device: '{device}'")
|
|
311
310
|
|
|
312
311
|
# set local deep copies to prevent in place list modification
|
|
313
312
|
local_lower_boundaries = deepcopy(lower_boundaries)
|
|
@@ -511,13 +510,13 @@ def _pso(func: ObjectiveFunction,
|
|
|
511
510
|
return best_position, best_score
|
|
512
511
|
|
|
513
512
|
|
|
514
|
-
def plot_optimal_feature_distributions(results_dir: Union[str, Path], save_dir: Union[str, Path]
|
|
513
|
+
def plot_optimal_feature_distributions(results_dir: Union[str, Path], save_dir: Union[str, Path]):
|
|
515
514
|
"""
|
|
516
515
|
Analyzes optimization results and plots the distribution of optimal values for each feature.
|
|
517
516
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
517
|
+
For features with more than two unique values, this function generates a color-coded
|
|
518
|
+
Kernel Density Estimate (KDE) plot. For binary or constant features, it generates a bar plot
|
|
519
|
+
showing relative frequency.
|
|
521
520
|
|
|
522
521
|
Parameters
|
|
523
522
|
----------
|
|
@@ -525,76 +524,69 @@ def plot_optimal_feature_distributions(results_dir: Union[str, Path], save_dir:
|
|
|
525
524
|
The path to the directory containing the optimization result CSV files.
|
|
526
525
|
save_dir : str or Path
|
|
527
526
|
The directory where the output plots will be saved.
|
|
528
|
-
color_by_target : bool, optional
|
|
529
|
-
If True, generates comparative plots with distributions colored by their source target.
|
|
530
527
|
"""
|
|
531
|
-
|
|
532
|
-
_LOGGER.info(f"Starting analysis in '{mode}' mode from results in: '{results_dir}'")
|
|
533
|
-
|
|
534
|
-
# Check results_dir
|
|
528
|
+
# Check results_dir and create output path
|
|
535
529
|
results_path = make_fullpath(results_dir)
|
|
536
|
-
# make output path
|
|
537
530
|
output_path = make_fullpath(save_dir, make=True)
|
|
538
531
|
|
|
539
532
|
all_csvs = list_csv_paths(results_path)
|
|
540
|
-
|
|
541
533
|
if not all_csvs:
|
|
542
|
-
_LOGGER.warning("No data found. No plots will be generated.")
|
|
534
|
+
_LOGGER.warning("⚠️ No data found. No plots will be generated.")
|
|
543
535
|
return
|
|
544
536
|
|
|
545
|
-
# ---
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
537
|
+
# --- Data Loading and Preparation ---
|
|
538
|
+
_LOGGER.info(f"📁 Starting analysis from results in: '{results_dir}'")
|
|
539
|
+
data_to_plot = []
|
|
540
|
+
for df, df_name in yield_dataframes_from_dir(results_path):
|
|
541
|
+
melted_df = df.iloc[:, :-1].melt(var_name='feature', value_name='value')
|
|
542
|
+
melted_df['target'] = df_name.replace("Optimization_", "")
|
|
543
|
+
data_to_plot.append(melted_df)
|
|
544
|
+
|
|
545
|
+
long_df = pd.concat(data_to_plot, ignore_index=True)
|
|
546
|
+
features = long_df['feature'].unique()
|
|
547
|
+
_LOGGER.info(f"📂 Found data for {len(features)} features across {len(long_df['target'].unique())} targets. Generating plots...")
|
|
548
|
+
|
|
549
|
+
# --- Plotting Loop ---
|
|
550
|
+
for feature_name in features:
|
|
551
|
+
plt.figure(figsize=(12, 7))
|
|
552
|
+
feature_df = long_df[long_df['feature'] == feature_name]
|
|
553
|
+
|
|
554
|
+
# Check if the feature is binary or constant
|
|
555
|
+
if feature_df['value'].nunique() <= 2:
|
|
556
|
+
# PLOT 1: For discrete values, calculate percentages and use a true bar plot.
|
|
557
|
+
# This ensures the X-axis is clean (e.g., just 0 and 1).
|
|
558
|
+
norm_df = (feature_df.groupby('target')['value']
|
|
559
|
+
.value_counts(normalize=True)
|
|
560
|
+
.mul(100)
|
|
561
|
+
.rename('percent')
|
|
562
|
+
.reset_index())
|
|
562
563
|
|
|
563
|
-
sns.
|
|
564
|
+
ax = sns.barplot(data=norm_df, x='value', y='percent', hue='target')
|
|
564
565
|
|
|
565
|
-
plt.title(f"
|
|
566
|
-
plt.
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
plt.
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
feature_columns = df.iloc[:, :-1]
|
|
581
|
-
for feature_name in feature_columns:
|
|
582
|
-
feature_distributions[feature_name].extend(df[feature_name].tolist())
|
|
566
|
+
plt.title(f"Optimal Value Distribution for '{feature_name}'", fontsize=16)
|
|
567
|
+
plt.ylabel("Frequency (%)", fontsize=12)
|
|
568
|
+
ax.set_ylim(0, 100) # Set Y-axis from 0 to 100
|
|
569
|
+
|
|
570
|
+
else:
|
|
571
|
+
# PLOT 2: KDE plot for continuous values.
|
|
572
|
+
ax = sns.kdeplot(data=feature_df, x='value', hue='target',
|
|
573
|
+
fill=True, alpha=0.1, warn_singular=False)
|
|
574
|
+
|
|
575
|
+
plt.title(f"Optimal Value Distribution for '{feature_name}'", fontsize=16)
|
|
576
|
+
plt.ylabel("Density", fontsize=12) # Y-axis is "Density" for KDE plots
|
|
577
|
+
|
|
578
|
+
# --- Common settings for both plot types ---
|
|
579
|
+
plt.xlabel("Feature Value", fontsize=12)
|
|
580
|
+
plt.grid(axis='y', alpha=0.5, linestyle='--')
|
|
583
581
|
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
plt.grid(axis='y', alpha=0.5, linestyle='--')
|
|
593
|
-
|
|
594
|
-
sanitized_feature_name = sanitize_filename(feature_name)
|
|
595
|
-
plot_filename = output_path / f"Aggregate_{sanitized_feature_name}.svg"
|
|
596
|
-
plt.savefig(plot_filename, bbox_inches='tight')
|
|
597
|
-
plt.close()
|
|
582
|
+
legend = ax.get_legend()
|
|
583
|
+
if legend:
|
|
584
|
+
legend.set_title('Target')
|
|
585
|
+
|
|
586
|
+
sanitized_feature_name = sanitize_filename(feature_name)
|
|
587
|
+
plot_filename = output_path / f"Distribution_{sanitized_feature_name}.svg"
|
|
588
|
+
plt.savefig(plot_filename, bbox_inches='tight')
|
|
589
|
+
plt.close()
|
|
598
590
|
|
|
599
591
|
_LOGGER.info(f"✅ All plots saved successfully to: '{output_path}'")
|
|
600
592
|
|
|
@@ -168,12 +168,12 @@ def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10
|
|
|
168
168
|
|
|
169
169
|
# Identify features to drop
|
|
170
170
|
to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
|
|
171
|
-
_LOGGER.info(f"
|
|
171
|
+
_LOGGER.info(f"🗑️ Dropping {len(to_drop)} column(s) with VIF > {threshold}: {to_drop}")
|
|
172
172
|
|
|
173
173
|
result_df = df.drop(columns=to_drop)
|
|
174
174
|
|
|
175
175
|
if result_df.empty:
|
|
176
|
-
_LOGGER.warning(f"
|
|
176
|
+
_LOGGER.warning(f"⚠️ All columns were dropped.")
|
|
177
177
|
|
|
178
178
|
return result_df, to_drop
|
|
179
179
|
|
|
@@ -100,10 +100,11 @@ def drop_constant_columns(df: pd.DataFrame, verbose: bool = True) -> pd.DataFram
|
|
|
100
100
|
cols_to_keep.append(col_name)
|
|
101
101
|
|
|
102
102
|
dropped_columns = original_columns - set(cols_to_keep)
|
|
103
|
-
if
|
|
104
|
-
print(f"Dropped {len(dropped_columns)} constant columns
|
|
105
|
-
|
|
106
|
-
|
|
103
|
+
if verbose:
|
|
104
|
+
print(f"🧹 Dropped {len(dropped_columns)} constant columns.")
|
|
105
|
+
if dropped_columns:
|
|
106
|
+
for dropped_column in dropped_columns:
|
|
107
|
+
print(f" {dropped_column}")
|
|
107
108
|
|
|
108
109
|
return df[cols_to_keep]
|
|
109
110
|
|
|
@@ -13,7 +13,7 @@ from torchvision.datasets import ImageFolder
|
|
|
13
13
|
from torchvision import transforms
|
|
14
14
|
import matplotlib.pyplot as plt
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from .utilities import _script_info
|
|
16
|
+
from .utilities import _script_info, make_fullpath
|
|
17
17
|
from .logger import _LOGGER
|
|
18
18
|
|
|
19
19
|
|
|
@@ -204,7 +204,7 @@ class DatasetMaker(_BaseMaker):
|
|
|
204
204
|
if not self._is_split:
|
|
205
205
|
raise RuntimeError("Continuous features must be normalized AFTER splitting data. Call .split_data() first.")
|
|
206
206
|
if self._is_normalized:
|
|
207
|
-
_LOGGER.warning("Data has already been normalized.")
|
|
207
|
+
_LOGGER.warning("⚠️ Data has already been normalized.")
|
|
208
208
|
return self
|
|
209
209
|
|
|
210
210
|
# Use continuous features columns
|
|
@@ -232,7 +232,7 @@ class DatasetMaker(_BaseMaker):
|
|
|
232
232
|
def split_data(self, test_size: float = 0.2, stratify: bool = False, random_state: Optional[int] = None) -> 'DatasetMaker':
|
|
233
233
|
"""Splits the data into training and testing sets."""
|
|
234
234
|
if self._is_split:
|
|
235
|
-
_LOGGER.warning("Data has already been split.")
|
|
235
|
+
_LOGGER.warning("⚠️ Data has already been split.")
|
|
236
236
|
return self
|
|
237
237
|
|
|
238
238
|
if self.labels.dtype == 'object' or self.labels.dtype.name == 'category':
|
|
@@ -260,9 +260,9 @@ class DatasetMaker(_BaseMaker):
|
|
|
260
260
|
Defaults to `SMOTETomek`.
|
|
261
261
|
"""
|
|
262
262
|
if not self._is_split:
|
|
263
|
-
raise RuntimeError("Cannot balance data before it has been split. Call .split_data() first.")
|
|
263
|
+
raise RuntimeError("❌ Cannot balance data before it has been split. Call .split_data() first.")
|
|
264
264
|
if self._is_balanced:
|
|
265
|
-
_LOGGER.warning("Training data has already been balanced.")
|
|
265
|
+
_LOGGER.warning("⚠️ Training data has already been balanced.")
|
|
266
266
|
return self
|
|
267
267
|
|
|
268
268
|
if resampler is None:
|
|
@@ -278,13 +278,13 @@ class DatasetMaker(_BaseMaker):
|
|
|
278
278
|
def process(self, test_size: float = 0.2, cat_method: Literal["one-hot", "embed"] = "one-hot", normalize_method: Literal["standard", "minmax"] = "standard",
|
|
279
279
|
balance: bool = False, random_state: Optional[int] = None) -> 'DatasetMaker':
|
|
280
280
|
"""Runs a standard, fully automated preprocessing pipeline."""
|
|
281
|
-
_LOGGER.info("--- Running Automated Processing Pipeline ---")
|
|
281
|
+
_LOGGER.info("--- 🤖 Running Automated Processing Pipeline ---")
|
|
282
282
|
self.process_categoricals(method=cat_method)
|
|
283
283
|
self.split_data(test_size=test_size, stratify=True, random_state=random_state)
|
|
284
284
|
self.normalize_continuous(method=normalize_method)
|
|
285
285
|
if balance:
|
|
286
286
|
self.balance_data()
|
|
287
|
-
_LOGGER.info("--- Automated Processing Complete ---")
|
|
287
|
+
_LOGGER.info("--- 🤖 Automated Processing Complete ---")
|
|
288
288
|
return self
|
|
289
289
|
|
|
290
290
|
def denormalize(self, data: Union[torch.Tensor, numpy.ndarray, pandas.DataFrame]) -> Union[numpy.ndarray, pandas.DataFrame]:
|
|
@@ -400,10 +400,7 @@ class VisionDatasetMaker(_BaseMaker):
|
|
|
400
400
|
Logs a report of the types, sizes, and channels of image files
|
|
401
401
|
found in the directory and its subdirectories.
|
|
402
402
|
"""
|
|
403
|
-
path_obj =
|
|
404
|
-
if not path_obj.is_dir():
|
|
405
|
-
_LOGGER.error(f"Path is not a valid directory: {path_obj}")
|
|
406
|
-
return
|
|
403
|
+
path_obj = make_fullpath(path)
|
|
407
404
|
|
|
408
405
|
non_image_files = set()
|
|
409
406
|
img_types = set()
|
|
@@ -505,7 +502,7 @@ class VisionDatasetMaker(_BaseMaker):
|
|
|
505
502
|
if not self._is_split:
|
|
506
503
|
raise RuntimeError("Data has not been split. Call .split_data() first.")
|
|
507
504
|
if not self._are_transforms_configured:
|
|
508
|
-
_LOGGER.warning("Transforms have not been configured. Using default ToTensor only.")
|
|
505
|
+
_LOGGER.warning("⚠️ Transforms have not been configured. Using default ToTensor only.")
|
|
509
506
|
|
|
510
507
|
if self._test_dataset:
|
|
511
508
|
return self._train_dataset, self._val_dataset, self._test_dataset
|
|
@@ -555,7 +552,7 @@ class SequenceMaker(_BaseMaker):
|
|
|
555
552
|
raise RuntimeError("Data must be split BEFORE normalizing. Call .split_data() first.")
|
|
556
553
|
|
|
557
554
|
if self.scaler:
|
|
558
|
-
_LOGGER.warning("Data has already been normalized.")
|
|
555
|
+
_LOGGER.warning("⚠️ Data has already been normalized.")
|
|
559
556
|
return self
|
|
560
557
|
|
|
561
558
|
if method == "standard":
|
|
@@ -579,7 +576,7 @@ class SequenceMaker(_BaseMaker):
|
|
|
579
576
|
def split_data(self, test_size: float = 0.2) -> 'SequenceMaker':
|
|
580
577
|
"""Splits the sequence into training and testing portions."""
|
|
581
578
|
if self._is_split:
|
|
582
|
-
_LOGGER.warning("Data has already been split.")
|
|
579
|
+
_LOGGER.warning("⚠️ Data has already been split.")
|
|
583
580
|
return self
|
|
584
581
|
|
|
585
582
|
split_idx = int(len(self.sequence) * (1 - test_size))
|
|
@@ -915,7 +915,7 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
|
|
|
915
915
|
datasets_path = make_fullpath(datasets_dir)
|
|
916
916
|
save_path = make_fullpath(save_dir, make=True)
|
|
917
917
|
|
|
918
|
-
_LOGGER.info("Training starting...")
|
|
918
|
+
_LOGGER.info("🏁 Training starting...")
|
|
919
919
|
#Yield imputed dataset
|
|
920
920
|
for dataframe, dataframe_name in yield_dataframes_from_dir(datasets_path):
|
|
921
921
|
#Yield features dataframe and target dataframe
|
|
@@ -933,7 +933,7 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
|
|
|
933
933
|
test_features=X_test, test_target=y_test,
|
|
934
934
|
feature_names=feature_names,target_name=target_name,
|
|
935
935
|
debug=debug, save_dir=save_path, save_model=save_model)
|
|
936
|
-
|
|
936
|
+
|
|
937
937
|
_LOGGER.info("✅ Training and evaluation complete.")
|
|
938
938
|
|
|
939
939
|
|
|
@@ -10,7 +10,6 @@ import logging
|
|
|
10
10
|
import sys
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
|
|
14
13
|
__all__ = [
|
|
15
14
|
"custom_logger"
|
|
16
15
|
]
|
|
@@ -85,10 +84,10 @@ def custom_logger(
|
|
|
85
84
|
else:
|
|
86
85
|
raise ValueError("Unsupported data type. Must be list, dict, DataFrame, str, or BaseException.")
|
|
87
86
|
|
|
88
|
-
_LOGGER.info(f"Log saved to: '{base_path}'")
|
|
87
|
+
_LOGGER.info(f"🗄️ Log saved to: '{base_path}'")
|
|
89
88
|
|
|
90
89
|
except Exception as e:
|
|
91
|
-
_LOGGER.error(f"Log not saved: {e}")
|
|
90
|
+
_LOGGER.error(f"❌ Log not saved: {e}")
|
|
92
91
|
|
|
93
92
|
|
|
94
93
|
def _log_list_to_txt(data: List[Any], path: Path) -> None:
|
|
@@ -176,7 +175,7 @@ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
|
|
|
176
175
|
handler = logging.StreamHandler(sys.stdout)
|
|
177
176
|
|
|
178
177
|
# Define the format string and the date format separately
|
|
179
|
-
log_format = '
|
|
178
|
+
log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
180
179
|
date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
|
|
181
180
|
|
|
182
181
|
# Pass both the format and the date format to the Formatter
|
|
@@ -4,9 +4,10 @@ import pandas as pd
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
import re
|
|
7
|
-
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple
|
|
7
|
+
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple, Callable, List, Dict
|
|
8
8
|
import joblib
|
|
9
9
|
from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
10
|
+
from pprint import pprint
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
# Keep track of available tools
|
|
@@ -25,7 +26,8 @@ __all__ = [
|
|
|
25
26
|
"serialize_object",
|
|
26
27
|
"deserialize_object",
|
|
27
28
|
"distribute_datasets_by_target",
|
|
28
|
-
"train_dataset_orchestrator"
|
|
29
|
+
"train_dataset_orchestrator",
|
|
30
|
+
"PathManager"
|
|
29
31
|
]
|
|
30
32
|
|
|
31
33
|
|
|
@@ -640,12 +642,214 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
|
|
|
640
642
|
print(f"⚠️ Failed to process file '{df_path}'. Reason: {e}")
|
|
641
643
|
continue
|
|
642
644
|
|
|
643
|
-
print(f"{total_saved} single-target datasets were created.")
|
|
645
|
+
print(f"\n✅ {total_saved} single-target datasets were created.")
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
### Path Manager
|
|
649
|
+
class PathManager:
|
|
650
|
+
"""
|
|
651
|
+
Manages and stores a project's file paths, acting as a centralized
|
|
652
|
+
"path database". It supports both development mode and applications
|
|
653
|
+
bundled with Briefcase.
|
|
654
|
+
|
|
655
|
+
Supports python dictionary syntax.
|
|
656
|
+
"""
|
|
657
|
+
def __init__(
|
|
658
|
+
self,
|
|
659
|
+
anchor_file: str,
|
|
660
|
+
base_directories: Optional[List[str]] = None
|
|
661
|
+
):
|
|
662
|
+
"""
|
|
663
|
+
The initializer determines the project's root directory and can pre-register
|
|
664
|
+
a list of base directories relative to that root.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
|
|
668
|
+
base_directories (Optional[List[str]]): A list of directory names
|
|
669
|
+
located at the same level as the anchor file's
|
|
670
|
+
parent directory to register immediately.
|
|
671
|
+
"""
|
|
672
|
+
resolved_anchor_path = Path(anchor_file).resolve()
|
|
673
|
+
self._package_name = resolved_anchor_path.parent.name
|
|
674
|
+
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
675
|
+
self._paths: Dict[str, Path] = {}
|
|
676
|
+
|
|
677
|
+
if self._is_bundled:
|
|
678
|
+
# In a bundle, resource_path gives the absolute path to the 'app_packages' dir
|
|
679
|
+
# when given the package name.
|
|
680
|
+
package_root = self._resource_path_func(self._package_name) # type: ignore
|
|
681
|
+
else:
|
|
682
|
+
# In dev mode, the package root is the directory containing the anchor file.
|
|
683
|
+
package_root = resolved_anchor_path.parent
|
|
684
|
+
|
|
685
|
+
# Register the root of the package itself
|
|
686
|
+
self._paths["ROOT"] = package_root
|
|
687
|
+
|
|
688
|
+
# Register all the base directories
|
|
689
|
+
if base_directories:
|
|
690
|
+
for dir_name in base_directories:
|
|
691
|
+
# In dev mode, this is simple. In a bundle, we must resolve
|
|
692
|
+
# each path from the package root.
|
|
693
|
+
if self._is_bundled:
|
|
694
|
+
self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
|
|
695
|
+
else:
|
|
696
|
+
self._paths[dir_name] = package_root / dir_name
|
|
697
|
+
|
|
698
|
+
# A helper function to find the briefcase-injected resource function
|
|
699
|
+
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
700
|
+
"""Checks if the app is running in a Briefcase bundle."""
|
|
701
|
+
try:
|
|
702
|
+
# This function is injected by Briefcase into the global scope
|
|
703
|
+
from briefcase.platforms.base import resource_path # type: ignore
|
|
704
|
+
return True, resource_path
|
|
705
|
+
except (ImportError, NameError):
|
|
706
|
+
return False, None
|
|
707
|
+
|
|
708
|
+
def get(self, key: str) -> Path:
|
|
709
|
+
"""
|
|
710
|
+
Retrieves a stored path by its key.
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
key (str): The key of the path to retrieve.
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
Path: The resolved, absolute Path object.
|
|
717
|
+
|
|
718
|
+
Raises:
|
|
719
|
+
KeyError: If the key is not found in the manager.
|
|
720
|
+
"""
|
|
721
|
+
try:
|
|
722
|
+
return self._paths[key]
|
|
723
|
+
except KeyError:
|
|
724
|
+
print(f"❌ Path key '{key}' not found.")
|
|
725
|
+
# Consider suggesting close matches if you want to get fancy
|
|
726
|
+
raise
|
|
727
|
+
|
|
728
|
+
def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
|
|
729
|
+
"""
|
|
730
|
+
Adds new paths or overwrites existing ones in the manager.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
|
|
734
|
+
the identifiers and values are the
|
|
735
|
+
Path objects or strings to store.
|
|
736
|
+
overwrite (bool): If False (default), raises a KeyError if any
|
|
737
|
+
key in new_paths already exists. If True,
|
|
738
|
+
allows overwriting existing keys.
|
|
739
|
+
"""
|
|
740
|
+
if not overwrite:
|
|
741
|
+
for key in new_paths:
|
|
742
|
+
if key in self._paths:
|
|
743
|
+
raise KeyError(
|
|
744
|
+
f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Resolve any string paths to Path objects before storing
|
|
748
|
+
resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
|
|
749
|
+
self._paths.update(resolved_new_paths)
|
|
750
|
+
|
|
751
|
+
def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
|
|
752
|
+
"""
|
|
753
|
+
Creates directory structures for registered paths in writable locations.
|
|
754
|
+
|
|
755
|
+
This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
|
|
756
|
+
|
|
757
|
+
In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
|
|
758
|
+
will only operate on paths outside of the package (e.g., user data dirs).
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
keys (Optional[List[str]]): If provided, only the directories
|
|
762
|
+
corresponding to these keys will be
|
|
763
|
+
created. If None (default), all
|
|
764
|
+
registered directory paths are used.
|
|
765
|
+
verbose (bool): If True, prints a message for each action.
|
|
766
|
+
"""
|
|
767
|
+
path_items = []
|
|
768
|
+
if keys:
|
|
769
|
+
for key in keys:
|
|
770
|
+
if key in self._paths:
|
|
771
|
+
path_items.append((key, self._paths[key]))
|
|
772
|
+
elif verbose:
|
|
773
|
+
print(f"⚠️ Key '{key}' not found in PathManager, skipping.")
|
|
774
|
+
else:
|
|
775
|
+
path_items = self._paths.items()
|
|
776
|
+
|
|
777
|
+
# Get the package root to check against.
|
|
778
|
+
package_root = self._paths.get("ROOT")
|
|
779
|
+
|
|
780
|
+
for key, path in path_items:
|
|
781
|
+
if path.suffix: # It's a file, not a directory
|
|
782
|
+
continue
|
|
783
|
+
|
|
784
|
+
# --- THE CRITICAL CHECK ---
|
|
785
|
+
# Determine if the path is inside the main application package.
|
|
786
|
+
is_internal_path = package_root and path.is_relative_to(package_root)
|
|
787
|
+
|
|
788
|
+
if self._is_bundled and is_internal_path:
|
|
789
|
+
if verbose:
|
|
790
|
+
print(f"ℹ️ Skipping internal directory '{key}' in bundled app (read-only).")
|
|
791
|
+
continue
|
|
792
|
+
# -------------------------
|
|
793
|
+
|
|
794
|
+
if verbose:
|
|
795
|
+
print(f"📁 Ensuring directory exists for key '{key}': {path}")
|
|
796
|
+
|
|
797
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
798
|
+
|
|
799
|
+
def status(self) -> None:
|
|
800
|
+
"""
|
|
801
|
+
Checks the status of all registered paths on the filesystem and prints a formatted report.
|
|
802
|
+
"""
|
|
803
|
+
report = {}
|
|
804
|
+
for key, path in self.items():
|
|
805
|
+
if path.is_dir():
|
|
806
|
+
report[key] = "📁 Directory"
|
|
807
|
+
elif path.is_file():
|
|
808
|
+
report[key] = "📄 File"
|
|
809
|
+
else:
|
|
810
|
+
report[key] = "❌ Not Found"
|
|
811
|
+
|
|
812
|
+
print("\n--- Path Status Report ---")
|
|
813
|
+
pprint(report)
|
|
814
|
+
|
|
815
|
+
def __repr__(self) -> str:
|
|
816
|
+
"""Provides a string representation of the stored paths."""
|
|
817
|
+
path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
|
|
818
|
+
return f"PathManager(\n{path_list}\n)"
|
|
819
|
+
|
|
820
|
+
# --- Dictionary-Style Methods ---
|
|
821
|
+
def __getitem__(self, key: str) -> Path:
|
|
822
|
+
"""Allows dictionary-style getting, e.g., PM['my_key']"""
|
|
823
|
+
return self.get(key)
|
|
824
|
+
|
|
825
|
+
def __setitem__(self, key: str, value: Union[str, Path]):
|
|
826
|
+
"""Allows dictionary-style setting, e.g., PM['my_key'] = path"""
|
|
827
|
+
self.update({key: value}, overwrite=True)
|
|
828
|
+
|
|
829
|
+
def __contains__(self, key: str) -> bool:
|
|
830
|
+
"""Allows checking for a key's existence, e.g., if 'my_key' in PM"""
|
|
831
|
+
return key in self._paths
|
|
832
|
+
|
|
833
|
+
def __len__(self) -> int:
|
|
834
|
+
"""Allows getting the number of paths, e.g., len(PM)"""
|
|
835
|
+
return len(self._paths)
|
|
836
|
+
|
|
837
|
+
def keys(self):
|
|
838
|
+
"""Returns all registered path keys."""
|
|
839
|
+
return self._paths.keys()
|
|
840
|
+
|
|
841
|
+
def values(self):
|
|
842
|
+
"""Returns all registered Path objects."""
|
|
843
|
+
return self._paths.values()
|
|
844
|
+
|
|
845
|
+
def items(self):
|
|
846
|
+
"""Returns all registered (key, Path) pairs."""
|
|
847
|
+
return self._paths.items()
|
|
644
848
|
|
|
645
849
|
|
|
646
850
|
class LogKeys:
|
|
647
851
|
"""
|
|
648
|
-
Used for ML scripts
|
|
852
|
+
Used internally for ML scripts.
|
|
649
853
|
|
|
650
854
|
Centralized keys for logging and history.
|
|
651
855
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-3.6.0 → dragon_ml_toolbox-3.8.0}/ml_tools/_particle_swarm_optimization.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|