dragon-ml-toolbox 3.7.0__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/METADATA +4 -3
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/RECORD +10 -9
- ml_tools/GUI_tools.py +96 -131
- ml_tools/ensemble_learning.py +123 -3
- ml_tools/path_manager.py +212 -0
- ml_tools/utilities.py +2 -2
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.7.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.9.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -15,8 +15,8 @@ License-File: LICENSE-THIRD-PARTY.md
|
|
|
15
15
|
Requires-Dist: numpy<2.0
|
|
16
16
|
Requires-Dist: scikit-learn
|
|
17
17
|
Requires-Dist: openpyxl
|
|
18
|
-
Requires-Dist: miceforest
|
|
19
|
-
Requires-Dist: plotnine
|
|
18
|
+
Requires-Dist: miceforest>=6.0.0
|
|
19
|
+
Requires-Dist: plotnine>=0.12
|
|
20
20
|
Requires-Dist: matplotlib
|
|
21
21
|
Requires-Dist: seaborn
|
|
22
22
|
Requires-Dist: pandas
|
|
@@ -129,6 +129,7 @@ ML_callbacks
|
|
|
129
129
|
ML_evaluation
|
|
130
130
|
ML_trainer
|
|
131
131
|
ML_tutorial
|
|
132
|
+
path_manager
|
|
132
133
|
PSO_optimization
|
|
133
134
|
RNN_forecast
|
|
134
135
|
utilities
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
dragon_ml_toolbox-3.
|
|
2
|
-
dragon_ml_toolbox-3.
|
|
1
|
+
dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
|
|
3
3
|
ml_tools/ETL_engineering.py,sha256=yeZsW_7zRvEcuMZbM4E2GV1dxwBoWIeJAcFFk2AK0fY,39502
|
|
4
|
-
ml_tools/GUI_tools.py,sha256=
|
|
4
|
+
ml_tools/GUI_tools.py,sha256=ABR1cqV09iZ2DbLfLZB7jaQVRVDbvCmj09pNkr3TDZk,18800
|
|
5
5
|
ml_tools/MICE_imputation.py,sha256=rYqvwQDVtoAJJ0agXWoGzoZEHedWiA6QzcEKEIkiZ08,11388
|
|
6
6
|
ml_tools/ML_callbacks.py,sha256=OT2zwORLcn49megBEgXsSUxDHoW0Ft0_v7hLEVF3jHM,13063
|
|
7
7
|
ml_tools/ML_evaluation.py,sha256=oiDV6HItQloUUKCUpltV-2pogubWLBieGpc-VUwosAQ,10106
|
|
@@ -15,11 +15,12 @@ ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9w
|
|
|
15
15
|
ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
|
|
16
16
|
ml_tools/data_exploration.py,sha256=M7bn2q5XN9zJZJGAmMMFSFFZh8LGzC2arFelrXw3N6Q,25241
|
|
17
17
|
ml_tools/datasetmaster.py,sha256=S3PKHNQZ9cyAOck8xQltVLZhaD1gFLfgHFL-aRjz4JU,30077
|
|
18
|
-
ml_tools/ensemble_learning.py,sha256=
|
|
18
|
+
ml_tools/ensemble_learning.py,sha256=p9PZwGY2OGSrJhXNzvMS_kCjK-I2JVcqiJBaVzb0GrM,42616
|
|
19
19
|
ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
|
|
20
20
|
ml_tools/logger.py,sha256=UkbiU9ihBhw9VKyn3rZzisdClWV94EBV6B09_D0iUU0,6026
|
|
21
|
-
ml_tools/
|
|
22
|
-
|
|
23
|
-
dragon_ml_toolbox-3.
|
|
24
|
-
dragon_ml_toolbox-3.
|
|
25
|
-
dragon_ml_toolbox-3.
|
|
21
|
+
ml_tools/path_manager.py,sha256=OCpESgdftbi6mOxetDMIaHhazt4N-W8pJx11X3-yNOs,8305
|
|
22
|
+
ml_tools/utilities.py,sha256=HR36Q_vYnaRcpSjpNISnA7lOZ36TouHop38lPLG_twY,23146
|
|
23
|
+
dragon_ml_toolbox-3.9.0.dist-info/METADATA,sha256=2R3xIuefuR9O_h71q3S49xUm2MLKQtn12jjwNFKl2mE,3273
|
|
24
|
+
dragon_ml_toolbox-3.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
dragon_ml_toolbox-3.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
26
|
+
dragon_ml_toolbox-3.9.0.dist-info/RECORD,,
|
ml_tools/GUI_tools.py
CHANGED
|
@@ -4,83 +4,21 @@ from typing import Optional, Callable, Any
|
|
|
4
4
|
import traceback
|
|
5
5
|
import FreeSimpleGUI as sg
|
|
6
6
|
from functools import wraps
|
|
7
|
-
from typing import Any, Dict, Tuple, List
|
|
7
|
+
from typing import Any, Dict, Tuple, List, Literal
|
|
8
8
|
from .utilities import _script_info
|
|
9
9
|
import numpy as np
|
|
10
10
|
from .logger import _LOGGER
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
|
-
"PathManager",
|
|
15
15
|
"ConfigManager",
|
|
16
16
|
"GUIFactory",
|
|
17
17
|
"catch_exceptions",
|
|
18
|
-
"
|
|
18
|
+
"BaseFeatureHandler",
|
|
19
19
|
"update_target_fields"
|
|
20
20
|
]
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
# --- Path Management ---
|
|
24
|
-
class PathManager:
|
|
25
|
-
"""
|
|
26
|
-
Manages paths for a Python application, supporting both development mode and bundled mode via Briefcase.
|
|
27
|
-
"""
|
|
28
|
-
def __init__(self, anchor_file: str):
|
|
29
|
-
"""
|
|
30
|
-
Initializes the PathManager. The package name is automatically inferred
|
|
31
|
-
from the parent directory of the anchor file.
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
anchor_file (str): The absolute path to a file within the project's
|
|
35
|
-
package, typically `__file__` from a module inside
|
|
36
|
-
that package (paths.py).
|
|
37
|
-
|
|
38
|
-
Note:
|
|
39
|
-
This inference assumes that the anchor file's parent directory
|
|
40
|
-
has the same name as the package (e.g., `.../src/my_app/paths.py`).
|
|
41
|
-
This is a standard and recommended project structure.
|
|
42
|
-
"""
|
|
43
|
-
resolved_anchor_path = Path(anchor_file).resolve()
|
|
44
|
-
self.package_name = resolved_anchor_path.parent.name
|
|
45
|
-
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
46
|
-
|
|
47
|
-
if self._is_bundled:
|
|
48
|
-
# In a Briefcase bundle, resource_path gives an absolute path
|
|
49
|
-
# to the resource directory.
|
|
50
|
-
self.package_root = self._resource_path_func(self.package_name, "") # type: ignore
|
|
51
|
-
else:
|
|
52
|
-
# In development mode, the package root is the directory
|
|
53
|
-
# containing the anchor file.
|
|
54
|
-
self.package_root = resolved_anchor_path.parent
|
|
55
|
-
|
|
56
|
-
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
57
|
-
"""Checks if the app is running in a bundled environment."""
|
|
58
|
-
try:
|
|
59
|
-
# This is the function Briefcase provides in a bundled app
|
|
60
|
-
from briefcase.platforms.base import resource_path # type: ignore
|
|
61
|
-
return True, resource_path
|
|
62
|
-
except ImportError:
|
|
63
|
-
return False, None
|
|
64
|
-
|
|
65
|
-
def get_path(self, relative_path: str | Path) -> Path:
|
|
66
|
-
"""
|
|
67
|
-
Gets the absolute path for a given resource file or directory
|
|
68
|
-
relative to the package root.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
relative_path (str | Path): The path relative to the package root (e.g., 'helpers/icon.png').
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
Path: The absolute path to the resource.
|
|
75
|
-
"""
|
|
76
|
-
if self._is_bundled:
|
|
77
|
-
# Briefcase's resource_path handles resolving the path within the app bundle
|
|
78
|
-
return self._resource_path_func(self.package_name, str(relative_path)) # type: ignore
|
|
79
|
-
else:
|
|
80
|
-
# In dev mode, join package root with the relative path.
|
|
81
|
-
return self.package_root / relative_path
|
|
82
|
-
|
|
83
|
-
|
|
84
22
|
# --- Configuration Management ---
|
|
85
23
|
class _SectionProxy:
|
|
86
24
|
"""A helper class to represent a section of the .ini file as an object."""
|
|
@@ -273,8 +211,8 @@ class GUIFactory:
|
|
|
273
211
|
self,
|
|
274
212
|
data_dict: Dict[str, Tuple[float, float]],
|
|
275
213
|
is_target: bool = False,
|
|
276
|
-
layout_mode:
|
|
277
|
-
|
|
214
|
+
layout_mode: Literal["grid", "row"] = 'grid',
|
|
215
|
+
features_per_column: int = 4
|
|
278
216
|
) -> List[List[sg.Column]]:
|
|
279
217
|
"""
|
|
280
218
|
Generates a layout for continuous features or targets.
|
|
@@ -283,7 +221,7 @@ class GUIFactory:
|
|
|
283
221
|
data_dict (dict): Keys are feature names, values are (min, max) tuples.
|
|
284
222
|
is_target (bool): If True, creates disabled inputs for displaying results.
|
|
285
223
|
layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
|
|
286
|
-
|
|
224
|
+
features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
|
|
287
225
|
|
|
288
226
|
Returns:
|
|
289
227
|
A list of lists of sg.Column elements, ready to be used in a window layout.
|
|
@@ -294,7 +232,7 @@ class GUIFactory:
|
|
|
294
232
|
|
|
295
233
|
columns = []
|
|
296
234
|
for name, (val_min, val_max) in data_dict.items():
|
|
297
|
-
key =
|
|
235
|
+
key = name
|
|
298
236
|
default_text = "" if is_target else str(val_max)
|
|
299
237
|
|
|
300
238
|
label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
|
|
@@ -313,6 +251,7 @@ class GUIFactory:
|
|
|
313
251
|
range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
|
|
314
252
|
layout = [[label], [element], [range_text]]
|
|
315
253
|
|
|
254
|
+
# each feature is wrapped as a column element
|
|
316
255
|
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
317
256
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
318
257
|
|
|
@@ -320,13 +259,13 @@ class GUIFactory:
|
|
|
320
259
|
return [columns] # A single row containing all columns
|
|
321
260
|
|
|
322
261
|
# Default to 'grid' layout
|
|
323
|
-
return [columns[i:i +
|
|
262
|
+
return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
|
|
324
263
|
|
|
325
264
|
def generate_combo_layout(
|
|
326
265
|
self,
|
|
327
266
|
data_dict: Dict[str, List[Any]],
|
|
328
|
-
layout_mode:
|
|
329
|
-
|
|
267
|
+
layout_mode: Literal["grid", "row"] = 'grid',
|
|
268
|
+
features_per_column: int = 4
|
|
330
269
|
) -> List[List[sg.Column]]:
|
|
331
270
|
"""
|
|
332
271
|
Generates a layout for categorical or binary features using Combo boxes.
|
|
@@ -334,7 +273,7 @@ class GUIFactory:
|
|
|
334
273
|
Args:
|
|
335
274
|
data_dict (dict): Keys are feature names, values are lists of options.
|
|
336
275
|
layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
|
|
337
|
-
|
|
276
|
+
features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
|
|
338
277
|
|
|
339
278
|
Returns:
|
|
340
279
|
A list of lists of sg.Column elements, ready to be used in a window layout.
|
|
@@ -352,13 +291,14 @@ class GUIFactory:
|
|
|
352
291
|
)
|
|
353
292
|
layout = [[label], [element]]
|
|
354
293
|
layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
|
|
294
|
+
# each feature is wrapped in a Column element
|
|
355
295
|
columns.append(sg.Column(layout, background_color=bg_color))
|
|
356
296
|
|
|
357
297
|
if layout_mode == 'row':
|
|
358
298
|
return [columns] # A single row containing all columns
|
|
359
299
|
|
|
360
300
|
# Default to 'grid' layout
|
|
361
|
-
return [columns[i:i +
|
|
301
|
+
return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
|
|
362
302
|
|
|
363
303
|
# --- Window Creation ---
|
|
364
304
|
def create_window(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Window:
|
|
@@ -412,68 +352,93 @@ def catch_exceptions(show_popup: bool = True):
|
|
|
412
352
|
return decorator
|
|
413
353
|
|
|
414
354
|
|
|
415
|
-
# --- Inference
|
|
416
|
-
|
|
417
|
-
"""
|
|
418
|
-
Default processor for binary 'True'/'False' strings.
|
|
419
|
-
Returns a list containing a single float.
|
|
420
|
-
"""
|
|
421
|
-
return [1.0] if str(chosen_value) == 'True' else [0.0]
|
|
422
|
-
|
|
423
|
-
def prepare_feature_vector(
|
|
424
|
-
values: Dict[str, Any],
|
|
425
|
-
feature_order: List[str],
|
|
426
|
-
continuous_features: List[str],
|
|
427
|
-
categorical_features: List[str],
|
|
428
|
-
categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
|
|
429
|
-
) -> np.ndarray:
|
|
355
|
+
# --- Inference Helper ---
|
|
356
|
+
class BaseFeatureHandler(ABC):
|
|
430
357
|
"""
|
|
431
|
-
|
|
432
|
-
This function supports label encoding and one-hot encoding via the processor.
|
|
358
|
+
An abstract base class that defines the template for preparing a model input feature vector to perform inference, from GUI inputs.
|
|
433
359
|
|
|
434
|
-
|
|
435
|
-
values (dict): The values dictionary from a `window.read()` call.
|
|
436
|
-
feature_order (list): A list of all feature names that have a GUI element.
|
|
437
|
-
For one-hot encoding, this should be the name of the
|
|
438
|
-
single GUI element (e.g., 'material_type'), not the
|
|
439
|
-
expanded feature names (e.g., 'material_is_steel').
|
|
440
|
-
continuous_features (list): A list of names for continuous features.
|
|
441
|
-
categorical_features (list): A list of names for categorical features.
|
|
442
|
-
categorical_processor (callable, optional): A function to process categorical
|
|
443
|
-
values. It should accept (feature_name, chosen_value) and return a
|
|
444
|
-
list of floats (e.g., [1.0] for label encoding, [0.0, 1.0, 0.0] for one-hot).
|
|
445
|
-
If None, a default 'True'/'False' processor is used.
|
|
446
|
-
|
|
447
|
-
Returns:
|
|
448
|
-
A 1D numpy array ready for model inference.
|
|
360
|
+
A subclass must implement the `gui_input_map` property and the `process_categorical` method.
|
|
449
361
|
"""
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
processor = categorical_processor or _default_categorical_processor
|
|
454
|
-
|
|
455
|
-
# Create sets for faster lookups
|
|
456
|
-
cont_set = set(continuous_features)
|
|
457
|
-
cat_set = set(categorical_features)
|
|
458
|
-
|
|
459
|
-
for name in feature_order:
|
|
460
|
-
chosen_value = values.get(name)
|
|
362
|
+
def __init__(self, expected_columns_in_order: list[str]):
|
|
363
|
+
"""
|
|
364
|
+
Validates and stores the feature names in the order the model expects.
|
|
461
365
|
|
|
462
|
-
|
|
463
|
-
|
|
366
|
+
Args:
|
|
367
|
+
expected_columns_in_order (List[str]): A list of strings with the feature names in the correct order.
|
|
368
|
+
"""
|
|
369
|
+
# --- Validation Logic ---
|
|
370
|
+
if not isinstance(expected_columns_in_order, list):
|
|
371
|
+
raise TypeError("Input 'expected_columns_in_order' must be a list.")
|
|
372
|
+
|
|
373
|
+
if not all(isinstance(col, str) for col in expected_columns_in_order):
|
|
374
|
+
raise TypeError("All elements in the 'expected_columns_in_order' list must be strings.")
|
|
375
|
+
# -----------------------
|
|
376
|
+
|
|
377
|
+
self._model_feature_order = expected_columns_in_order
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
@abstractmethod
|
|
381
|
+
def gui_input_map(self) -> Dict[str, Literal["continuous","categorical"]]:
|
|
382
|
+
"""
|
|
383
|
+
Must be implemented by the subclass.
|
|
464
384
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
385
|
+
Should return a dictionary mapping each GUI input name to its type ('continuous' or 'categorical').
|
|
386
|
+
|
|
387
|
+
```python
|
|
388
|
+
#Example:
|
|
389
|
+
{'temperature': 'continuous', 'material_type': 'categorical'}
|
|
390
|
+
```
|
|
391
|
+
"""
|
|
392
|
+
pass
|
|
393
|
+
|
|
394
|
+
@abstractmethod
|
|
395
|
+
def process_categorical(self, feature_name: str, chosen_value: Any) -> Dict[str, float]:
|
|
396
|
+
"""
|
|
397
|
+
Must be implemented by the subclass.
|
|
398
|
+
|
|
399
|
+
Should take a GUI categorical feature name and its chosen value, and return a dictionary mapping the one-hot-encoded feature names to their
|
|
400
|
+
float values (as expected by the inference model).
|
|
401
|
+
"""
|
|
402
|
+
pass
|
|
403
|
+
|
|
404
|
+
def __call__(self, window_values: Dict[str, Any]) -> np.ndarray:
|
|
405
|
+
"""
|
|
406
|
+
Performs the full vector preparation, returning a 1D numpy array.
|
|
407
|
+
|
|
408
|
+
Should not be overridden by subclasses.
|
|
409
|
+
"""
|
|
410
|
+
# Stage 1: Process GUI inputs into a dictionary
|
|
411
|
+
processed_features: Dict[str, float] = {}
|
|
412
|
+
for gui_name, feature_type in self.gui_input_map.items():
|
|
413
|
+
chosen_value = window_values.get(gui_name)
|
|
414
|
+
|
|
415
|
+
if chosen_value is None or str(chosen_value) == '':
|
|
416
|
+
raise ValueError(f"GUI input '{gui_name}' is missing a value.")
|
|
417
|
+
|
|
418
|
+
if feature_type == 'continuous':
|
|
419
|
+
try:
|
|
420
|
+
processed_features[gui_name] = float(chosen_value)
|
|
421
|
+
except (ValueError, TypeError):
|
|
422
|
+
raise ValueError(f"Invalid number '{chosen_value}' for '{gui_name}'.")
|
|
423
|
+
|
|
424
|
+
elif feature_type == 'categorical':
|
|
425
|
+
feature_dict = self.process_categorical(gui_name, chosen_value)
|
|
426
|
+
processed_features.update(feature_dict)
|
|
427
|
+
|
|
428
|
+
# Stage 2: Assemble the final vector using the model's required order
|
|
429
|
+
final_vector: List[float] = []
|
|
470
430
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
431
|
+
try:
|
|
432
|
+
for feature_name in self._model_feature_order:
|
|
433
|
+
final_vector.append(processed_features[feature_name])
|
|
434
|
+
except KeyError as e:
|
|
435
|
+
raise RuntimeError(
|
|
436
|
+
f"Configuration Error: Implemented methods failed to generate "
|
|
437
|
+
f"the required model feature: '{e}'"
|
|
438
|
+
f"Check the gui_input_map and process_categorical logic."
|
|
439
|
+
)
|
|
475
440
|
|
|
476
|
-
|
|
441
|
+
return np.array(final_vector, dtype=np.float32)
|
|
477
442
|
|
|
478
443
|
|
|
479
444
|
def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
|
|
@@ -482,12 +447,12 @@ def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
|
|
|
482
447
|
|
|
483
448
|
Args:
|
|
484
449
|
window (sg.Window): The application's window object.
|
|
485
|
-
results_dict (dict): A dictionary where keys are target
|
|
450
|
+
results_dict (dict): A dictionary where keys are target element-keys and values are the predicted results to update.
|
|
486
451
|
"""
|
|
487
452
|
for target_name, result in results_dict.items():
|
|
488
453
|
# Format numbers to 2 decimal places, leave other types as-is
|
|
489
454
|
display_value = f"{result:.2f}" if isinstance(result, (int, float)) else result
|
|
490
|
-
window[target_name].update(display_value)
|
|
455
|
+
window[target_name].update(display_value) # type: ignore
|
|
491
456
|
|
|
492
457
|
|
|
493
458
|
def info():
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -6,7 +6,7 @@ from matplotlib.colors import Colormap
|
|
|
6
6
|
from matplotlib import rcdefaults
|
|
7
7
|
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Literal, Union, Optional, Iterator, Tuple
|
|
9
|
+
from typing import Literal, Union, Optional, Iterator, Tuple, Dict, Any, List
|
|
10
10
|
|
|
11
11
|
from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
|
|
12
12
|
from imblearn.under_sampling import RandomUnderSampler
|
|
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
|
|
|
19
19
|
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
|
|
20
20
|
import shap
|
|
21
21
|
|
|
22
|
-
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
|
|
22
|
+
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath, list_files_by_extension, deserialize_object
|
|
23
23
|
from .logger import _LOGGER
|
|
24
24
|
|
|
25
25
|
import warnings # Ignore warnings
|
|
@@ -38,7 +38,8 @@ __all__ = [
|
|
|
38
38
|
"evaluate_model_regression",
|
|
39
39
|
"get_shap_values",
|
|
40
40
|
"train_test_pipeline",
|
|
41
|
-
"run_ensemble_pipeline"
|
|
41
|
+
"run_ensemble_pipeline",
|
|
42
|
+
"InferenceHandler"
|
|
42
43
|
]
|
|
43
44
|
|
|
44
45
|
## Type aliases
|
|
@@ -937,5 +938,124 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
|
|
|
937
938
|
_LOGGER.info("✅ Training and evaluation complete.")
|
|
938
939
|
|
|
939
940
|
|
|
941
|
+
###### 6. Inference ######
|
|
942
|
+
class InferenceHandler:
|
|
943
|
+
"""
|
|
944
|
+
Handles loading ensemble models and performing inference for either regression or classification tasks.
|
|
945
|
+
"""
|
|
946
|
+
def __init__(self,
|
|
947
|
+
models_dir: Union[str,Path],
|
|
948
|
+
task: TaskType,
|
|
949
|
+
verbose: bool=True) -> None:
|
|
950
|
+
"""
|
|
951
|
+
Initializes the handler by loading all models from a directory.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
models_dir (Path): The directory containing the saved .joblib model files.
|
|
955
|
+
task ("regression" | "classification"): The type of task the models perform.
|
|
956
|
+
"""
|
|
957
|
+
self.models: Dict[str, Any] = dict()
|
|
958
|
+
self.task: str = task
|
|
959
|
+
self.verbose = verbose
|
|
960
|
+
self._feature_names: Optional[List[str]] = None
|
|
961
|
+
|
|
962
|
+
model_files = list_files_by_extension(directory=models_dir, extension="joblib")
|
|
963
|
+
|
|
964
|
+
for fname, fpath in model_files.items():
|
|
965
|
+
try:
|
|
966
|
+
full_object: dict
|
|
967
|
+
full_object = deserialize_object(filepath=fpath,
|
|
968
|
+
verbose=self.verbose,
|
|
969
|
+
raise_on_error=True) # type: ignore
|
|
970
|
+
|
|
971
|
+
model: Any = full_object["model"]
|
|
972
|
+
target_name: str = full_object["target_name"]
|
|
973
|
+
feature_names_list: List[str] = full_object["feature_names"]
|
|
974
|
+
|
|
975
|
+
# Check that feature names match
|
|
976
|
+
if self._feature_names is None:
|
|
977
|
+
# Store the feature names from the first model loaded.
|
|
978
|
+
self._feature_names = feature_names_list
|
|
979
|
+
elif self._feature_names != feature_names_list:
|
|
980
|
+
# Add a warning if subsequent models have different feature names.
|
|
981
|
+
_LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
|
|
982
|
+
|
|
983
|
+
self.models[target_name] = model
|
|
984
|
+
if self.verbose:
|
|
985
|
+
_LOGGER.info(f"✅ Loaded model for target: {target_name}")
|
|
986
|
+
|
|
987
|
+
except Exception as e:
|
|
988
|
+
_LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
|
|
989
|
+
|
|
990
|
+
@property
|
|
991
|
+
def feature_names(self) -> List[str]:
|
|
992
|
+
"""
|
|
993
|
+
Getter for the list of feature names the models expect.
|
|
994
|
+
Returns an empty list if no models were loaded.
|
|
995
|
+
"""
|
|
996
|
+
return self._feature_names if self._feature_names is not None else []
|
|
997
|
+
|
|
998
|
+
def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
|
999
|
+
"""
|
|
1000
|
+
Predicts on a single feature vector.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
features (np.ndarray): A 1D or 2D NumPy array for a single sample.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
1007
|
+
- For regression: The value is the single predicted float.
|
|
1008
|
+
- For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
|
|
1009
|
+
"""
|
|
1010
|
+
if features.ndim == 1:
|
|
1011
|
+
features = features.reshape(1, -1)
|
|
1012
|
+
|
|
1013
|
+
if features.shape[0] != 1:
|
|
1014
|
+
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
1015
|
+
|
|
1016
|
+
results: Dict[str, Any] = dict()
|
|
1017
|
+
for target_name, model in self.models.items():
|
|
1018
|
+
if self.task == "regression":
|
|
1019
|
+
prediction = model.predict(features)
|
|
1020
|
+
results[target_name] = prediction.item()
|
|
1021
|
+
else: # Classification
|
|
1022
|
+
label = model.predict(features)[0]
|
|
1023
|
+
probabilities = model.predict_proba(features)[0]
|
|
1024
|
+
results[target_name] = {"label": label, "probabilities": probabilities}
|
|
1025
|
+
|
|
1026
|
+
if self.verbose:
|
|
1027
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
1028
|
+
return results
|
|
1029
|
+
|
|
1030
|
+
def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
|
|
1031
|
+
"""
|
|
1032
|
+
Predicts on a batch of feature vectors.
|
|
1033
|
+
|
|
1034
|
+
Args:
|
|
1035
|
+
features (np.ndarray): A 2D NumPy array where each row is a sample.
|
|
1036
|
+
|
|
1037
|
+
Returns:
|
|
1038
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
1039
|
+
- For regression: The value is a NumPy array of predictions.
|
|
1040
|
+
- For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
|
|
1041
|
+
"""
|
|
1042
|
+
if features.ndim != 2:
|
|
1043
|
+
raise ValueError("Input for batch prediction must be a 2D array.")
|
|
1044
|
+
|
|
1045
|
+
results: Dict[str, Any] = dict()
|
|
1046
|
+
for target_name, model in self.models.items():
|
|
1047
|
+
if self.task == "regression":
|
|
1048
|
+
results[target_name] = model.predict(features)
|
|
1049
|
+
else: # Classification
|
|
1050
|
+
labels = model.predict(features)
|
|
1051
|
+
probabilities = model.predict_proba(features)
|
|
1052
|
+
results[target_name] = {"labels": labels, "probabilities": probabilities}
|
|
1053
|
+
|
|
1054
|
+
if self.verbose:
|
|
1055
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
1056
|
+
|
|
1057
|
+
return results
|
|
1058
|
+
|
|
1059
|
+
|
|
940
1060
|
def info():
|
|
941
1061
|
_script_info(__all__)
|
ml_tools/path_manager.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from pprint import pprint
|
|
2
|
+
from typing import Optional, List, Dict, Callable, Union
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from .utilities import _script_info
|
|
5
|
+
from .logger import _LOGGER
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"PathManager"
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PathManager:
|
|
14
|
+
"""
|
|
15
|
+
Manages and stores a project's file paths, acting as a centralized
|
|
16
|
+
"path database". It supports both development mode and applications
|
|
17
|
+
bundled with Briefcase.
|
|
18
|
+
|
|
19
|
+
Supports python dictionary syntax.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
anchor_file: str,
|
|
24
|
+
base_directories: Optional[List[str]] = None
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
The initializer determines the project's root directory and can pre-register
|
|
28
|
+
a list of base directories relative to that root.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
|
|
32
|
+
base_directories (Optional[List[str]]): A list of directory names located at the same level as the anchor file to be registered immediately.
|
|
33
|
+
"""
|
|
34
|
+
resolved_anchor_path = Path(anchor_file).resolve()
|
|
35
|
+
self._package_name = resolved_anchor_path.parent.name
|
|
36
|
+
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
37
|
+
self._paths: Dict[str, Path] = {}
|
|
38
|
+
|
|
39
|
+
if self._is_bundled:
|
|
40
|
+
# In a bundle, resource_path gives the absolute path to the 'app_packages' dir
|
|
41
|
+
# when given the package name.
|
|
42
|
+
package_root = self._resource_path_func(self._package_name) # type: ignore
|
|
43
|
+
else:
|
|
44
|
+
# In dev mode, the package root is the directory containing the anchor file.
|
|
45
|
+
package_root = resolved_anchor_path.parent
|
|
46
|
+
|
|
47
|
+
# Register the root of the package itself
|
|
48
|
+
self._paths["ROOT"] = package_root
|
|
49
|
+
|
|
50
|
+
# Register all the base directories
|
|
51
|
+
if base_directories:
|
|
52
|
+
for dir_name in base_directories:
|
|
53
|
+
# In dev mode, this is simple. In a bundle, we must resolve
|
|
54
|
+
# each path from the package root.
|
|
55
|
+
if self._is_bundled:
|
|
56
|
+
self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
|
|
57
|
+
else:
|
|
58
|
+
self._paths[dir_name] = package_root / dir_name
|
|
59
|
+
|
|
60
|
+
# A helper function to find the briefcase-injected resource function
|
|
61
|
+
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
62
|
+
"""Checks if the app is running in a Briefcase bundle."""
|
|
63
|
+
try:
|
|
64
|
+
# This function is injected by Briefcase into the global scope
|
|
65
|
+
from briefcase.platforms.base import resource_path # type: ignore
|
|
66
|
+
return True, resource_path
|
|
67
|
+
except (ImportError, NameError):
|
|
68
|
+
return False, None
|
|
69
|
+
|
|
70
|
+
def get(self, key: str) -> Path:
|
|
71
|
+
"""
|
|
72
|
+
Retrieves a stored path by its key.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
key (str): The key of the path to retrieve.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Path: The resolved, absolute Path object.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
KeyError: If the key is not found in the manager.
|
|
82
|
+
"""
|
|
83
|
+
try:
|
|
84
|
+
return self._paths[key]
|
|
85
|
+
except KeyError:
|
|
86
|
+
_LOGGER.error(f"❌ Path key '{key}' not found.")
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Adds new paths or overwrites existing ones in the manager.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
|
|
95
|
+
the identifiers and values are the
|
|
96
|
+
Path objects or strings to store.
|
|
97
|
+
overwrite (bool): If False (default), raises a KeyError if any
|
|
98
|
+
key in new_paths already exists. If True,
|
|
99
|
+
allows overwriting existing keys.
|
|
100
|
+
"""
|
|
101
|
+
if not overwrite:
|
|
102
|
+
for key in new_paths:
|
|
103
|
+
if key in self._paths:
|
|
104
|
+
raise KeyError(
|
|
105
|
+
f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Resolve any string paths to Path objects before storing
|
|
109
|
+
resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
|
|
110
|
+
self._paths.update(resolved_new_paths)
|
|
111
|
+
|
|
112
|
+
def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Creates directory structures for registered paths in writable locations.
|
|
115
|
+
|
|
116
|
+
This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
|
|
117
|
+
|
|
118
|
+
In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
|
|
119
|
+
will only operate on paths outside of the package (e.g., user data dirs).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
keys (Optional[List[str]]): If provided, only the directories
|
|
123
|
+
corresponding to these keys will be
|
|
124
|
+
created. If None (default), all
|
|
125
|
+
registered directory paths are used.
|
|
126
|
+
verbose (bool): If True, prints a message for each action.
|
|
127
|
+
"""
|
|
128
|
+
path_items = []
|
|
129
|
+
if keys:
|
|
130
|
+
for key in keys:
|
|
131
|
+
if key in self._paths:
|
|
132
|
+
path_items.append((key, self._paths[key]))
|
|
133
|
+
elif verbose:
|
|
134
|
+
_LOGGER.warning(f"⚠️ Key '{key}' not found in PathManager, skipping.")
|
|
135
|
+
else:
|
|
136
|
+
path_items = self._paths.items()
|
|
137
|
+
|
|
138
|
+
# Get the package root to check against.
|
|
139
|
+
package_root = self._paths.get("ROOT")
|
|
140
|
+
|
|
141
|
+
for key, path in path_items:
|
|
142
|
+
if path.suffix: # It's a file, not a directory
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
# --- THE CRITICAL CHECK ---
|
|
146
|
+
# Determine if the path is inside the main application package.
|
|
147
|
+
is_internal_path = package_root and path.is_relative_to(package_root)
|
|
148
|
+
|
|
149
|
+
if self._is_bundled and is_internal_path:
|
|
150
|
+
if verbose:
|
|
151
|
+
_LOGGER.warning(f"⚠️ Skipping internal directory '{key}' in bundled app (read-only).")
|
|
152
|
+
continue
|
|
153
|
+
# -------------------------
|
|
154
|
+
|
|
155
|
+
if verbose:
|
|
156
|
+
_LOGGER.info(f"📁 Ensuring directory exists for key '{key}': {path}")
|
|
157
|
+
|
|
158
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
def status(self) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Checks the status of all registered paths on the filesystem and prints a formatted report.
|
|
163
|
+
"""
|
|
164
|
+
report = {}
|
|
165
|
+
for key, path in self.items():
|
|
166
|
+
if path.is_dir():
|
|
167
|
+
report[key] = "📁 Directory"
|
|
168
|
+
elif path.is_file():
|
|
169
|
+
report[key] = "📄 File"
|
|
170
|
+
else:
|
|
171
|
+
report[key] = "❌ Not Found"
|
|
172
|
+
|
|
173
|
+
print("\n--- Path Status Report ---")
|
|
174
|
+
pprint(report)
|
|
175
|
+
|
|
176
|
+
def __repr__(self) -> str:
|
|
177
|
+
"""Provides a string representation of the stored paths."""
|
|
178
|
+
path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
|
|
179
|
+
return f"PathManager(\n{path_list}\n)"
|
|
180
|
+
|
|
181
|
+
# --- Dictionary-Style Methods ---
|
|
182
|
+
def __getitem__(self, key: str) -> Path:
|
|
183
|
+
"""Allows dictionary-style getting, e.g., PM['my_key']"""
|
|
184
|
+
return self.get(key)
|
|
185
|
+
|
|
186
|
+
def __setitem__(self, key: str, value: Union[str, Path]):
|
|
187
|
+
"""Allows dictionary-style setting, does not allow overwriting, e.g., PM['my_key'] = path"""
|
|
188
|
+
self.update({key: value}, overwrite=False)
|
|
189
|
+
|
|
190
|
+
def __contains__(self, key: str) -> bool:
|
|
191
|
+
"""Allows checking for a key's existence, e.g., if 'my_key' in PM"""
|
|
192
|
+
return key in self._paths
|
|
193
|
+
|
|
194
|
+
def __len__(self) -> int:
|
|
195
|
+
"""Allows getting the number of paths, e.g., len(PM)"""
|
|
196
|
+
return len(self._paths)
|
|
197
|
+
|
|
198
|
+
def keys(self):
|
|
199
|
+
"""Returns all registered path keys."""
|
|
200
|
+
return self._paths.keys()
|
|
201
|
+
|
|
202
|
+
def values(self):
|
|
203
|
+
"""Returns all registered Path objects."""
|
|
204
|
+
return self._paths.values()
|
|
205
|
+
|
|
206
|
+
def items(self):
|
|
207
|
+
"""Returns all registered (key, Path) pairs."""
|
|
208
|
+
return self._paths.items()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def info():
|
|
212
|
+
_script_info(__all__)
|
ml_tools/utilities.py
CHANGED
|
@@ -25,7 +25,7 @@ __all__ = [
|
|
|
25
25
|
"serialize_object",
|
|
26
26
|
"deserialize_object",
|
|
27
27
|
"distribute_datasets_by_target",
|
|
28
|
-
"train_dataset_orchestrator"
|
|
28
|
+
"train_dataset_orchestrator",
|
|
29
29
|
]
|
|
30
30
|
|
|
31
31
|
|
|
@@ -645,7 +645,7 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
|
|
|
645
645
|
|
|
646
646
|
class LogKeys:
|
|
647
647
|
"""
|
|
648
|
-
Used for ML scripts
|
|
648
|
+
Used internally for ML scripts.
|
|
649
649
|
|
|
650
650
|
Centralized keys for logging and history.
|
|
651
651
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|