dragon-ml-toolbox 3.7.0__py3-none-any.whl → 3.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 3.7.0
3
+ Version: 3.9.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -15,8 +15,8 @@ License-File: LICENSE-THIRD-PARTY.md
15
15
  Requires-Dist: numpy<2.0
16
16
  Requires-Dist: scikit-learn
17
17
  Requires-Dist: openpyxl
18
- Requires-Dist: miceforest<7.0.0,>=6.0.0
19
- Requires-Dist: plotnine<0.13,>=0.12
18
+ Requires-Dist: miceforest>=6.0.0
19
+ Requires-Dist: plotnine>=0.12
20
20
  Requires-Dist: matplotlib
21
21
  Requires-Dist: seaborn
22
22
  Requires-Dist: pandas
@@ -129,6 +129,7 @@ ML_callbacks
129
129
  ML_evaluation
130
130
  ML_trainer
131
131
  ML_tutorial
132
+ path_manager
132
133
  PSO_optimization
133
134
  RNN_forecast
134
135
  utilities
@@ -1,7 +1,7 @@
1
- dragon_ml_toolbox-3.7.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-3.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
1
+ dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-3.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
3
  ml_tools/ETL_engineering.py,sha256=yeZsW_7zRvEcuMZbM4E2GV1dxwBoWIeJAcFFk2AK0fY,39502
4
- ml_tools/GUI_tools.py,sha256=3kRxok-QCN5S0q1i7yK137Bsr6c2N4M4nIvgPVAuZU0,20371
4
+ ml_tools/GUI_tools.py,sha256=ABR1cqV09iZ2DbLfLZB7jaQVRVDbvCmj09pNkr3TDZk,18800
5
5
  ml_tools/MICE_imputation.py,sha256=rYqvwQDVtoAJJ0agXWoGzoZEHedWiA6QzcEKEIkiZ08,11388
6
6
  ml_tools/ML_callbacks.py,sha256=OT2zwORLcn49megBEgXsSUxDHoW0Ft0_v7hLEVF3jHM,13063
7
7
  ml_tools/ML_evaluation.py,sha256=oiDV6HItQloUUKCUpltV-2pogubWLBieGpc-VUwosAQ,10106
@@ -15,11 +15,12 @@ ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9w
15
15
  ml_tools/_pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
16
16
  ml_tools/data_exploration.py,sha256=M7bn2q5XN9zJZJGAmMMFSFFZh8LGzC2arFelrXw3N6Q,25241
17
17
  ml_tools/datasetmaster.py,sha256=S3PKHNQZ9cyAOck8xQltVLZhaD1gFLfgHFL-aRjz4JU,30077
18
- ml_tools/ensemble_learning.py,sha256=CDSIygnHaNe92aJ46Fofevd7q6lowTnE98yWuIV3Y6w,37462
18
+ ml_tools/ensemble_learning.py,sha256=p9PZwGY2OGSrJhXNzvMS_kCjK-I2JVcqiJBaVzb0GrM,42616
19
19
  ml_tools/handle_excel.py,sha256=lwds7rDLlGSCWiWGI7xNg-Z7kxAepogp0lstSFa0590,12949
20
20
  ml_tools/logger.py,sha256=UkbiU9ihBhw9VKyn3rZzisdClWV94EBV6B09_D0iUU0,6026
21
- ml_tools/utilities.py,sha256=0w0vka0Aj9IYOHJ6crWIb6gwpQIJnPyj3v2_dnVxHrs,23138
22
- dragon_ml_toolbox-3.7.0.dist-info/METADATA,sha256=kvgFjd_BRwob7xycC5rbROCkq4C6FVq3J5-VdCXEPrI,3273
23
- dragon_ml_toolbox-3.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
- dragon_ml_toolbox-3.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
25
- dragon_ml_toolbox-3.7.0.dist-info/RECORD,,
21
+ ml_tools/path_manager.py,sha256=OCpESgdftbi6mOxetDMIaHhazt4N-W8pJx11X3-yNOs,8305
22
+ ml_tools/utilities.py,sha256=HR36Q_vYnaRcpSjpNISnA7lOZ36TouHop38lPLG_twY,23146
23
+ dragon_ml_toolbox-3.9.0.dist-info/METADATA,sha256=2R3xIuefuR9O_h71q3S49xUm2MLKQtn12jjwNFKl2mE,3273
24
+ dragon_ml_toolbox-3.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ dragon_ml_toolbox-3.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
26
+ dragon_ml_toolbox-3.9.0.dist-info/RECORD,,
ml_tools/GUI_tools.py CHANGED
@@ -4,83 +4,21 @@ from typing import Optional, Callable, Any
4
4
  import traceback
