dragon-ml-toolbox 3.8.0__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/METADATA +4 -3
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/RECORD +10 -9
- ml_tools/GUI_tools.py +82 -56
- ml_tools/ensemble_learning.py +123 -3
- ml_tools/path_manager.py +212 -0
- ml_tools/utilities.py +1 -205
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.8.0.dist-info → dragon_ml_toolbox-3.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.9.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Loza <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -15,8 +15,8 @@ License-File: LICENSE-THIRD-PARTY.md
|
|
|
15
15
|
Requires-Dist: numpy<2.0
|
|
16
16
|
Requires-Dist: scikit-learn
|
|
17
17
|
Requires-Dist: openpyxl
|
|
18
|
-
Requires-Dist: miceforest
|
|
19
|
-
Requires-Dist: plotnine
|
|
18
|
+
Requires-Dist: miceforest>=6.0.0
|
|
19
|
+
Requires-Dist: plotnine>=0.12
|
|
20
20
|
Requires-Dist: matplotlib
|
|
21
21
|
Requires-Dist: seaborn
|
|
22
22
|
Requires-Dist: pandas
|
|
@@ -129,6 +129,7 @@ ML_callbacks
|
|
|
129
129
|
ML_evaluation
|
|
130
130
|
ML_trainer
|
|
131
131
|
ML_tutorial
|
|
132
|
+
path_manager
|
|
132
133
|
PSO_optimization
|
|
133
134
|
RNN_forecast
|
|
134
135
|
utilities
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
dragon_ml_toolbox-3.
|
|
2
|
-
dragon_ml_toolbox-3.
|
|
1
|
+
dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
|
|
3
3
|
ml_tools/ETL_engineering.py,sha256=yeZsW_7zRvEcuMZbM4E2GV1dxwBoWIeJAcFFk2AK0fY,39502
|
|
4
|
-
ml_tools/GUI_tools.py,sha256=
|
|
4
|
+
ml_tools/GUI_tools.py,sha256=ABR1cqV09iZ2DbLfLZB7jaQVRVDbvCmj09pNkr3TDZk,18800
|
|
5
5
|
ml_tools/MICE_imputation.py,sha256=rYqvwQDVtoAJJ0agXWoGzoZEHedWiA6QzcEKEIkiZ08,11388
|
|
6
6
|
ml_tools/ML_callbacks.py,sha256=OT2zwORLcn49megBEgXsSUxDHoW0Ft0_v7hLEVF3jHM,13063
|
|
7
7
|
ml_tools/ML_evaluation.py,sha256=oiDV6HItQloUUKCUpltV-2pogubWLBieGpc-VUwosAQ,10106
|
|
@@ -15,11 +15,12 @@ ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9w
|
|
|
15
15
|
ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
|
|
16
16
|
ml_tools/data_exploration.py,sha256=M7bn2q5XN9zJZJGAmMMFSFFZh8LGzC2arFelrXw3N6Q,25241
|
|
17
17
|
ml_tools/datasetmaster.py,sha256=S3PKHNQZ9cyAOck8xQltVLZhaD1gFLfgHFL-aRjz4JU,30077
|
|
18
|
-
ml_tools/ensemble_learning.py,sha256=
|
|
18
|
+
ml_tools/ensemble_learning.py,sha256=p9PZwGY2OGSrJhXNzvMS_kCjK-I2JVcqiJBaVzb0GrM,42616
|
|
19
19
|
ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
|
|
20
20
|
ml_tools/logger.py,sha256=UkbiU9ihBhw9VKyn3rZzisdClWV94EBV6B09_D0iUU0,6026
|
|
21
|
-
ml_tools/
|
|
22
|
-
|
|
23
|
-
dragon_ml_toolbox-3.
|
|
24
|
-
dragon_ml_toolbox-3.
|
|
25
|
-
dragon_ml_toolbox-3.
|
|
21
|
+
ml_tools/path_manager.py,sha256=OCpESgdftbi6mOxetDMIaHhazt4N-W8pJx11X3-yNOs,8305
|
|
22
|
+
ml_tools/utilities.py,sha256=HR36Q_vYnaRcpSjpNISnA7lOZ36TouHop38lPLG_twY,23146
|
|
23
|
+
dragon_ml_toolbox-3.9.0.dist-info/METADATA,sha256=2R3xIuefuR9O_h71q3S49xUm2MLKQtn12jjwNFKl2mE,3273
|
|
24
|
+
dragon_ml_toolbox-3.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
25
|
+
dragon_ml_toolbox-3.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
26
|
+
dragon_ml_toolbox-3.9.0.dist-info/RECORD,,
|
ml_tools/GUI_tools.py
CHANGED
|
@@ -8,13 +8,14 @@ from typing import Any, Dict, Tuple, List, Literal
|
|
|
8
8
|
from .utilities import _script_info
|
|
9
9
|
import numpy as np
|
|
10
10
|
from .logger import _LOGGER
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"ConfigManager",
|
|
15
16
|
"GUIFactory",
|
|
16
17
|
"catch_exceptions",
|
|
17
|
-
"
|
|
18
|
+
"BaseFeatureHandler",
|
|
18
19
|
"update_target_fields"
|
|
19
20
|
]
|
|
20
21
|
|
|
@@ -351,68 +352,93 @@ def catch_exceptions(show_popup: bool = True):
|
|
|
351
352
|
return decorator
|
|
352
353
|
|
|
353
354
|
|
|
354
|
-
# --- Inference
|
|
355
|
-
|
|
355
|
+
# --- Inference Helper ---
|
|
356
|
+
class BaseFeatureHandler(ABC):
|
|
356
357
|
"""
|
|
357
|
-
|
|
358
|
-
Returns a list containing a single float.
|
|
359
|
-
"""
|
|
360
|
-
return [1.0] if str(chosen_value) == 'True' else [0.0]
|
|
361
|
-
|
|
362
|
-
def prepare_feature_vector(
|
|
363
|
-
window_values: Dict[str, Any],
|
|
364
|
-
gui_feature_order: List[str],
|
|
365
|
-
continuous_features: List[str],
|
|
366
|
-
categorical_features: List[str],
|
|
367
|
-
categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
|
|
368
|
-
) -> np.ndarray:
|
|
369
|
-
"""
|
|
370
|
-
Validates and converts GUI values into a numpy array for a model.
|
|
371
|
-
This function supports label encoding and one-hot encoding via the processor.
|
|
358
|
+
An abstract base class that defines the template for preparing a model input feature vector to perform inference, from GUI inputs.
|
|
372
359
|
|
|
373
|
-
|
|
374
|
-
window_values (dict): The values dictionary from a `window.read()` call.
|
|
375
|
-
gui_feature_order (list): A list of all feature names that have a GUI element.
|
|
376
|
-
For one-hot encoding, this should be the name of the
|
|
377
|
-
single GUI element (e.g., 'material_type'), not the
|
|
378
|
-
expanded feature names (e.g., 'material_is_steel').
|
|
379
|
-
continuous_features (list): A list of names for continuous features.
|
|
380
|
-
categorical_features (list): A list of names for categorical features.
|
|
381
|
-
categorical_processor (callable, optional): A function to process categorical
|
|
382
|
-
values. It should accept (feature_name, chosen_value) and return a
|
|
383
|
-
list of floats (e.g., [1.0] for label encoding, [0.0, 1.0, 0.0] for one-hot).
|
|
384
|
-
If None, a default 'True'/'False' processor is used.
|
|
385
|
-
|
|
386
|
-
Returns:
|
|
387
|
-
A 1D numpy array ready for model inference.
|
|
360
|
+
A subclass must implement the `gui_input_map` property and the `process_categorical` method.
|
|
388
361
|
"""
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
362
|
+
def __init__(self, expected_columns_in_order: list[str]):
|
|
363
|
+
"""
|
|
364
|
+
Validates and stores the feature names in the order the model expects.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
expected_columns_in_order (List[str]): A list of strings with the feature names in the correct order.
|
|
368
|
+
"""
|
|
369
|
+
# --- Validation Logic ---
|
|
370
|
+
if not isinstance(expected_columns_in_order, list):
|
|
371
|
+
raise TypeError("Input 'expected_columns_in_order' must be a list.")
|
|
372
|
+
|
|
373
|
+
if not all(isinstance(col, str) for col in expected_columns_in_order):
|
|
374
|
+
raise TypeError("All elements in the 'expected_columns_in_order' list must be strings.")
|
|
375
|
+
# -----------------------
|
|
376
|
+
|
|
377
|
+
self._model_feature_order = expected_columns_in_order
|
|
400
378
|
|
|
401
|
-
|
|
402
|
-
|
|
379
|
+
@property
|
|
380
|
+
@abstractmethod
|
|
381
|
+
def gui_input_map(self) -> Dict[str, Literal["continuous","categorical"]]:
|
|
382
|
+
"""
|
|
383
|
+
Must be implemented by the subclass.
|
|
403
384
|
|
|
404
|
-
|
|
405
|
-
try:
|
|
406
|
-
processed_values.append(float(chosen_value))
|
|
407
|
-
except (ValueError, TypeError):
|
|
408
|
-
raise ValueError(f"Invalid input for '{name}'. Please enter a valid number.")
|
|
385
|
+
Should return a dictionary mapping each GUI input name to its type ('continuous' or 'categorical').
|
|
409
386
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
387
|
+
```python
|
|
388
|
+
#Example:
|
|
389
|
+
{'temperature': 'continuous', 'material_type': 'categorical'}
|
|
390
|
+
```
|
|
391
|
+
"""
|
|
392
|
+
pass
|
|
393
|
+
|
|
394
|
+
@abstractmethod
|
|
395
|
+
def process_categorical(self, feature_name: str, chosen_value: Any) -> Dict[str, float]:
|
|
396
|
+
"""
|
|
397
|
+
Must be implemented by the subclass.
|
|
398
|
+
|
|
399
|
+
Should take a GUI categorical feature name and its chosen value, and return a dictionary mapping the one-hot-encoded feature names to their
|
|
400
|
+
float values (as expected by the inference model).
|
|
401
|
+
"""
|
|
402
|
+
pass
|
|
403
|
+
|
|
404
|
+
def __call__(self, window_values: Dict[str, Any]) -> np.ndarray:
|
|
405
|
+
"""
|
|
406
|
+
Performs the full vector preparation, returning a 1D numpy array.
|
|
407
|
+
|
|
408
|
+
Should not be overridden by subclasses.
|
|
409
|
+
"""
|
|
410
|
+
# Stage 1: Process GUI inputs into a dictionary
|
|
411
|
+
processed_features: Dict[str, float] = {}
|
|
412
|
+
for gui_name, feature_type in self.gui_input_map.items():
|
|
413
|
+
chosen_value = window_values.get(gui_name)
|
|
414
|
+
|
|
415
|
+
if chosen_value is None or str(chosen_value) == '':
|
|
416
|
+
raise ValueError(f"GUI input '{gui_name}' is missing a value.")
|
|
417
|
+
|
|
418
|
+
if feature_type == 'continuous':
|
|
419
|
+
try:
|
|
420
|
+
processed_features[gui_name] = float(chosen_value)
|
|
421
|
+
except (ValueError, TypeError):
|
|
422
|
+
raise ValueError(f"Invalid number '{chosen_value}' for '{gui_name}'.")
|
|
423
|
+
|
|
424
|
+
elif feature_type == 'categorical':
|
|
425
|
+
feature_dict = self.process_categorical(gui_name, chosen_value)
|
|
426
|
+
processed_features.update(feature_dict)
|
|
427
|
+
|
|
428
|
+
# Stage 2: Assemble the final vector using the model's required order
|
|
429
|
+
final_vector: List[float] = []
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
for feature_name in self._model_feature_order:
|
|
433
|
+
final_vector.append(processed_features[feature_name])
|
|
434
|
+
except KeyError as e:
|
|
435
|
+
raise RuntimeError(
|
|
436
|
+
f"Configuration Error: Implemented methods failed to generate "
|
|
437
|
+
f"the required model feature: '{e}'"
|
|
438
|
+
f"Check the gui_input_map and process_categorical logic."
|
|
439
|
+
)
|
|
414
440
|
|
|
415
|
-
|
|
441
|
+
return np.array(final_vector, dtype=np.float32)
|
|
416
442
|
|
|
417
443
|
|
|
418
444
|
def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
|
ml_tools/ensemble_learning.py
CHANGED
|
@@ -6,7 +6,7 @@ from matplotlib.colors import Colormap
|
|
|
6
6
|
from matplotlib import rcdefaults
|
|
7
7
|
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Literal, Union, Optional, Iterator, Tuple
|
|
9
|
+
from typing import Literal, Union, Optional, Iterator, Tuple, Dict, Any, List
|
|
10
10
|
|
|
11
11
|
from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
|
|
12
12
|
from imblearn.under_sampling import RandomUnderSampler
|
|
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
|
|
|
19
19
|
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
|
|
20
20
|
import shap
|
|
21
21
|
|
|
22
|
-
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
|
|
22
|
+
from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath, list_files_by_extension, deserialize_object
|
|
23
23
|
from .logger import _LOGGER
|
|
24
24
|
|
|
25
25
|
import warnings # Ignore warnings
|
|
@@ -38,7 +38,8 @@ __all__ = [
|
|
|
38
38
|
"evaluate_model_regression",
|
|
39
39
|
"get_shap_values",
|
|
40
40
|
"train_test_pipeline",
|
|
41
|
-
"run_ensemble_pipeline"
|
|
41
|
+
"run_ensemble_pipeline",
|
|
42
|
+
"InferenceHandler"
|
|
42
43
|
]
|
|
43
44
|
|
|
44
45
|
## Type aliases
|
|
@@ -937,5 +938,124 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
|
|
|
937
938
|
_LOGGER.info("✅ Training and evaluation complete.")
|
|
938
939
|
|
|
939
940
|
|
|
941
|
+
###### 6. Inference ######
|
|
942
|
+
class InferenceHandler:
|
|
943
|
+
"""
|
|
944
|
+
Handles loading ensemble models and performing inference for either regression or classification tasks.
|
|
945
|
+
"""
|
|
946
|
+
def __init__(self,
|
|
947
|
+
models_dir: Union[str,Path],
|
|
948
|
+
task: TaskType,
|
|
949
|
+
verbose: bool=True) -> None:
|
|
950
|
+
"""
|
|
951
|
+
Initializes the handler by loading all models from a directory.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
models_dir (Path): The directory containing the saved .joblib model files.
|
|
955
|
+
task ("regression" | "classification"): The type of task the models perform.
|
|
956
|
+
"""
|
|
957
|
+
self.models: Dict[str, Any] = dict()
|
|
958
|
+
self.task: str = task
|
|
959
|
+
self.verbose = verbose
|
|
960
|
+
self._feature_names: Optional[List[str]] = None
|
|
961
|
+
|
|
962
|
+
model_files = list_files_by_extension(directory=models_dir, extension="joblib")
|
|
963
|
+
|
|
964
|
+
for fname, fpath in model_files.items():
|
|
965
|
+
try:
|
|
966
|
+
full_object: dict
|
|
967
|
+
full_object = deserialize_object(filepath=fpath,
|
|
968
|
+
verbose=self.verbose,
|
|
969
|
+
raise_on_error=True) # type: ignore
|
|
970
|
+
|
|
971
|
+
model: Any = full_object["model"]
|
|
972
|
+
target_name: str = full_object["target_name"]
|
|
973
|
+
feature_names_list: List[str] = full_object["feature_names"]
|
|
974
|
+
|
|
975
|
+
# Check that feature names match
|
|
976
|
+
if self._feature_names is None:
|
|
977
|
+
# Store the feature names from the first model loaded.
|
|
978
|
+
self._feature_names = feature_names_list
|
|
979
|
+
elif self._feature_names != feature_names_list:
|
|
980
|
+
# Add a warning if subsequent models have different feature names.
|
|
981
|
+
_LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
|
|
982
|
+
|
|
983
|
+
self.models[target_name] = model
|
|
984
|
+
if self.verbose:
|
|
985
|
+
_LOGGER.info(f"✅ Loaded model for target: {target_name}")
|
|
986
|
+
|
|
987
|
+
except Exception as e:
|
|
988
|
+
_LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
|
|
989
|
+
|
|
990
|
+
@property
|
|
991
|
+
def feature_names(self) -> List[str]:
|
|
992
|
+
"""
|
|
993
|
+
Getter for the list of feature names the models expect.
|
|
994
|
+
Returns an empty list if no models were loaded.
|
|
995
|
+
"""
|
|
996
|
+
return self._feature_names if self._feature_names is not None else []
|
|
997
|
+
|
|
998
|
+
def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
|
999
|
+
"""
|
|
1000
|
+
Predicts on a single feature vector.
|
|
1001
|
+
|
|
1002
|
+
Args:
|
|
1003
|
+
features (np.ndarray): A 1D or 2D NumPy array for a single sample.
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
1007
|
+
- For regression: The value is the single predicted float.
|
|
1008
|
+
- For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
|
|
1009
|
+
"""
|
|
1010
|
+
if features.ndim == 1:
|
|
1011
|
+
features = features.reshape(1, -1)
|
|
1012
|
+
|
|
1013
|
+
if features.shape[0] != 1:
|
|
1014
|
+
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
1015
|
+
|
|
1016
|
+
results: Dict[str, Any] = dict()
|
|
1017
|
+
for target_name, model in self.models.items():
|
|
1018
|
+
if self.task == "regression":
|
|
1019
|
+
prediction = model.predict(features)
|
|
1020
|
+
results[target_name] = prediction.item()
|
|
1021
|
+
else: # Classification
|
|
1022
|
+
label = model.predict(features)[0]
|
|
1023
|
+
probabilities = model.predict_proba(features)[0]
|
|
1024
|
+
results[target_name] = {"label": label, "probabilities": probabilities}
|
|
1025
|
+
|
|
1026
|
+
if self.verbose:
|
|
1027
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
1028
|
+
return results
|
|
1029
|
+
|
|
1030
|
+
def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
|
|
1031
|
+
"""
|
|
1032
|
+
Predicts on a batch of feature vectors.
|
|
1033
|
+
|
|
1034
|
+
Args:
|
|
1035
|
+
features (np.ndarray): A 2D NumPy array where each row is a sample.
|
|
1036
|
+
|
|
1037
|
+
Returns:
|
|
1038
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
1039
|
+
- For regression: The value is a NumPy array of predictions.
|
|
1040
|
+
- For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
|
|
1041
|
+
"""
|
|
1042
|
+
if features.ndim != 2:
|
|
1043
|
+
raise ValueError("Input for batch prediction must be a 2D array.")
|
|
1044
|
+
|
|
1045
|
+
results: Dict[str, Any] = dict()
|
|
1046
|
+
for target_name, model in self.models.items():
|
|
1047
|
+
if self.task == "regression":
|
|
1048
|
+
results[target_name] = model.predict(features)
|
|
1049
|
+
else: # Classification
|
|
1050
|
+
labels = model.predict(features)
|
|
1051
|
+
probabilities = model.predict_proba(features)
|
|
1052
|
+
results[target_name] = {"labels": labels, "probabilities": probabilities}
|
|
1053
|
+
|
|
1054
|
+
if self.verbose:
|
|
1055
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
1056
|
+
|
|
1057
|
+
return results
|
|
1058
|
+
|
|
1059
|
+
|
|
940
1060
|
def info():
|
|
941
1061
|
_script_info(__all__)
|
ml_tools/path_manager.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from pprint import pprint
|
|
2
|
+
from typing import Optional, List, Dict, Callable, Union
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from .utilities import _script_info
|
|
5
|
+
from .logger import _LOGGER
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"PathManager"
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PathManager:
|
|
14
|
+
"""
|
|
15
|
+
Manages and stores a project's file paths, acting as a centralized
|
|
16
|
+
"path database". It supports both development mode and applications
|
|
17
|
+
bundled with Briefcase.
|
|
18
|
+
|
|
19
|
+
Supports python dictionary syntax.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
anchor_file: str,
|
|
24
|
+
base_directories: Optional[List[str]] = None
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
The initializer determines the project's root directory and can pre-register
|
|
28
|
+
a list of base directories relative to that root.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
|
|
32
|
+
base_directories (Optional[List[str]]): A list of directory names located at the same level as the anchor file to be registered immediately.
|
|
33
|
+
"""
|
|
34
|
+
resolved_anchor_path = Path(anchor_file).resolve()
|
|
35
|
+
self._package_name = resolved_anchor_path.parent.name
|
|
36
|
+
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
37
|
+
self._paths: Dict[str, Path] = {}
|
|
38
|
+
|
|
39
|
+
if self._is_bundled:
|
|
40
|
+
# In a bundle, resource_path gives the absolute path to the 'app_packages' dir
|
|
41
|
+
# when given the package name.
|
|
42
|
+
package_root = self._resource_path_func(self._package_name) # type: ignore
|
|
43
|
+
else:
|
|
44
|
+
# In dev mode, the package root is the directory containing the anchor file.
|
|
45
|
+
package_root = resolved_anchor_path.parent
|
|
46
|
+
|
|
47
|
+
# Register the root of the package itself
|
|
48
|
+
self._paths["ROOT"] = package_root
|
|
49
|
+
|
|
50
|
+
# Register all the base directories
|
|
51
|
+
if base_directories:
|
|
52
|
+
for dir_name in base_directories:
|
|
53
|
+
# In dev mode, this is simple. In a bundle, we must resolve
|
|
54
|
+
# each path from the package root.
|
|
55
|
+
if self._is_bundled:
|
|
56
|
+
self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
|
|
57
|
+
else:
|
|
58
|
+
self._paths[dir_name] = package_root / dir_name
|
|
59
|
+
|
|
60
|
+
# A helper function to find the briefcase-injected resource function
|
|
61
|
+
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
62
|
+
"""Checks if the app is running in a Briefcase bundle."""
|
|
63
|
+
try:
|
|
64
|
+
# This function is injected by Briefcase into the global scope
|
|
65
|
+
from briefcase.platforms.base import resource_path # type: ignore
|
|
66
|
+
return True, resource_path
|
|
67
|
+
except (ImportError, NameError):
|
|
68
|
+
return False, None
|
|
69
|
+
|
|
70
|
+
def get(self, key: str) -> Path:
|
|
71
|
+
"""
|
|
72
|
+
Retrieves a stored path by its key.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
key (str): The key of the path to retrieve.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Path: The resolved, absolute Path object.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
KeyError: If the key is not found in the manager.
|
|
82
|
+
"""
|
|
83
|
+
try:
|
|
84
|
+
return self._paths[key]
|
|
85
|
+
except KeyError:
|
|
86
|
+
_LOGGER.error(f"❌ Path key '{key}' not found.")
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Adds new paths or overwrites existing ones in the manager.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
|
|
95
|
+
the identifiers and values are the
|
|
96
|
+
Path objects or strings to store.
|
|
97
|
+
overwrite (bool): If False (default), raises a KeyError if any
|
|
98
|
+
key in new_paths already exists. If True,
|
|
99
|
+
allows overwriting existing keys.
|
|
100
|
+
"""
|
|
101
|
+
if not overwrite:
|
|
102
|
+
for key in new_paths:
|
|
103
|
+
if key in self._paths:
|
|
104
|
+
raise KeyError(
|
|
105
|
+
f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Resolve any string paths to Path objects before storing
|
|
109
|
+
resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
|
|
110
|
+
self._paths.update(resolved_new_paths)
|
|
111
|
+
|
|
112
|
+
def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Creates directory structures for registered paths in writable locations.
|
|
115
|
+
|
|
116
|
+
This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
|
|
117
|
+
|
|
118
|
+
In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
|
|
119
|
+
will only operate on paths outside of the package (e.g., user data dirs).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
keys (Optional[List[str]]): If provided, only the directories
|
|
123
|
+
corresponding to these keys will be
|
|
124
|
+
created. If None (default), all
|
|
125
|
+
registered directory paths are used.
|
|
126
|
+
verbose (bool): If True, prints a message for each action.
|
|
127
|
+
"""
|
|
128
|
+
path_items = []
|
|
129
|
+
if keys:
|
|
130
|
+
for key in keys:
|
|
131
|
+
if key in self._paths:
|
|
132
|
+
path_items.append((key, self._paths[key]))
|
|
133
|
+
elif verbose:
|
|
134
|
+
_LOGGER.warning(f"⚠️ Key '{key}' not found in PathManager, skipping.")
|
|
135
|
+
else:
|
|
136
|
+
path_items = self._paths.items()
|
|
137
|
+
|
|
138
|
+
# Get the package root to check against.
|
|
139
|
+
package_root = self._paths.get("ROOT")
|
|
140
|
+
|
|
141
|
+
for key, path in path_items:
|
|
142
|
+
if path.suffix: # It's a file, not a directory
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
# --- THE CRITICAL CHECK ---
|
|
146
|
+
# Determine if the path is inside the main application package.
|
|
147
|
+
is_internal_path = package_root and path.is_relative_to(package_root)
|
|
148
|
+
|
|
149
|
+
if self._is_bundled and is_internal_path:
|
|
150
|
+
if verbose:
|
|
151
|
+
_LOGGER.warning(f"⚠️ Skipping internal directory '{key}' in bundled app (read-only).")
|
|
152
|
+
continue
|
|
153
|
+
# -------------------------
|
|
154
|
+
|
|
155
|
+
if verbose:
|
|
156
|
+
_LOGGER.info(f"📁 Ensuring directory exists for key '{key}': {path}")
|
|
157
|
+
|
|
158
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
def status(self) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Checks the status of all registered paths on the filesystem and prints a formatted report.
|
|
163
|
+
"""
|
|
164
|
+
report = {}
|
|
165
|
+
for key, path in self.items():
|
|
166
|
+
if path.is_dir():
|
|
167
|
+
report[key] = "📁 Directory"
|
|
168
|
+
elif path.is_file():
|
|
169
|
+
report[key] = "📄 File"
|
|
170
|
+
else:
|
|
171
|
+
report[key] = "❌ Not Found"
|
|
172
|
+
|
|
173
|
+
print("\n--- Path Status Report ---")
|
|
174
|
+
pprint(report)
|
|
175
|
+
|
|
176
|
+
def __repr__(self) -> str:
|
|
177
|
+
"""Provides a string representation of the stored paths."""
|
|
178
|
+
path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
|
|
179
|
+
return f"PathManager(\n{path_list}\n)"
|
|
180
|
+
|
|
181
|
+
# --- Dictionary-Style Methods ---
|
|
182
|
+
def __getitem__(self, key: str) -> Path:
|
|
183
|
+
"""Allows dictionary-style getting, e.g., PM['my_key']"""
|
|
184
|
+
return self.get(key)
|
|
185
|
+
|
|
186
|
+
def __setitem__(self, key: str, value: Union[str, Path]):
|
|
187
|
+
"""Allows dictionary-style setting, does not allow overwriting, e.g., PM['my_key'] = path"""
|
|
188
|
+
self.update({key: value}, overwrite=False)
|
|
189
|
+
|
|
190
|
+
def __contains__(self, key: str) -> bool:
|
|
191
|
+
"""Allows checking for a key's existence, e.g., if 'my_key' in PM"""
|
|
192
|
+
return key in self._paths
|
|
193
|
+
|
|
194
|
+
def __len__(self) -> int:
|
|
195
|
+
"""Allows getting the number of paths, e.g., len(PM)"""
|
|
196
|
+
return len(self._paths)
|
|
197
|
+
|
|
198
|
+
def keys(self):
|
|
199
|
+
"""Returns all registered path keys."""
|
|
200
|
+
return self._paths.keys()
|
|
201
|
+
|
|
202
|
+
def values(self):
|
|
203
|
+
"""Returns all registered Path objects."""
|
|
204
|
+
return self._paths.values()
|
|
205
|
+
|
|
206
|
+
def items(self):
|
|
207
|
+
"""Returns all registered (key, Path) pairs."""
|
|
208
|
+
return self._paths.items()
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def info():
|
|
212
|
+
_script_info(__all__)
|
ml_tools/utilities.py
CHANGED
|
@@ -4,10 +4,9 @@ import pandas as pd
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
import re
|
|
7
|
-
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple
|
|
7
|
+
from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple
|
|
8
8
|
import joblib
|
|
9
9
|
from joblib.externals.loky.process_executor import TerminatedWorkerError
|
|
10
|
-
from pprint import pprint
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
# Keep track of available tools
|
|
@@ -27,7 +26,6 @@ __all__ = [
|
|
|
27
26
|
"deserialize_object",
|
|
28
27
|
"distribute_datasets_by_target",
|
|
29
28
|
"train_dataset_orchestrator",
|
|
30
|
-
"PathManager"
|
|
31
29
|
]
|
|
32
30
|
|
|
33
31
|
|
|
@@ -645,208 +643,6 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
|
|
|
645
643
|
print(f"\n✅ {total_saved} single-target datasets were created.")
|
|
646
644
|
|
|
647
645
|
|
|
648
|
-
### Path Manager
|
|
649
|
-
class PathManager:
|
|
650
|
-
"""
|
|
651
|
-
Manages and stores a project's file paths, acting as a centralized
|
|
652
|
-
"path database". It supports both development mode and applications
|
|
653
|
-
bundled with Briefcase.
|
|
654
|
-
|
|
655
|
-
Supports python dictionary syntax.
|
|
656
|
-
"""
|
|
657
|
-
def __init__(
|
|
658
|
-
self,
|
|
659
|
-
anchor_file: str,
|
|
660
|
-
base_directories: Optional[List[str]] = None
|
|
661
|
-
):
|
|
662
|
-
"""
|
|
663
|
-
The initializer determines the project's root directory and can pre-register
|
|
664
|
-
a list of base directories relative to that root.
|
|
665
|
-
|
|
666
|
-
Args:
|
|
667
|
-
anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
|
|
668
|
-
base_directories (Optional[List[str]]): A list of directory names
|
|
669
|
-
located at the same level as the anchor file's
|
|
670
|
-
parent directory to register immediately.
|
|
671
|
-
"""
|
|
672
|
-
resolved_anchor_path = Path(anchor_file).resolve()
|
|
673
|
-
self._package_name = resolved_anchor_path.parent.name
|
|
674
|
-
self._is_bundled, self._resource_path_func = self._check_bundle_status()
|
|
675
|
-
self._paths: Dict[str, Path] = {}
|
|
676
|
-
|
|
677
|
-
if self._is_bundled:
|
|
678
|
-
# In a bundle, resource_path gives the absolute path to the 'app_packages' dir
|
|
679
|
-
# when given the package name.
|
|
680
|
-
package_root = self._resource_path_func(self._package_name) # type: ignore
|
|
681
|
-
else:
|
|
682
|
-
# In dev mode, the package root is the directory containing the anchor file.
|
|
683
|
-
package_root = resolved_anchor_path.parent
|
|
684
|
-
|
|
685
|
-
# Register the root of the package itself
|
|
686
|
-
self._paths["ROOT"] = package_root
|
|
687
|
-
|
|
688
|
-
# Register all the base directories
|
|
689
|
-
if base_directories:
|
|
690
|
-
for dir_name in base_directories:
|
|
691
|
-
# In dev mode, this is simple. In a bundle, we must resolve
|
|
692
|
-
# each path from the package root.
|
|
693
|
-
if self._is_bundled:
|
|
694
|
-
self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
|
|
695
|
-
else:
|
|
696
|
-
self._paths[dir_name] = package_root / dir_name
|
|
697
|
-
|
|
698
|
-
# A helper function to find the briefcase-injected resource function
|
|
699
|
-
def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
|
|
700
|
-
"""Checks if the app is running in a Briefcase bundle."""
|
|
701
|
-
try:
|
|
702
|
-
# This function is injected by Briefcase into the global scope
|
|
703
|
-
from briefcase.platforms.base import resource_path # type: ignore
|
|
704
|
-
return True, resource_path
|
|
705
|
-
except (ImportError, NameError):
|
|
706
|
-
return False, None
|
|
707
|
-
|
|
708
|
-
def get(self, key: str) -> Path:
|
|
709
|
-
"""
|
|
710
|
-
Retrieves a stored path by its key.
|
|
711
|
-
|
|
712
|
-
Args:
|
|
713
|
-
key (str): The key of the path to retrieve.
|
|
714
|
-
|
|
715
|
-
Returns:
|
|
716
|
-
Path: The resolved, absolute Path object.
|
|
717
|
-
|
|
718
|
-
Raises:
|
|
719
|
-
KeyError: If the key is not found in the manager.
|
|
720
|
-
"""
|
|
721
|
-
try:
|
|
722
|
-
return self._paths[key]
|
|
723
|
-
except KeyError:
|
|
724
|
-
print(f"❌ Path key '{key}' not found.")
|
|
725
|
-
# Consider suggesting close matches if you want to get fancy
|
|
726
|
-
raise
|
|
727
|
-
|
|
728
|
-
def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
|
|
729
|
-
"""
|
|
730
|
-
Adds new paths or overwrites existing ones in the manager.
|
|
731
|
-
|
|
732
|
-
Args:
|
|
733
|
-
new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
|
|
734
|
-
the identifiers and values are the
|
|
735
|
-
Path objects or strings to store.
|
|
736
|
-
overwrite (bool): If False (default), raises a KeyError if any
|
|
737
|
-
key in new_paths already exists. If True,
|
|
738
|
-
allows overwriting existing keys.
|
|
739
|
-
"""
|
|
740
|
-
if not overwrite:
|
|
741
|
-
for key in new_paths:
|
|
742
|
-
if key in self._paths:
|
|
743
|
-
raise KeyError(
|
|
744
|
-
f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
# Resolve any string paths to Path objects before storing
|
|
748
|
-
resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
|
|
749
|
-
self._paths.update(resolved_new_paths)
|
|
750
|
-
|
|
751
|
-
def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
|
|
752
|
-
"""
|
|
753
|
-
Creates directory structures for registered paths in writable locations.
|
|
754
|
-
|
|
755
|
-
This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
|
|
756
|
-
|
|
757
|
-
In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
|
|
758
|
-
will only operate on paths outside of the package (e.g., user data dirs).
|
|
759
|
-
|
|
760
|
-
Args:
|
|
761
|
-
keys (Optional[List[str]]): If provided, only the directories
|
|
762
|
-
corresponding to these keys will be
|
|
763
|
-
created. If None (default), all
|
|
764
|
-
registered directory paths are used.
|
|
765
|
-
verbose (bool): If True, prints a message for each action.
|
|
766
|
-
"""
|
|
767
|
-
path_items = []
|
|
768
|
-
if keys:
|
|
769
|
-
for key in keys:
|
|
770
|
-
if key in self._paths:
|
|
771
|
-
path_items.append((key, self._paths[key]))
|
|
772
|
-
elif verbose:
|
|
773
|
-
print(f"⚠️ Key '{key}' not found in PathManager, skipping.")
|
|
774
|
-
else:
|
|
775
|
-
path_items = self._paths.items()
|
|
776
|
-
|
|
777
|
-
# Get the package root to check against.
|
|
778
|
-
package_root = self._paths.get("ROOT")
|
|
779
|
-
|
|
780
|
-
for key, path in path_items:
|
|
781
|
-
if path.suffix: # It's a file, not a directory
|
|
782
|
-
continue
|
|
783
|
-
|
|
784
|
-
# --- THE CRITICAL CHECK ---
|
|
785
|
-
# Determine if the path is inside the main application package.
|
|
786
|
-
is_internal_path = package_root and path.is_relative_to(package_root)
|
|
787
|
-
|
|
788
|
-
if self._is_bundled and is_internal_path:
|
|
789
|
-
if verbose:
|
|
790
|
-
print(f"ℹ️ Skipping internal directory '{key}' in bundled app (read-only).")
|
|
791
|
-
continue
|
|
792
|
-
# -------------------------
|
|
793
|
-
|
|
794
|
-
if verbose:
|
|
795
|
-
print(f"📁 Ensuring directory exists for key '{key}': {path}")
|
|
796
|
-
|
|
797
|
-
path.mkdir(parents=True, exist_ok=True)
|
|
798
|
-
|
|
799
|
-
def status(self) -> None:
|
|
800
|
-
"""
|
|
801
|
-
Checks the status of all registered paths on the filesystem and prints a formatted report.
|
|
802
|
-
"""
|
|
803
|
-
report = {}
|
|
804
|
-
for key, path in self.items():
|
|
805
|
-
if path.is_dir():
|
|
806
|
-
report[key] = "📁 Directory"
|
|
807
|
-
elif path.is_file():
|
|
808
|
-
report[key] = "📄 File"
|
|
809
|
-
else:
|
|
810
|
-
report[key] = "❌ Not Found"
|
|
811
|
-
|
|
812
|
-
print("\n--- Path Status Report ---")
|
|
813
|
-
pprint(report)
|
|
814
|
-
|
|
815
|
-
def __repr__(self) -> str:
|
|
816
|
-
"""Provides a string representation of the stored paths."""
|
|
817
|
-
path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
|
|
818
|
-
return f"PathManager(\n{path_list}\n)"
|
|
819
|
-
|
|
820
|
-
# --- Dictionary-Style Methods ---
|
|
821
|
-
def __getitem__(self, key: str) -> Path:
|
|
822
|
-
"""Allows dictionary-style getting, e.g., PM['my_key']"""
|
|
823
|
-
return self.get(key)
|
|
824
|
-
|
|
825
|
-
def __setitem__(self, key: str, value: Union[str, Path]):
|
|
826
|
-
"""Allows dictionary-style setting, e.g., PM['my_key'] = path"""
|
|
827
|
-
self.update({key: value}, overwrite=True)
|
|
828
|
-
|
|
829
|
-
def __contains__(self, key: str) -> bool:
|
|
830
|
-
"""Allows checking for a key's existence, e.g., if 'my_key' in PM"""
|
|
831
|
-
return key in self._paths
|
|
832
|
-
|
|
833
|
-
def __len__(self) -> int:
|
|
834
|
-
"""Allows getting the number of paths, e.g., len(PM)"""
|
|
835
|
-
return len(self._paths)
|
|
836
|
-
|
|
837
|
-
def keys(self):
|
|
838
|
-
"""Returns all registered path keys."""
|
|
839
|
-
return self._paths.keys()
|
|
840
|
-
|
|
841
|
-
def values(self):
|
|
842
|
-
"""Returns all registered Path objects."""
|
|
843
|
-
return self._paths.values()
|
|
844
|
-
|
|
845
|
-
def items(self):
|
|
846
|
-
"""Returns all registered (key, Path) pairs."""
|
|
847
|
-
return self._paths.items()
|
|
848
|
-
|
|
849
|
-
|
|
850
646
|
class LogKeys:
|
|
851
647
|
"""
|
|
852
648
|
Used internally for ML scripts.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|