dragon-ml-toolbox 3.8.0__py3-none-any.whl → 3.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 3.8.0
3
+ Version: 3.9.1
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -15,8 +15,8 @@ License-File: LICENSE-THIRD-PARTY.md
15
15
  Requires-Dist: numpy<2.0
16
16
  Requires-Dist: scikit-learn
17
17
  Requires-Dist: openpyxl
18
- Requires-Dist: miceforest<7.0.0,>=6.0.0
19
- Requires-Dist: plotnine<0.13,>=0.12
18
+ Requires-Dist: miceforest>=6.0.0
19
+ Requires-Dist: plotnine>=0.12
20
20
  Requires-Dist: matplotlib
21
21
  Requires-Dist: seaborn
22
22
  Requires-Dist: pandas
@@ -129,6 +129,7 @@ ML_callbacks
129
129
  ML_evaluation
130
130
  ML_trainer
131
131
  ML_tutorial
132
+ path_manager
132
133
  PSO_optimization
133
134
  RNN_forecast
134
135
  utilities
@@ -1,7 +1,7 @@
1
- dragon_ml_toolbox-3.8.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-3.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
1
+ dragon_ml_toolbox-3.9.1.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-3.9.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
3
  ml_tools/ETL_engineering.py,sha256=yeZsW_7zRvEcuMZbM4E2GV1dxwBoWIeJAcFFk2AK0fY,39502
4
- ml_tools/GUI_tools.py,sha256=z0CbN8zOC9bSGxOwcf539gSmvXyn-xP5xXHPxWiywMI,17920
4
+ ml_tools/GUI_tools.py,sha256=ayLwQMkpkFPoun7TxT2Llq5whVIWDjcHXU_ljENvueM,19118
5
5
  ml_tools/MICE_imputation.py,sha256=rYqvwQDVtoAJJ0agXWoGzoZEHedWiA6QzcEKEIkiZ08,11388
6
6
  ml_tools/ML_callbacks.py,sha256=OT2zwORLcn49megBEgXsSUxDHoW0Ft0_v7hLEVF3jHM,13063
7
7
  ml_tools/ML_evaluation.py,sha256=oiDV6HItQloUUKCUpltV-2pogubWLBieGpc-VUwosAQ,10106
@@ -15,11 +15,12 @@ ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9w
15
15
  ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
16
16
  ml_tools/data_exploration.py,sha256=M7bn2q5XN9zJZJGAmMMFSFFZh8LGzC2arFelrXw3N6Q,25241
17
17
  ml_tools/datasetmaster.py,sha256=S3PKHNQZ9cyAOck8xQltVLZhaD1gFLfgHFL-aRjz4JU,30077
18
- ml_tools/ensemble_learning.py,sha256=CDSIygnHaNe92aJ46Fofevd7q6lowTnE98yWuIV3Y6w,37462
18
+ ml_tools/ensemble_learning.py,sha256=p9PZwGY2OGSrJhXNzvMS_kCjK-I2JVcqiJBaVzb0GrM,42616
19
19
  ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
20
20
  ml_tools/logger.py,sha256=UkbiU9ihBhw9VKyn3rZzisdClWV94EBV6B09_D0iUU0,6026
21
- ml_tools/utilities.py,sha256=ghEOhN5-eozgfDjJ0r8qOBlnDYahn3jaYzfnitL-GDU,31375
22
- dragon_ml_toolbox-3.8.0.dist-info/METADATA,sha256=FBhxslY5Lx2HlauipzYsoPovFSdGqlYjgaN0oRVxfLk,3273
23
- dragon_ml_toolbox-3.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
- dragon_ml_toolbox-3.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
25
- dragon_ml_toolbox-3.8.0.dist-info/RECORD,,
21
+ ml_tools/path_manager.py,sha256=OCpESgdftbi6mOxetDMIaHhazt4N-W8pJx11X3-yNOs,8305
22
+ ml_tools/utilities.py,sha256=HR36Q_vYnaRcpSjpNISnA7lOZ36TouHop38lPLG_twY,23146
23
+ dragon_ml_toolbox-3.9.1.dist-info/METADATA,sha256=u0yfIE9prmFgf-ZVkXsGj67w_B9FiK6zZS6voAfWHAg,3273
24
+ dragon_ml_toolbox-3.9.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ dragon_ml_toolbox-3.9.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
26
+ dragon_ml_toolbox-3.9.1.dist-info/RECORD,,
ml_tools/GUI_tools.py CHANGED
@@ -1,20 +1,20 @@
1
1
  import configparser