5
5
  import FreeSimpleGUI as sg
6
6
  from functools import wraps
7
- from typing import Any, Dict, Tuple, List
7
+ from typing import Any, Dict, Tuple, List, Literal
8
8
  from .utilities import _script_info
9
9
  import numpy as np
10
10
  from .logger import _LOGGER
11
+ from abc import ABC, abstractmethod
11
12
 
12
13
 
13
14
  __all__ = [
14
- "PathManager",
15
15
  "ConfigManager",
16
16
  "GUIFactory",
17
17
  "catch_exceptions",
18
- "prepare_feature_vector",
18
+ "BaseFeatureHandler",
19
19
  "update_target_fields"
20
20
  ]
21
21
 
22
-
23
- # --- Path Management ---
24
- class PathManager:
25
- """
26
- Manages paths for a Python application, supporting both development mode and bundled mode via Briefcase.
27
- """
28
- def __init__(self, anchor_file: str):
29
- """
30
- Initializes the PathManager. The package name is automatically inferred
31
- from the parent directory of the anchor file.
32
-
33
- Args:
34
- anchor_file (str): The absolute path to a file within the project's
35
- package, typically `__file__` from a module inside
36
- that package (paths.py).
37
-
38
- Note:
39
- This inference assumes that the anchor file's parent directory
40
- has the same name as the package (e.g., `.../src/my_app/paths.py`).
41
- This is a standard and recommended project structure.
42
- """
43
- resolved_anchor_path = Path(anchor_file).resolve()
44
- self.package_name = resolved_anchor_path.parent.name
45
- self._is_bundled, self._resource_path_func = self._check_bundle_status()
46
-
47
- if self._is_bundled:
48
- # In a Briefcase bundle, resource_path gives an absolute path
49
- # to the resource directory.
50
- self.package_root = self._resource_path_func(self.package_name, "") # type: ignore
51
- else:
52
- # In development mode, the package root is the directory
53
- # containing the anchor file.
54
- self.package_root = resolved_anchor_path.parent
55
-
56
- def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
57
- """Checks if the app is running in a bundled environment."""
58
- try:
59
- # This is the function Briefcase provides in a bundled app
60
- from briefcase.platforms.base import resource_path # type: ignore
61
- return True, resource_path
62
- except ImportError:
63
- return False, None
64
-
65
- def get_path(self, relative_path: str | Path) -> Path:
66
- """
67
- Gets the absolute path for a given resource file or directory
68
- relative to the package root.
69
-
70
- Args:
71
- relative_path (str | Path): The path relative to the package root (e.g., 'helpers/icon.png').
72
-
73
- Returns:
74
- Path: The absolute path to the resource.
75
- """
76
- if self._is_bundled:
77
- # Briefcase's resource_path handles resolving the path within the app bundle
78
- return self._resource_path_func(self.package_name, str(relative_path)) # type: ignore
79
- else:
80
- # In dev mode, join package root with the relative path.
81
- return self.package_root / relative_path
82
-
83
-
84
22
  # --- Configuration Management ---
85
23
  class _SectionProxy:
86
24
  """A helper class to represent a section of the .ini file as an object."""
@@ -273,8 +211,8 @@ class GUIFactory:
273
211
  self,
274
212
  data_dict: Dict[str, Tuple[float, float]],
275
213
  is_target: bool = False,
276
- layout_mode: str = 'grid',
277
- columns_per_row: int = 4
214
+ layout_mode: Literal["grid", "row"] = 'grid',
215
+ features_per_column: int = 4
278
216
  ) -> List[List[sg.Column]]:
279
217
  """
280
218
  Generates a layout for continuous features or targets.
@@ -283,7 +221,7 @@ class GUIFactory:
283
221
  data_dict (dict): Keys are feature names, values are (min, max) tuples.
284
222
  is_target (bool): If True, creates disabled inputs for displaying results.
285
223
  layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
286
- columns_per_row (int): Number of feature columns per row when layout_mode is 'grid'.
224
+ features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
287
225
 
288
226
  Returns:
289
227
  A list of lists of sg.Column elements, ready to be used in a window layout.
@@ -294,7 +232,7 @@ class GUIFactory:
294
232
 
295
233
  columns = []
296
234
  for name, (val_min, val_max) in data_dict.items():
297
- key = f"TARGET_{name}" if is_target else name
235
+ key = name
298
236
  default_text = "" if is_target else str(val_max)
299
237
 
300
238
  label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
@@ -313,6 +251,7 @@ class GUIFactory:
313
251
  range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
314
252
  layout = [[label], [element], [range_text]]
315
253
 
254
+ # each feature is wrapped as a column element
316
255
  layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
317
256
  columns.append(sg.Column(layout, background_color=bg_color))
318
257
 
@@ -320,13 +259,13 @@ class GUIFactory:
320
259
  return [columns] # A single row containing all columns
321
260
 
322
261
  # Default to 'grid' layout
323
- return [columns[i:i + columns_per_row] for i in range(0, len(columns), columns_per_row)]
262
+ return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
324
263
 
325
264
  def generate_combo_layout(
326
265
  self,
327
266
  data_dict: Dict[str, List[Any]],
328
- layout_mode: str = 'grid',
329
- columns_per_row: int = 4
267
+ layout_mode: Literal["grid", "row"] = 'grid',
268
+ features_per_column: int = 4
330
269
  ) -> List[List[sg.Column]]:
331
270
  """
332
271
  Generates a layout for categorical or binary features using Combo boxes.
@@ -334,7 +273,7 @@ class GUIFactory:
334
273
  Args:
335
274
  data_dict (dict): Keys are feature names, values are lists of options.
336
275
  layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
337
- columns_per_row (int): Number of feature columns per row when layout_mode is 'grid'.
276
+ features_per_column (int): Number of features per column when `layout_mode` is 'grid'.
338
277
 
339
278
  Returns:
340
279
  A list of lists of sg.Column elements, ready to be used in a window layout.
@@ -352,13 +291,14 @@ class GUIFactory:
352
291
  )
353
292
  layout = [[label], [element]]
354
293
  layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)]) # type: ignore
294
+ # each feature is wrapped in a Column element
355
295
  columns.append(sg.Column(layout, background_color=bg_color))
356
296
 
357
297
  if layout_mode == 'row':
358
298
  return [columns] # A single row containing all columns
359
299
 
360
300
  # Default to 'grid' layout
361
- return [columns[i:i + columns_per_row] for i in range(0, len(columns), columns_per_row)]
301
+ return [columns[i:i + features_per_column] for i in range(0, len(columns), features_per_column)]
362
302
 
363
303
  # --- Window Creation ---