2
2
  from pathlib import Path
3
- from typing import Optional, Callable, Any
4
3
  import traceback
5
4
  import FreeSimpleGUI as sg
6
5
  from functools import wraps
7
- from typing import Any, Dict, Tuple, List, Literal
6
+ from typing import Any, Dict, Tuple, List, Literal, Union, Any, Optional
8
7
  from .utilities import _script_info
9
8
  import numpy as np
10
9
  from .logger import _LOGGER
10
+ from abc import ABC, abstractmethod
11
11
 
12
12
 
13
13
  __all__ = [
14
14
  "ConfigManager",
15
15
  "GUIFactory",
16
16
  "catch_exceptions",
17
- "prepare_feature_vector",
17
+ "BaseFeatureHandler",
18
18
  "update_target_fields"
19
19
  ]
20
20
 
@@ -184,7 +184,7 @@ class GUIFactory:
184
184
  }
185
185
  return sg.Button(text.title(), key=key, **style_args)
186
186
 
187
- def make_frame(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Frame:
187
+ def make_frame(self, title: str, layout: List[List[Union[sg.Element, sg.Column]]], **kwargs) -> sg.Frame:
188
188
  """
189
189
  Creates a styled frame around a given layout.
190
190
 
@@ -208,7 +208,7 @@ class GUIFactory:
208
208
  # --- General-Purpose Layout Generators ---
209
209
  def generate_continuous_layout(
210
210
  self,
211
- data_dict: Dict[str, Tuple[float, float]],
211
+ data_dict: Dict[str, Optional[Tuple[Union[int,float], Union[int,float]]]],
212
212
  is_target: bool = False,
213
213
  layout_mode: Literal["grid", "row"] = 'grid',
214
214
  features_per_column: int = 4
@@ -230,7 +230,13 @@ class GUIFactory:
230
230
  label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style) # type: ignore
231
231
 
232
232
  columns = []
233
- for name, (val_min, val_max) in data_dict.items():
233
+ for name, value in data_dict.items():
234
+ if value is None:
235
+ val_min, val_max = None, None
236
+ if not is_target:
237
+ raise ValueError(f"Feature '{name}' was assigned a 'None' value. It is not defined as a target.")
238
+ else:
239
+ val_min, val_max = value
234
240
  key = name
235
241
  default_text = "" if is_target else str(val_max)
236
242
 
@@ -247,7 +253,7 @@ class GUIFactory:
247
253
  layout = [[label], [element]]
248
254
  else:
249
255
  range_font = (cfg.fonts.font_family, cfg.fonts.range_size) # type: ignore
250
- range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
256
+ range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color) # type: ignore
251
257
  layout = [[label], [element], [range_text]]
252
258
 
253
259
  # each feature is wrapped as a column element
@@ -351,68 +357,93 @@ def catch_exceptions(show_popup: bool = True):
351
357
  return decorator
352
358
 
353
359
 
354
- # --- Inference Helpers ---
355
- def _default_categorical_processor(feature_name: str, chosen_value: Any) -> List[float]:
356
- """
357
- Default processor for binary 'True'/'False' strings.
358
- Returns a list containing a single float.
360
+ # --- Inference Helper ---
361
+ class BaseFeatureHandler(ABC):
359
362
  """
360
- return [1.0] if str(chosen_value) == 'True' else [0.0]
361
-
362
- def prepare_feature_vector(
363
- window_values: Dict[str, Any],
364
- gui_feature_order: List[str],
365
- continuous_features: List[str],
366
- categorical_features: List[str],
367
- categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
368
- ) -> np.ndarray:
369
- """
370
- Validates and converts GUI values into a numpy array for a model.
371
- This function supports label encoding and one-hot encoding via the processor.
363
+ An abstract base class that defines the template for preparing a model input feature vector to perform inference, from GUI inputs.
372
364
 
373
- Args:
374
- window_values (dict): The values dictionary from a `window.read()` call.
375
- gui_feature_order (list): A list of all feature names that have a GUI element.
376
- For one-hot encoding, this should be the name of the
377
- single GUI element (e.g., 'material_type'), not the
378
- expanded feature names (e.g., 'material_is_steel').
379
- continuous_features (list): A list of names for continuous features.
380
- categorical_features (list): A list of names for categorical features.
381
- categorical_processor (callable, optional): A function to process categorical
382
- values. It should accept (feature_name, chosen_value) and return a
383
- list of floats (e.g., [1.0] for label encoding, [0.0, 1.0, 0.0] for one-hot).
384
- If None, a default 'True'/'False' processor is used.
385
-
386
- Returns:
387
- A 1D numpy array ready for model inference.
365
+ A subclass must implement the `gui_input_map` property and the `process_categorical` method.
388
366
  """
389
- processed_values: List[float] = []
390
-
391
- # Use the provided processor or the default one
392
- processor = categorical_processor or _default_categorical_processor
393
-
394
- # Create sets for faster lookups
395
- cont_set = set(continuous_features)
396
- cat_set = set(categorical_features)
397
-
398
- for name in gui_feature_order:
399
- chosen_value = window_values.get(name)
367
+ def __init__(self, expected_columns_in_order: list[str]):
368
+ """
369
+ Validates and stores the feature names in the order the model expects.
400
370
 
401
- if chosen_value is None or chosen_value == '':
402
- raise ValueError(f"Feature '{name}' is missing a value.")
371
+ Args:
372
+ expected_columns_in_order (List[str]): A list of strings with the feature names in the correct order.
373
+ """
374
+ # --- Validation Logic ---
375
+ if not isinstance(expected_columns_in_order, list):
376
+ raise TypeError("Input 'expected_columns_in_order' must be a list.")
377
+
378
+ if not all(isinstance(col, str) for col in expected_columns_in_order):
379
+ raise TypeError("All elements in the 'expected_columns_in_order' list must be strings.")
380
+ # -----------------------
381
+
382
+ self._model_feature_order = expected_columns_in_order
383
+
384
+ @property
385
+ @abstractmethod
386
+ def gui_input_map(self) -> Dict[str, Literal["continuous","categorical"]]:
387
+ """
388
+ Must be implemented by the subclass.
403
389
 
404
- if name in cont_set:
405
- try:
406
- processed_values.append(float(chosen_value))
407
- except (ValueError, TypeError):
408
- raise ValueError(f"Invalid input for '{name}'. Please enter a valid number.")
390
+ Should return a dictionary mapping each GUI input name to its type ('continuous' or 'categorical').
409
391
 
410
- elif name in cat_set:
411
- # The processor returns a list of values (one for label, multiple for one-hot)
412
- numeric_values = processor(name, chosen_value)
413
- processed_values.extend(numeric_values)
392
+ ```python
393
+ #Example:
394
+ {'temperature': 'continuous', 'material_type': 'categorical'}
395
+ ```
396
+ """
397
+ pass
398
+
399
+ @abstractmethod
400
+ def process_categorical(self, feature_name: str, chosen_value: Any) -> Dict[str, float]:
401
+ """
402
+ Must be implemented by the subclass.
403
+
404
+ Should take a GUI categorical feature name and its chosen value, and return a dictionary mapping the one-hot-encoded feature names to their
405
+ float values (as expected by the inference model).
406
+ """
407
+ pass
408
+
409
+ def __call__(self, window_values: Dict[str, Any]) -> np.ndarray:
410
+ """
411
+ Performs the full vector preparation, returning a 1D numpy array.
412
+
413
+ Should not be overridden by subclasses.
414
+ """
415
+ # Stage 1: Process GUI inputs into a dictionary
416
+ processed_features: Dict[str, float] = {}
417
+ for gui_name, feature_type in self.gui_input_map.items():
418
+ chosen_value = window_values.get(gui_name)
419
+
420
+ if chosen_value is None or str(chosen_value) == '':
421
+ raise ValueError(f"GUI input '{gui_name}' is missing a value.")
422
+
423
+ if feature_type == 'continuous':
424
+ try:
425
+ processed_features[gui_name] = float(chosen_value)
426
+ except (ValueError, TypeError):
427
+ raise ValueError(f"Invalid number '{chosen_value}' for '{gui_name}'.")
428
+
429
+ elif feature_type == 'categorical':
430
+ feature_dict = self.process_categorical(gui_name, chosen_value)
431
+ processed_features.update(feature_dict)
432
+
433
+ # Stage 2: Assemble the final vector using the model's required order
434
+ final_vector: List[float] = []
435
+
436
+ try:
437
+ for feature_name in self._model_feature_order:
438
+ final_vector.append(processed_features[feature_name])
439
+ except KeyError as e:
440
+ raise RuntimeError(
441
+ f"Configuration Error: Implemented methods failed to generate "
442
+ f"the required model feature: '{e}'"
443
+ f"Check the gui_input_map and process_categorical logic."
444
+ )
414
445
 
415
- return np.array(processed_values, dtype=np.float32)
446
+ return np.array(final_vector, dtype=np.float32)
416
447
 
417
448
 
418
449
  def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
@@ -6,7 +6,7 @@ from matplotlib.colors import Colormap
6
6
  from matplotlib import rcdefaults
7
7
 
8
8
  from pathlib import Path
9
- from typing import Literal, Union, Optional, Iterator, Tuple
9
+ from typing import Literal, Union, Optional, Iterator, Tuple, Dict, Any, List
10
10
 
11
11
  from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
12
12
  from imblearn.under_sampling import RandomUnderSampler
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
19
19
  from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
20
20
  import shap
21
21
 
22
- from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
22
+ from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath, list_files_by_extension, deserialize_object
23
23
  from .logger import _LOGGER
24
24
 
25
25
  import warnings # Ignore warnings
@@ -38,7 +38,8 @@ __all__ = [
38
38
  "evaluate_model_regression",
39
39
  "get_shap_values",
40
40
  "train_test_pipeline",
41
- "run_ensemble_pipeline"
41
+ "run_ensemble_pipeline",
42
+ "InferenceHandler"
42
43
  ]
43
44
 
44
45
  ## Type aliases
@@ -937,5 +938,124 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
937
938
  _LOGGER.info("✅ Training and evaluation complete.")
938
939
 
939
940
 
941
+ ###### 6. Inference ######
942
+ class InferenceHandler:
943
+ """
944
+ Handles loading ensemble models and performing inference for either regression or classification tasks.
945
+ """
946
+ def __init__(self,
947
+ models_dir: Union[str,Path],
948
+ task: TaskType,
949
+ verbose: bool=True) -> None:
950
+ """
951
+ Initializes the handler by loading all models from a directory.
952
+
953
+ Args:
954
+ models_dir (Path): The directory containing the saved .joblib model files.
955
+ task ("regression" | "classification"): The type of task the models perform.
956
+ """
957
+ self.models: Dict[str, Any] = dict()
958
+ self.task: str = task
959
+ self.verbose = verbose
960
+ self._feature_names: Optional[List[str]] = None
961
+
962
+ model_files = list_files_by_extension(directory=models_dir, extension="joblib")
963
+
964
+ for fname, fpath in model_files.items():
965
+ try:
966
+ full_object: dict
967
+ full_object = deserialize_object(filepath=fpath,
968
+ verbose=self.verbose,
969
+ raise_on_error=True) # type: ignore
970
+
971
+ model: Any = full_object["model"]
972
+ target_name: str = full_object["target_name"]
973
+ feature_names_list: List[str] = full_object["feature_names"]
974
+
975
+ # Check that feature names match
976
+ if self._feature_names is None:
977
+ # Store the feature names from the first model loaded.
978
+ self._feature_names = feature_names_list
979
+ elif self._feature_names != feature_names_list:
980
+ # Add a warning if subsequent models have different feature names.
981
+ _LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
982
+
983
+ self.models[target_name] = model
984
+ if self.verbose:
985
+ _LOGGER.info(f"✅ Loaded model for target: {target_name}")
986
+
987
+ except Exception as e:
988
+ _LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
989
+
990
+ @property
991
+ def feature_names(self) -> List[str]:
992
+ """
993
+ Getter for the list of feature names the models expect.
994
+ Returns an empty list if no models were loaded.
995
+ """
996
+ return self._feature_names if self._feature_names is not None else []
997
+
998
+ def predict(self, features: np.ndarray) -> Dict[str, Any]:
999
+ """
1000
+ Predicts on a single feature vector.
1001
+
1002
+ Args:
1003
+ features (np.ndarray): A 1D or 2D NumPy array for a single sample.
1004
+
1005
+ Returns:
1006
+ Dict[str, Any]: A dictionary where keys are target names.
1007
+ - For regression: The value is the single predicted float.
1008
+ - For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
1009
+ """
1010
+ if features.ndim == 1:
1011
+ features = features.reshape(1, -1)
1012
+
1013
+ if features.shape[0] != 1:
1014
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
1015
+
1016
+ results: Dict[str, Any] = dict()
1017
+ for target_name, model in self.models.items():
1018
+ if self.task == "regression":
1019
+ prediction = model.predict(features)
1020
+ results[target_name] = prediction.item()
1021
+ else: # Classification
1022
+ label = model.predict(features)[0]
1023
+ probabilities = model.predict_proba(features)[0]
1024
+ results[target_name] = {"label": label, "probabilities": probabilities}
1025
+
1026
+ if self.verbose:
1027
+ _LOGGER.info("✅ Inference process complete.")
1028
+ return results
1029
+
1030
+ def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
1031
+ """
1032
+ Predicts on a batch of feature vectors.
1033
+
1034
+ Args:
1035
+ features (np.ndarray): A 2D NumPy array where each row is a sample.
1036
+
1037
+ Returns:
1038
+ Dict[str, Any]: A dictionary where keys are target names.
1039
+ - For regression: The value is a NumPy array of predictions.
1040
+ - For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
1041
+ """
1042
+ if features.ndim != 2:
1043
+ raise ValueError("Input for batch prediction must be a 2D array.")
1044
+
1045
+ results: Dict[str, Any] = dict()
1046
+ for target_name, model in self.models.items():
1047
+ if self.task == "regression":
1048
+ results[target_name] = model.predict(features)
1049
+ else: # Classification
1050
+ labels = model.predict(features)
1051
+ probabilities = model.predict_proba(features)
1052
+ results[target_name] = {"labels": labels, "probabilities": probabilities}
1053
+
1054
+ if self.verbose:
1055
+ _LOGGER.info("✅ Inference process complete.")
1056
+
1057
+ return results
1058
+
1059
+
940
1060
  def info():
941
1061
  _script_info(__all__)
@@ -0,0 +1,212 @@
1
+ from pprint import pprint
2
+ from typing import Optional, List, Dict, Callable, Union
3
+ from pathlib import Path
4
+ from .utilities import _script_info
5
+ from .logger import _LOGGER
6
+
7
+
8
+ __all__ = [
9
+ "PathManager"
10
+ ]
11
+
12
+
13
+ class PathManager:
14
+ """
15
+ Manages and stores a project's file paths, acting as a centralized
16
+ "path database". It supports both development mode and applications
17
+ bundled with Briefcase.
18
+
19
+ Supports python dictionary syntax.
20
+ """
21
+ def __init__(
22
+ self,
23
+ anchor_file: str,
24
+ base_directories: Optional[List[str]] = None
25
+ ):
26
+ """
27
+ The initializer determines the project's root directory and can pre-register
28
+ a list of base directories relative to that root.
29
+
30
+ Args:
31
+ anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
32
+ base_directories (Optional[List[str]]): A list of directory names located at the same level as the anchor file to be registered immediately.
33
+ """
34
+ resolved_anchor_path = Path(anchor_file).resolve()
35
+ self._package_name = resolved_anchor_path.parent.name
36
+ self._is_bundled, self._resource_path_func = self._check_bundle_status()
37
+ self._paths: Dict[str, Path] = {}
38
+
39
+ if self._is_bundled:
40
+ # In a bundle, resource_path gives the absolute path to the 'app_packages' dir
41
+ # when given the package name.
42
+ package_root = self._resource_path_func(self._package_name) # type: ignore
43
+ else:
44
+ # In dev mode, the package root is the directory containing the anchor file.
45
+ package_root = resolved_anchor_path.parent
46
+
47
+ # Register the root of the package itself
48
+ self._paths["ROOT"] = package_root
49
+
50
+ # Register all the base directories
51
+ if base_directories:
52
+ for dir_name in base_directories:
53
+ # In dev mode, this is simple. In a bundle, we must resolve
54
+ # each path from the package root.
55
+ if self._is_bundled:
56
+ self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
57
+ else:
58
+ self._paths[dir_name] = package_root / dir_name
59
+
60
+ # A helper function to find the briefcase-injected resource function
61
+ def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
62
+ """Checks if the app is running in a Briefcase bundle."""
63
+ try:
64
+ # This function is injected by Briefcase into the global scope
65
+ from briefcase.platforms.base import resource_path # type: ignore
66
+ return True, resource_path
67
+ except (ImportError, NameError):
68
+ return False, None
69
+
70
+ def get(self, key: str) -> Path:
71
+ """
72
+ Retrieves a stored path by its key.
73
+
74
+ Args:
75
+ key (str): The key of the path to retrieve.
76
+
77
+ Returns:
78
+ Path: The resolved, absolute Path object.
79
+
80
+ Raises:
81
+ KeyError: If the key is not found in the manager.
82
+ """
83
+ try:
84
+ return self._paths[key]
85
+ except KeyError:
86
+ _LOGGER.error(f"❌ Path key '{key}' not found.")
87
+ raise
88
+
89
+ def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
90
+ """
91
+ Adds new paths or overwrites existing ones in the manager.
92
+
93
+ Args:
94
+ new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
95
+ the identifiers and values are the
96
+ Path objects or strings to store.
97
+ overwrite (bool): If False (default), raises a KeyError if any
98
+ key in new_paths already exists. If True,
99
+ allows overwriting existing keys.
100
+ """
101
+ if not overwrite:
102
+ for key in new_paths:
103
+ if key in self._paths:
104
+ raise KeyError(
105
+ f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
106
+ )
107
+
108
+ # Resolve any string paths to Path objects before storing
109
+ resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
110
+ self._paths.update(resolved_new_paths)
111
+
112
+ def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
113
+ """
114
+ Creates directory structures for registered paths in writable locations.
115
+
116
+ This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
117
+
118
+ In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
119
+ will only operate on paths outside of the package (e.g., user data dirs).
120
+
121
+ Args:
122
+ keys (Optional[List[str]]): If provided, only the directories
123
+ corresponding to these keys will be
124
+ created. If None (default), all
125
+ registered directory paths are used.
126
+ verbose (bool): If True, prints a message for each action.
127
+ """
128
+ path_items = []
129
+ if keys:
130
+ for key in keys:
131
+ if key in self._paths:
132
+ path_items.append((key, self._paths[key]))
133
+ elif verbose:
134
+ _LOGGER.warning(f"⚠️ Key '{key}' not found in PathManager, skipping.")
135
+ else:
136
+ path_items = self._paths.items()
137
+
138
+ # Get the package root to check against.
139
+ package_root = self._paths.get("ROOT")
140
+
141
+ for key, path in path_items:
142
+ if path.suffix: # It's a file, not a directory
143
+ continue
144
+
145
+ # --- THE CRITICAL CHECK ---
146
+ # Determine if the path is inside the main application package.
147
+ is_internal_path = package_root and path.is_relative_to(package_root)
148
+
149
+ if self._is_bundled and is_internal_path:
150
+ if verbose:
151
+ _LOGGER.warning(f"⚠️ Skipping internal directory '{key}' in bundled app (read-only).")
152
+ continue
153
+ # -------------------------
154
+
155
+ if verbose:
156
+ _LOGGER.info(f"📁 Ensuring directory exists for key '{key}': {path}")
157
+
158
+ path.mkdir(parents=True, exist_ok=True)
159
+
160
+ def status(self) -> None:
161
+ """
162
+ Checks the status of all registered paths on the filesystem and prints a formatted report.
163
+ """
164
+ report = {}
165
+ for key, path in self.items():
166
+ if path.is_dir():
167
+ report[key] = "📁 Directory"
168
+ elif path.is_file():
169
+ report[key] = "📄 File"
170
+ else:
171
+ report[key] = "❌ Not Found"
172
+
173
+ print("\n--- Path Status Report ---")
174
+ pprint(report)
175
+
176
+ def __repr__(self) -> str:
177
+ """Provides a string representation of the stored paths."""
178
+ path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
179
+ return f"PathManager(\n{path_list}\n)"
180
+
181
+ # --- Dictionary-Style Methods ---
182
+ def __getitem__(self, key: str) -> Path:
183
+ """Allows dictionary-style getting, e.g., PM['my_key']"""
184
+ return self.get(key)
185
+
186
+ def __setitem__(self, key: str, value: Union[str, Path]):
187
+ """Allows dictionary-style setting, does not allow overwriting, e.g., PM['my_key'] = path"""
188
+ self.update({key: value}, overwrite=False)
189
+
190
+ def __contains__(self, key: str) -> bool:
191
+ """Allows checking for a key's existence, e.g., if 'my_key' in PM"""
192
+ return key in self._paths
193
+
194
+ def __len__(self) -> int:
195
+ """Allows getting the number of paths, e.g., len(PM)"""
196
+ return len(self._paths)
197
+
198
+ def keys(self):
199
+ """Returns all registered path keys."""
200
+ return self._paths.keys()
201
+
202
+ def values(self):
203
+ """Returns all registered Path objects."""
204
+ return self._paths.values()
205
+
206
+ def items(self):
207
+ """Returns all registered (key, Path) pairs."""
208
+ return self._paths.items()
209
+
210
+
211
+ def info():
212
+ _script_info(__all__)
ml_tools/utilities.py CHANGED
@@ -4,10 +4,9 @@ import pandas as pd
4
4
  import polars as pl
5
5
  from pathlib import Path
6
6
  import re
7
- from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple, Callable, List, Dict
7
+ from typing import Literal, Union, Sequence, Optional, Any, Iterator, Tuple
8
8
  import joblib
9
9
  from joblib.externals.loky.process_executor import TerminatedWorkerError
10
- from pprint import pprint
11
10
 
12
11
 
13
12
  # Keep track of available tools
@@ -27,7 +26,6 @@ __all__ = [
27
26
  "deserialize_object",
28
27
  "distribute_datasets_by_target",
29
28
  "train_dataset_orchestrator",
30
- "PathManager"
31
29
  ]
32
30
 
33
31
 
@@ -645,208 +643,6 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
645
643
  print(f"\n✅ {total_saved} single-target datasets were created.")
646
644
 
647
645
 
648
- ### Path Manager
649
- class PathManager:
650
- """
651
- Manages and stores a project's file paths, acting as a centralized
652
- "path database". It supports both development mode and applications
653
- bundled with Briefcase.
654
-
655
- Supports python dictionary syntax.
656
- """
657
- def __init__(
658
- self,
659
- anchor_file: str,
660
- base_directories: Optional[List[str]] = None
661
- ):
662
- """
663
- The initializer determines the project's root directory and can pre-register
664
- a list of base directories relative to that root.
665
-
666
- Args:
667
- anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
668
- base_directories (Optional[List[str]]): A list of directory names
669
- located at the same level as the anchor file's
670
- parent directory to register immediately.
671
- """
672
- resolved_anchor_path = Path(anchor_file).resolve()
673
- self._package_name = resolved_anchor_path.parent.name
674
- self._is_bundled, self._resource_path_func = self._check_bundle_status()
675
- self._paths: Dict[str, Path] = {}
676
-
677
- if self._is_bundled:
678
- # In a bundle, resource_path gives the absolute path to the 'app_packages' dir
679
- # when given the package name.
680
- package_root = self._resource_path_func(self._package_name) # type: ignore
681
- else:
682
- # In dev mode, the package root is the directory containing the anchor file.
683
- package_root = resolved_anchor_path.parent
684
-
685
- # Register the root of the package itself
686
- self._paths["ROOT"] = package_root
687
-
688
- # Register all the base directories
689
- if base_directories:
690
- for dir_name in base_directories:
691
- # In dev mode, this is simple. In a bundle, we must resolve
692
- # each path from the package root.
693
- if self._is_bundled:
694
- self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
695
- else:
696
- self._paths[dir_name] = package_root / dir_name
697
-
698
- # A helper function to find the briefcase-injected resource function
699
- def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
700
- """Checks if the app is running in a Briefcase bundle."""
701
- try:
702
- # This function is injected by Briefcase into the global scope
703
- from briefcase.platforms.base import resource_path # type: ignore
704
- return True, resource_path
705
- except (ImportError, NameError):
706
- return False, None
707
-
708
- def get(self, key: str) -> Path:
709
- """
710
- Retrieves a stored path by its key.
711
-
712
- Args:
713
- key (str): The key of the path to retrieve.
714
-
715
- Returns:
716
- Path: The resolved, absolute Path object.
717
-
718
- Raises:
719
- KeyError: If the key is not found in the manager.
720
- """
721
- try:
722
- return self._paths[key]
723
- except KeyError:
724
- print(f"❌ Path key '{key}' not found.")
725
- # Consider suggesting close matches if you want to get fancy
726
- raise
727
-
728
- def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
729
- """
730
- Adds new paths or overwrites existing ones in the manager.
731
-
732
- Args:
733
- new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
734
- the identifiers and values are the
735
- Path objects or strings to store.
736
- overwrite (bool): If False (default), raises a KeyError if any
737
- key in new_paths already exists. If True,
738
- allows overwriting existing keys.
739
- """
740
- if not overwrite:
741
- for key in new_paths:
742
- if key in self._paths:
743
- raise KeyError(
744
- f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
745
- )
746
-
747
- # Resolve any string paths to Path objects before storing
748
- resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
749
- self._paths.update(resolved_new_paths)
750
-
751
- def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
752
- """
753
- Creates directory structures for registered paths in writable locations.
754
-
755
- This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
756
-
757
- In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
758
- will only operate on paths outside of the package (e.g., user data dirs).
759
-
760
- Args:
761
- keys (Optional[List[str]]): If provided, only the directories
762
- corresponding to these keys will be
763
- created. If None (default), all
764
- registered directory paths are used.
765
- verbose (bool): If True, prints a message for each action.
766
- """
767
- path_items = []
768
- if keys:
769
- for key in keys:
770
- if key in self._paths:
771
- path_items.append((key, self._paths[key]))
772
- elif verbose:
773
- print(f"⚠️ Key '{key}' not found in PathManager, skipping.")
774
- else:
775
- path_items = self._paths.items()
776
-
777
- # Get the package root to check against.
778
- package_root = self._paths.get("ROOT")
779
-
780
- for key, path in path_items:
781
- if path.suffix: # It's a file, not a directory
782
- continue
783
-
784
- # --- THE CRITICAL CHECK ---
785
- # Determine if the path is inside the main application package.
786
- is_internal_path = package_root and path.is_relative_to(package_root)
787
-
788
- if self._is_bundled and is_internal_path:
789
- if verbose:
790
- print(f"ℹ️ Skipping internal directory '{key}' in bundled app (read-only).")
791
- continue
792
- # -------------------------
793
-
794
- if verbose:
795
- print(f"📁 Ensuring directory exists for key '{key}': {path}")
796
-
797
- path.mkdir(parents=True, exist_ok=True)
798
-
799
- def status(self) -> None:
800
- """
801
- Checks the status of all registered paths on the filesystem and prints a formatted report.
802
- """
803
- report = {}
804
- for key, path in self.items():
805
- if path.is_dir():
806
- report[key] = "📁 Directory"
807
- elif path.is_file():
808
- report[key] = "📄 File"
809
- else:
810
- report[key] = "❌ Not Found"
811
-
812
- print("\n--- Path Status Report ---")
813
- pprint(report)
814
-
815
- def __repr__(self) -> str:
816
- """Provides a string representation of the stored paths."""
817
- path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
818
- return f"PathManager(\n{path_list}\n)"
819
-
820
- # --- Dictionary-Style Methods ---
821
- def __getitem__(self, key: str) -> Path:
822
- """Allows dictionary-style getting, e.g., PM['my_key']"""
823
- return self.get(key)
824
-
825
- def __setitem__(self, key: str, value: Union[str, Path]):
826
- """Allows dictionary-style setting, e.g., PM['my_key'] = path"""
827
- self.update({key: value}, overwrite=True)
828
-
829
- def __contains__(self, key: str) -> bool:
830
- """Allows checking for a key's existence, e.g., if 'my_key' in PM"""
831
- return key in self._paths
832
-
833
- def __len__(self) -> int:
834
- """Allows getting the number of paths, e.g., len(PM)"""
835
- return len(self._paths)
836
-
837
- def keys(self):
838
- """Returns all registered path keys."""
839
- return self._paths.keys()
840
-
841
- def values(self):
842
- """Returns all registered Path objects."""
843
- return self._paths.values()
844
-
845
- def items(self):
846
- """Returns all registered (key, Path) pairs."""
847
- return self._paths.items()
848
-
849
-
850
646
  class LogKeys:
851
647
  """
852
648
  Used internally for ML scripts.