364
304
  def create_window(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Window:
@@ -412,68 +352,93 @@ def catch_exceptions(show_popup: bool = True):
412
352
  return decorator
413
353
 
414
354
 
415
- # --- Inference Helpers ---
416
- def _default_categorical_processor(feature_name: str, chosen_value: Any) -> List[float]:
417
- """
418
- Default processor for binary 'True'/'False' strings.
419
- Returns a list containing a single float.
420
- """
421
- return [1.0] if str(chosen_value) == 'True' else [0.0]
422
-
423
- def prepare_feature_vector(
424
- values: Dict[str, Any],
425
- feature_order: List[str],
426
- continuous_features: List[str],
427
- categorical_features: List[str],
428
- categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
429
- ) -> np.ndarray:
355
+ # --- Inference Helper ---
356
+ class BaseFeatureHandler(ABC):
430
357
  """
431
- Validates and converts GUI values into a numpy array for a model.
432
- This function supports label encoding and one-hot encoding via the processor.
358
+ An abstract base class that defines the template for preparing a model input feature vector to perform inference, from GUI inputs.
433
359
 
434
- Args:
435
- values (dict): The values dictionary from a `window.read()` call.
436
- feature_order (list): A list of all feature names that have a GUI element.
437
- For one-hot encoding, this should be the name of the
438
- single GUI element (e.g., 'material_type'), not the
439
- expanded feature names (e.g., 'material_is_steel').
440
- continuous_features (list): A list of names for continuous features.
441
- categorical_features (list): A list of names for categorical features.
442
- categorical_processor (callable, optional): A function to process categorical
443
- values. It should accept (feature_name, chosen_value) and return a
444
- list of floats (e.g., [1.0] for label encoding, [0.0, 1.0, 0.0] for one-hot).
445
- If None, a default 'True'/'False' processor is used.
446
-
447
- Returns:
448
- A 1D numpy array ready for model inference.
360
+ A subclass must implement the `gui_input_map` property and the `process_categorical` method.
449
361
  """
450
- processed_values: List[float] = []
451
-
452
- # Use the provided processor or the default one
453
- processor = categorical_processor or _default_categorical_processor
454
-
455
- # Create sets for faster lookups
456
- cont_set = set(continuous_features)
457
- cat_set = set(categorical_features)
458
-
459
- for name in feature_order:
460
- chosen_value = values.get(name)
362
+ def __init__(self, expected_columns_in_order: list[str]):
363
+ """
364
+ Validates and stores the feature names in the order the model expects.
461
365
 
462
- if chosen_value is None or chosen_value == '':
463
- raise ValueError(f"Feature '{name}' is missing a value.")
366
+ Args:
367
+ expected_columns_in_order (List[str]): A list of strings with the feature names in the correct order.
368
+ """
369
+ # --- Validation Logic ---
370
+ if not isinstance(expected_columns_in_order, list):
371
+ raise TypeError("Input 'expected_columns_in_order' must be a list.")
372
+
373
+ if not all(isinstance(col, str) for col in expected_columns_in_order):
374
+ raise TypeError("All elements in the 'expected_columns_in_order' list must be strings.")
375
+ # -----------------------
376
+
377
+ self._model_feature_order = expected_columns_in_order
378
+
379
+ @property
380
+ @abstractmethod
381
+ def gui_input_map(self) -> Dict[str, Literal["continuous","categorical"]]:
382
+ """
383
+ Must be implemented by the subclass.
464
384
 
465
- if name in cont_set:
466
- try:
467
- processed_values.append(float(chosen_value))
468
- except (ValueError, TypeError):
469
- raise ValueError(f"Invalid input for '{name}'. Please enter a valid number.")
385
+ Should return a dictionary mapping each GUI input name to its type ('continuous' or 'categorical').
386
+
387
+ ```python
388
+ #Example:
389
+ {'temperature': 'continuous', 'material_type': 'categorical'}
390
+ ```
391
+ """
392
+ pass
393
+
394
+ @abstractmethod
395
+ def process_categorical(self, feature_name: str, chosen_value: Any) -> Dict[str, float]:
396
+ """
397
+ Must be implemented by the subclass.
398
+
399
+ Should take a GUI categorical feature name and its chosen value, and return a dictionary mapping the one-hot-encoded feature names to their
400
+ float values (as expected by the inference model).
401
+ """
402
+ pass
403
+
404
+ def __call__(self, window_values: Dict[str, Any]) -> np.ndarray:
405
+ """
406
+ Performs the full vector preparation, returning a 1D numpy array.
407
+
408
+ Should not be overridden by subclasses.
409
+ """
410
+ # Stage 1: Process GUI inputs into a dictionary
411
+ processed_features: Dict[str, float] = {}
412
+ for gui_name, feature_type in self.gui_input_map.items():
413
+ chosen_value = window_values.get(gui_name)
414
+
415
+ if chosen_value is None or str(chosen_value) == '':
416
+ raise ValueError(f"GUI input '{gui_name}' is missing a value.")
417
+
418
+ if feature_type == 'continuous':
419
+ try:
420
+ processed_features[gui_name] = float(chosen_value)
421
+ except (ValueError, TypeError):
422
+ raise ValueError(f"Invalid number '{chosen_value}' for '{gui_name}'.")
423
+
424
+ elif feature_type == 'categorical':
425
+ feature_dict = self.process_categorical(gui_name, chosen_value)
426
+ processed_features.update(feature_dict)
427
+
428
+ # Stage 2: Assemble the final vector using the model's required order
429
+ final_vector: List[float] = []
470
430
 
471
- elif name in cat_set:
472
- # The processor returns a list of values (one for label, multiple for one-hot)
473
- numeric_values = processor(name, chosen_value)
474
- processed_values.extend(numeric_values)
431
+ try:
432
+ for feature_name in self._model_feature_order:
433
+ final_vector.append(processed_features[feature_name])
434
+ except KeyError as e:
435
+ raise RuntimeError(
436
+ f"Configuration Error: Implemented methods failed to generate "
437
+ f"the required model feature: '{e}'"
438
+ f"Check the gui_input_map and process_categorical logic."
439
+ )
475
440
 
476
- return np.array(processed_values, dtype=np.float32)
441
+ return np.array(final_vector, dtype=np.float32)
477
442
 
478
443
 
479
444
  def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
@@ -482,12 +447,12 @@ def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
482
447
 
483
448
  Args:
484
449
  window (sg.Window): The application's window object.
485
- results_dict (dict): A dictionary where keys are target key names (including 'TARGET_' prefix if necessary) and values are the predicted results.
450
+ results_dict (dict): A dictionary where keys are target element-keys and values are the predicted results to update.
486
451
  """
487
452
  for target_name, result in results_dict.items():
488
453
  # Format numbers to 2 decimal places, leave other types as-is
489
454
  display_value = f"{result:.2f}" if isinstance(result, (int, float)) else result
490
- window[target_name].update(display_value)
455
+ window[target_name].update(display_value) # type: ignore
491
456
 
492
457
 
493
458
  def info():
@@ -6,7 +6,7 @@ from matplotlib.colors import Colormap
6
6
  from matplotlib import rcdefaults
7
7
 
8
8
  from pathlib import Path
9
- from typing import Literal, Union, Optional, Iterator, Tuple
9
+ from typing import Literal, Union, Optional, Iterator, Tuple, Dict, Any, List
10
10
 
11
11
  from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
12
12
  from imblearn.under_sampling import RandomUnderSampler
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
19
19
  from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
20
20
  import shap
21
21
 
22
- from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
22
+ from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath, list_files_by_extension, deserialize_object
23
23
  from .logger import _LOGGER
24
24
 
25
25
  import warnings # Ignore warnings
@@ -38,7 +38,8 @@ __all__ = [
38
38
  "evaluate_model_regression",
39
39
  "get_shap_values",
40
40
  "train_test_pipeline",
41
- "run_ensemble_pipeline"
41
+ "run_ensemble_pipeline",
42
+ "InferenceHandler"
42
43
  ]
43
44
 
44
45
  ## Type aliases
@@ -937,5 +938,124 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
937
938
  _LOGGER.info("✅ Training and evaluation complete.")
938
939
 
939
940
 
941
+ ###### 6. Inference ######
942
+ class InferenceHandler:
943
+ """
944
+ Handles loading ensemble models and performing inference for either regression or classification tasks.
945
+ """
946
+ def __init__(self,
947
+ models_dir: Union[str,Path],
948
+ task: TaskType,
949
+ verbose: bool=True) -> None:
950
+ """
951
+ Initializes the handler by loading all models from a directory.
952
+
953
+ Args:
954
+ models_dir (Path): The directory containing the saved .joblib model files.
955
+ task ("regression" | "classification"): The type of task the models perform.
956
+ """
957
+ self.models: Dict[str, Any] = dict()
958
+ self.task: str = task
959
+ self.verbose = verbose
960
+ self._feature_names: Optional[List[str]] = None
961
+
962
+ model_files = list_files_by_extension(directory=models_dir, extension="joblib")
963
+
964
+ for fname, fpath in model_files.items():
965
+ try:
966
+ full_object: dict
967
+ full_object = deserialize_object(filepath=fpath,
968
+ verbose=self.verbose,
969
+ raise_on_error=True) # type: ignore
970
+
971
+ model: Any = full_object["model"]
972
+ target_name: str = full_object["target_name"]
973
+ feature_names_list: List[str] = full_object["feature_names"]
974
+
975
+ # Check that feature names match
976
+ if self._feature_names is None:
977
+ # Store the feature names from the first model loaded.
978
+ self._feature_names = feature_names_list
979
+ elif self._feature_names != feature_names_list:
980
+ # Add a warning if subsequent models have different feature names.
981
+ _LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
982
+
983
+ self.models[target_name] = model
984
+ if self.verbose:
985
+ _LOGGER.info(f"✅ Loaded model for target: {target_name}")
986
+
987
+ except Exception as e:
988
+ _LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
989
+
990
+ @property
991
+ def feature_names(self) -> List[str]:
992
+ """
993
+ Getter for the list of feature names the models expect.
994
+ Returns an empty list if no models were loaded.
995
+ """
996
+ return self._feature_names if self._feature_names is not None else []
997
+
998
+ def predict(self, features: np.ndarray) -> Dict[str, Any]:
999
+ """
1000
+ Predicts on a single feature vector.
1001
+
1002
+ Args:
1003
+ features (np.ndarray): A 1D or 2D NumPy array for a single sample.
1004
+
1005
+ Returns:
1006
+ Dict[str, Any]: A dictionary where keys are target names.
1007
+ - For regression: The value is the single predicted float.
1008
+ - For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
1009
+ """
1010
+ if features.ndim == 1:
1011
+ features = features.reshape(1, -1)
1012
+
1013
+ if features.shape[0] != 1:
1014
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
1015
+
1016
+ results: Dict[str, Any] = dict()
1017
+ for target_name, model in self.models.items():
1018
+ if self.task == "regression":
1019
+ prediction = model.predict(features)
1020
+ results[target_name] = prediction.item()
1021
+ else: # Classification
1022
+ label = model.predict(features)[0]
1023
+ probabilities = model.predict_proba(features)[0]
1024
+ results[target_name] = {"label": label, "probabilities": probabilities}
1025
+
1026
+ if self.verbose:
1027
+ _LOGGER.info("✅ Inference process complete.")
1028
+ return results
1029
+
1030
+ def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
1031
+ """
1032
+ Predicts on a batch of feature vectors.
1033
+
1034
+ Args:
1035
+ features (np.ndarray): A 2D NumPy array where each row is a sample.
1036
+
1037
+ Returns:
1038
+ Dict[str, Any]: A dictionary where keys are target names.
1039
+ - For regression: The value is a NumPy array of predictions.
1040
+ - For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
1041
+ """
1042
+ if features.ndim != 2:
1043
+ raise ValueError("Input for batch prediction must be a 2D array.")
1044
+
1045
+ results: Dict[str, Any] = dict()
1046
+ for target_name, model in self.models.items():
1047
+ if self.task == "regression":
1048
+ results[target_name] = model.predict(features)
1049
+ else: # Classification
1050
+ labels = model.predict(features)
1051
+ probabilities = model.predict_proba(features)
1052
+ results[target_name] = {"labels": labels, "probabilities": probabilities}
1053
+
1054
+ if self.verbose:
1055
+ _LOGGER.info("✅ Inference process complete.")
1056
+
1057
+ return results
1058
+
1059
+
940
1060
  def info():
941
1061
  _script_info(__all__)
@@ -0,0 +1,212 @@
1
+ from pprint import pprint
2
+ from typing import Optional, List, Dict, Callable, Union
3
+ from pathlib import Path
4
+ from .utilities import _script_info
5
+ from .logger import _LOGGER
6
+
7
+
8
+ __all__ = [
9
+ "PathManager"
10
+ ]
11
+
12
+
13
+ class PathManager:
14
+ """
15
+ Manages and stores a project's file paths, acting as a centralized
16
+ "path database". It supports both development mode and applications
17
+ bundled with Briefcase.
18
+
19
+ Supports python dictionary syntax.
20
+ """
21
+ def __init__(
22
+ self,
23
+ anchor_file: str,
24
+ base_directories: Optional[List[str]] = None
25
+ ):
26
+ """
27
+ The initializer determines the project's root directory and can pre-register
28
+ a list of base directories relative to that root.
29
+
30
+ Args:
31
+ anchor_file (str): The absolute path to a file whose parent directory will be considered the package root and name. Typically, `__file__`.
32
+ base_directories (Optional[List[str]]): A list of directory names located at the same level as the anchor file to be registered immediately.
33
+ """
34
+ resolved_anchor_path = Path(anchor_file).resolve()
35
+ self._package_name = resolved_anchor_path.parent.name
36
+ self._is_bundled, self._resource_path_func = self._check_bundle_status()
37
+ self._paths: Dict[str, Path] = {}
38
+
39
+ if self._is_bundled:
40
+ # In a bundle, resource_path gives the absolute path to the 'app_packages' dir
41
+ # when given the package name.
42
+ package_root = self._resource_path_func(self._package_name) # type: ignore
43
+ else:
44
+ # In dev mode, the package root is the directory containing the anchor file.
45
+ package_root = resolved_anchor_path.parent
46
+
47
+ # Register the root of the package itself
48
+ self._paths["ROOT"] = package_root
49
+
50
+ # Register all the base directories
51
+ if base_directories:
52
+ for dir_name in base_directories:
53
+ # In dev mode, this is simple. In a bundle, we must resolve
54
+ # each path from the package root.
55
+ if self._is_bundled:
56
+ self._paths[dir_name] = self._resource_path_func(self._package_name, dir_name) # type: ignore
57
+ else:
58
+ self._paths[dir_name] = package_root / dir_name
59
+
60
+ # A helper function to find the briefcase-injected resource function
61
+ def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
62
+ """Checks if the app is running in a Briefcase bundle."""
63
+ try:
64
+ # This function is injected by Briefcase into the global scope
65
+ from briefcase.platforms.base import resource_path # type: ignore
66
+ return True, resource_path
67
+ except (ImportError, NameError):
68
+ return False, None
69
+
70
+ def get(self, key: str) -> Path:
71
+ """
72
+ Retrieves a stored path by its key.
73
+
74
+ Args:
75
+ key (str): The key of the path to retrieve.
76
+
77
+ Returns:
78
+ Path: The resolved, absolute Path object.
79
+
80
+ Raises:
81
+ KeyError: If the key is not found in the manager.
82
+ """
83
+ try:
84
+ return self._paths[key]
85
+ except KeyError:
86
+ _LOGGER.error(f"❌ Path key '{key}' not found.")
87
+ raise
88
+
89
+ def update(self, new_paths: Dict[str, Union[str, Path]], overwrite: bool = False) -> None:
90
+ """
91
+ Adds new paths or overwrites existing ones in the manager.
92
+
93
+ Args:
94
+ new_paths (Dict[str, Union[str, Path]]): A dictionary where keys are
95
+ the identifiers and values are the
96
+ Path objects or strings to store.
97
+ overwrite (bool): If False (default), raises a KeyError if any
98
+ key in new_paths already exists. If True,
99
+ allows overwriting existing keys.
100
+ """
101
+ if not overwrite:
102
+ for key in new_paths:
103
+ if key in self._paths:
104
+ raise KeyError(
105
+ f"Path key '{key}' already exists in the manager. To replace it, call update() with overwrite=True."
106
+ )
107
+
108
+ # Resolve any string paths to Path objects before storing
109
+ resolved_new_paths = {k: Path(v) for k, v in new_paths.items()}
110
+ self._paths.update(resolved_new_paths)
111
+
112
+ def make_dirs(self, keys: Optional[List[str]] = None, verbose: bool = False) -> None:
113
+ """
114
+ Creates directory structures for registered paths in writable locations.
115
+
116
+ This method identifies paths that are directories (no file suffix) and creates them on the filesystem.
117
+
118
+ In a bundled application, this method will NOT attempt to create directories inside the read-only app package, preventing crashes. It
119
+ will only operate on paths outside of the package (e.g., user data dirs).
120
+
121
+ Args:
122
+ keys (Optional[List[str]]): If provided, only the directories
123
+ corresponding to these keys will be
124
+ created. If None (default), all
125
+ registered directory paths are used.
126
+ verbose (bool): If True, prints a message for each action.
127
+ """
128
+ path_items = []
129
+ if keys:
130
+ for key in keys:
131
+ if key in self._paths:
132
+ path_items.append((key, self._paths[key]))
133
+ elif verbose:
134
+ _LOGGER.warning(f"⚠️ Key '{key}' not found in PathManager, skipping.")
135
+ else:
136
+ path_items = self._paths.items()
137
+
138
+ # Get the package root to check against.
139
+ package_root = self._paths.get("ROOT")
140
+
141
+ for key, path in path_items:
142
+ if path.suffix: # It's a file, not a directory
143
+ continue
144
+
145
+ # --- THE CRITICAL CHECK ---
146
+ # Determine if the path is inside the main application package.
147
+ is_internal_path = package_root and path.is_relative_to(package_root)
148
+
149
+ if self._is_bundled and is_internal_path:
150
+ if verbose:
151
+ _LOGGER.warning(f"⚠️ Skipping internal directory '{key}' in bundled app (read-only).")
152
+ continue
153
+ # -------------------------
154
+
155
+ if verbose:
156
+ _LOGGER.info(f"📁 Ensuring directory exists for key '{key}': {path}")
157
+
158
+ path.mkdir(parents=True, exist_ok=True)
159
+
160
+ def status(self) -> None:
161
+ """
162
+ Checks the status of all registered paths on the filesystem and prints a formatted report.
163
+ """
164
+ report = {}
165
+ for key, path in self.items():
166
+ if path.is_dir():
167
+ report[key] = "📁 Directory"
168
+ elif path.is_file():
169
+ report[key] = "📄 File"
170
+ else:
171
+ report[key] = "❌ Not Found"
172
+
173
+ print("\n--- Path Status Report ---")
174
+ pprint(report)
175
+
176
+ def __repr__(self) -> str:
177
+ """Provides a string representation of the stored paths."""
178
+ path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
179
+ return f"PathManager(\n{path_list}\n)"
180
+
181
+ # --- Dictionary-Style Methods ---
182
+ def __getitem__(self, key: str) -> Path:
183
+ """Allows dictionary-style getting, e.g., PM['my_key']"""
184
+ return self.get(key)
185
+
186
+ def __setitem__(self, key: str, value: Union[str, Path]):
187
+ """Allows dictionary-style setting, does not allow overwriting, e.g., PM['my_key'] = path"""
188
+ self.update({key: value}, overwrite=False)
189
+
190
+ def __contains__(self, key: str) -> bool:
191
+ """Allows checking for a key's existence, e.g., if 'my_key' in PM"""
192
+ return key in self._paths
193
+
194
+ def __len__(self) -> int:
195
+ """Allows getting the number of paths, e.g., len(PM)"""
196
+ return len(self._paths)
197
+
198
+ def keys(self):
199
+ """Returns all registered path keys."""
200
+ return self._paths.keys()
201
+
202
+ def values(self):
203
+ """Returns all registered Path objects."""
204
+ return self._paths.values()
205
+
206
+ def items(self):
207
+ """Returns all registered (key, Path) pairs."""
208
+ return self._paths.items()
209
+
210
+
211
+ def info():
212
+ _script_info(__all__)
ml_tools/utilities.py CHANGED
@@ -25,7 +25,7 @@ __all__ = [
25
25
  "serialize_object",
26
26
  "deserialize_object",
27
27
  "distribute_datasets_by_target",
28
- "train_dataset_orchestrator"
28
+ "train_dataset_orchestrator",
29
29
  ]
30
30
 
31
31
 
@@ -645,7 +645,7 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
645
645
 
646
646
  class LogKeys:
647
647
  """
648
- Used for ML scripts only
648
+ Used internally for ML scripts.
649
649
 
650
650
  Centralized keys for logging and history.
651
651
  """