dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/_logger.py ADDED
@@ -0,0 +1,36 @@
1
+ import logging
2
+ import sys
3
+
4
+
5
+ def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
6
+ """
7
+ Initializes and returns a configured logger instance.
8
+
9
+ - `logger.info()`
10
+ - `logger.warning()`
11
+ - `logger.error()` the program can potentially recover.
12
+ - `logger.critical()` the program is going to crash.
13
+ """
14
+ logger = logging.getLogger(name)
15
+ logger.setLevel(level)
16
+
17
+ # Prevents adding handlers multiple times if the function is called again
18
+ if not logger.handlers:
19
+ handler = logging.StreamHandler(sys.stdout)
20
+
21
+ # Define the format string and the date format separately
22
+ log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
24
+
25
+ # Pass both the format and the date format to the Formatter
26
+ formatter = logging.Formatter(log_format, datefmt=date_format)
27
+
28
+ handler.setFormatter(formatter)
29
+ logger.addHandler(handler)
30
+
31
+ logger.propagate = False
32
+
33
+ return logger
34
+
35
+ # Create a single logger instance to be imported by other modules
36
+ _LOGGER = _get_logger()
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
  from torch import nn
3
- from .utilities import _script_info
3
+ from ._script_info import _script_info
4
4
 
5
5
 
6
6
  __all__ = [
@@ -0,0 +1,8 @@
1
+
2
+ def _script_info(all_data: list[str]):
3
+ """
4
+ List available names.
5
+ """
6
+ print("Available functions and objects:")
7
+ for i, name in enumerate(all_data, start=1):
8
+ print(f"{i} - {name}")
@@ -2,12 +2,11 @@ from pathlib import Path
2
2
  from datetime import datetime
3
3
  from typing import Union, List, Dict, Any
4
4
  import pandas as pd
5
- from openpyxl.styles import Font, PatternFill
6
5
  import traceback
7
6
  import json
8
- from .utilities import sanitize_filename, _script_info, make_fullpath
9
- import logging
10
- import sys
7
+ from .path_manager import sanitize_filename, make_fullpath
8
+ from ._script_info import _script_info
9
+ from ._logger import _LOGGER
11
10
 
12
11
 
13
12
  __all__ = [
@@ -38,9 +37,6 @@ def custom_logger(
38
37
  - dict[str, scalar] → .json
39
38
  Dictionary is treated as structured data and serialized as JSON.
40
39
 
41
- - pandas.DataFrame → .xlsx
42
- Written to an Excel file with styled headers.
43
-
44
40
  - str → .log
45
41
  Plain text string is written to a .log file.
46
42
 
@@ -72,9 +68,6 @@ def custom_logger(
72
68
  else:
73
69
  _log_dict_to_json(data, base_path.with_suffix(".json"))
74
70
 
75
- elif isinstance(data, pd.DataFrame):
76
- _log_dataframe_to_xlsx(data, base_path.with_suffix(".xlsx"))
77
-
78
71
  elif isinstance(data, str):
79
72
  _log_string_to_log(data, base_path.with_suffix(".log"))
80
73
 
@@ -117,27 +110,6 @@ def _log_dict_to_csv(data: Dict[Any, List[Any]], path: Path) -> None:
117
110
  df.to_csv(path, index=False)
118
111
 
119
112
 
120
- def _log_dataframe_to_xlsx(data: pd.DataFrame, path: Path) -> None:
121
- writer = pd.ExcelWriter(path, engine='openpyxl')
122
- data.to_excel(writer, index=True, sheet_name='Data')
123
-
124
- workbook = writer.book
125
- worksheet = writer.sheets['Data']
126
-
127
- header_font = Font(bold=True)
128
- header_fill = PatternFill(
129
- start_color="ADD8E6", # Light blue
130
- end_color="ADD8E6",
131
- fill_type="solid"
132
- )
133
-
134
- for cell in worksheet[1]:
135
- cell.font = header_font
136
- cell.fill = header_fill
137
-
138
- writer.close()
139
-
140
-
141
113
  def _log_string_to_log(data: str, path: Path) -> None:
142
114
  with open(path, 'w', encoding='utf-8') as f:
143
115
  f.write(data.strip() + '\n')
@@ -155,38 +127,4 @@ def _log_dict_to_json(data: Dict[Any, Any], path: Path) -> None:
155
127
 
156
128
 
157
129
  def info():
158
- _script_info(__all__)
159
-
160
-
161
- def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
162
- """
163
- Initializes and returns a configured logger instance.
164
-
165
- - `logger.info()`
166
- - `logger.warning()`
167
- - `logger.error()` the program can potentially recover.
168
- - `logger.critical()` the program is going to crash.
169
- """
170
- logger = logging.getLogger(name)
171
- logger.setLevel(level)
172
-
173
- # Prevents adding handlers multiple times if the function is called again
174
- if not logger.handlers:
175
- handler = logging.StreamHandler(sys.stdout)
176
-
177
- # Define the format string and the date format separately
178
- log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
179
- date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
180
-
181
- # Pass both the format and the date format to the Formatter
182
- formatter = logging.Formatter(log_format, datefmt=date_format)
183
-
184
- handler.setFormatter(formatter)
185
- logger.addHandler(handler)
186
-
187
- logger.propagate = False
188
-
189
- return logger
190
-
191
- # Create a single logger instance to be imported by other modules
192
- _LOGGER = _get_logger()
130
+ _script_info(__all__)
@@ -3,12 +3,10 @@ from pandas.api.types import is_numeric_dtype
3
3
  import numpy as np
4
4
  import matplotlib.pyplot as plt
5
5
  import seaborn as sns
6
- from IPython import get_ipython
7
- from IPython.display import clear_output
8
- import time
9
6
  from typing import Union, Literal, Dict, Tuple, List, Optional
10
7
  from pathlib import Path
11
- from .utilities import sanitize_filename, _script_info, make_fullpath
8
+ from .path_manager import sanitize_filename, make_fullpath
9
+ from ._script_info import _script_info
12
10
  import re
13
11
 
14
12
 
@@ -22,7 +20,6 @@ __all__ = [
22
20
  "drop_columns_with_missing_data",
23
21
  "split_continuous_binary",
24
22
  "plot_correlation_heatmap",
25
- "check_value_distributions",
26
23
  "plot_value_distributions",
27
24
  "clip_outliers_single",
28
25
  "clip_outliers_multi",
@@ -343,63 +340,6 @@ def plot_correlation_heatmap(df: pd.DataFrame,
343
340
  plt.close()
344
341
 
345
342
 
346
- def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
347
- """
348
- Analyzes value counts for each column in a DataFrame, optionally plots distributions,
349
- and saves them as .png files in the specified directory.
350
-
351
- Args:
352
- df (pd.DataFrame): The dataset to analyze.
353
- view_frequencies (bool): Print relative frequencies instead of value counts.
354
- bin_threshold (int): Threshold of unique values to start using bins.
355
- skip_cols_with_key (str | None): Skip column names containing the key. If None, don't skip any column.
356
-
357
- Notes:
358
- - Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
359
- """
360
- # cherry-pick columns
361
- if skip_cols_with_key is not None:
362
- columns = [col for col in df.columns if skip_cols_with_key not in col]
363
- else:
364
- columns = df.columns.to_list()
365
-
366
- for col in columns:
367
- if _is_notebook():
368
- clear_output(wait=False)
369
- if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
370
- bins_number = 10
371
- binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
372
- while binned.nunique() <= 2:
373
- bins_number -= 1
374
- binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
375
- if bins_number <= 2:
376
- break
377
-
378
- if binned.nunique() <= 2:
379
- view_std = df[col].value_counts(ascending=False)
380
- else:
381
- view_std = binned.value_counts(sort=False)
382
-
383
- else:
384
- view_std = df[col].value_counts(ascending=False)
385
-
386
- view_std.name = col
387
-
388
- # unlikely scenario where the series is empty
389
- if view_std.sum() == 0:
390
- view_freq = view_std
391
- else:
392
- view_freq = view_std / view_std.sum()
393
- # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
394
- view_freq.name = col
395
-
396
- # Print value counts
397
- print(view_freq if view_frequencies else view_std)
398
-
399
- time.sleep(1)
400
- user_input_ = input("Press enter to continue")
401
-
402
-
403
343
  def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
404
344
  """
405
345
  Plots and saves the value distributions for all (or selected) columns in a DataFrame,
@@ -691,9 +631,5 @@ def standardize_percentages(
691
631
  return df_copy
692
632
 
693
633
 
694
- def _is_notebook():
695
- return get_ipython() is not None
696
-
697
-
698
634
  def info():
699
635
  _script_info(__all__)
ml_tools/datasetmaster.py CHANGED
@@ -13,8 +13,9 @@ from torchvision.datasets import ImageFolder
13
13
  from torchvision import transforms
14
14
  import matplotlib.pyplot as plt
15
15
  from pathlib import Path
16
- from .utilities import _script_info, make_fullpath
17
- from .logger import _LOGGER
16
+ from .path_manager import make_fullpath
17
+ from ._logger import _LOGGER
18
+ from ._script_info import _script_info
18
19
 
19
20
 
20
21
  # --- public-facing API ---
@@ -0,0 +1,249 @@
1
+ from ._script_info import _script_info
2
+ from ._logger import _LOGGER
3
+ from .path_manager import make_fullpath, list_files_by_extension
4
+ from .keys import ModelSaveKeys
5
+
6
+ from typing import Union, Literal, Dict, Any, Optional, List
7
+ from pathlib import Path
8
+ import json
9
+
10
+ import joblib
11
+ import numpy as np
12
+ # Inference models
13
+ import xgboost
14
+ import lightgbm
15
+
16
+
17
+ __all__ = [
18
+ "InferenceHandler",
19
+ "model_report"
20
+ ]
21
+
22
+
23
+ class InferenceHandler:
24
+ """
25
+ Handles loading ensemble models and performing inference for either regression or classification tasks.
26
+ """
27
+ def __init__(self,
28
+ models_dir: Union[str,Path],
29
+ task: Literal["classification", "regression"],
30
+ verbose: bool=True) -> None:
31
+ """
32
+ Initializes the handler by loading all models from a directory.
33
+
34
+ Args:
35
+ models_dir (Path): The directory containing the saved .joblib model files.
36
+ task ("regression" | "classification"): The type of task the models perform.
37
+ """
38
+ self.models: Dict[str, Any] = dict()
39
+ self.task: str = task
40
+ self.verbose = verbose
41
+ self._feature_names: Optional[List[str]] = None
42
+
43
+ model_files = list_files_by_extension(directory=models_dir, extension="joblib")
44
+
45
+ for fname, fpath in model_files.items():
46
+ try:
47
+ full_object: dict
48
+ full_object = _deserialize_object(filepath=fpath,
49
+ verbose=self.verbose,
50
+ raise_on_error=True) # type: ignore
51
+
52
+ model: Any = full_object[ModelSaveKeys.MODEL]
53
+ target_name: str = full_object[ModelSaveKeys.TARGET]
54
+ feature_names_list: List[str] = full_object[ModelSaveKeys.FEATURES]
55
+
56
+ # Check that feature names match
57
+ if self._feature_names is None:
58
+ # Store the feature names from the first model loaded.
59
+ self._feature_names = feature_names_list
60
+ elif self._feature_names != feature_names_list:
61
+ # Add a warning if subsequent models have different feature names.
62
+ _LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
63
+
64
+ self.models[target_name] = model
65
+ if self.verbose:
66
+ _LOGGER.info(f"✅ Loaded model for target: {target_name}")
67
+
68
+ except Exception as e:
69
+ _LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
70
+
71
+ @property
72
+ def feature_names(self) -> List[str]:
73
+ """
74
+ Getter for the list of feature names the models expect.
75
+ Returns an empty list if no models were loaded.
76
+ """
77
+ return self._feature_names if self._feature_names is not None else []
78
+
79
+ def predict(self, features: np.ndarray) -> Dict[str, Any]:
80
+ """
81
+ Predicts on a single feature vector.
82
+
83
+ Args:
84
+ features (np.ndarray): A 1D or 2D NumPy array for a single sample.
85
+
86
+ Returns:
87
+ Dict[str, Any]: A dictionary where keys are target names.
88
+ - For regression: The value is the single predicted float.
89
+ - For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
90
+ """
91
+ if features.ndim == 1:
92
+ features = features.reshape(1, -1)
93
+
94
+ if features.shape[0] != 1:
95
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
96
+
97
+ results: Dict[str, Any] = dict()
98
+ for target_name, model in self.models.items():
99
+ if self.task == "regression":
100
+ prediction = model.predict(features)
101
+ results[target_name] = prediction.item()
102
+ else: # Classification
103
+ label = model.predict(features)[0]
104
+ probabilities = model.predict_proba(features)[0]
105
+ results[target_name] = {ModelSaveKeys.CLASSIFICATION_LABEL: label,
106
+ ModelSaveKeys.CLASSIFICATION_PROBABILITIES: probabilities}
107
+
108
+ if self.verbose:
109
+ _LOGGER.info("✅ Inference process complete.")
110
+ return results
111
+
112
+ def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
113
+ """
114
+ Predicts on a batch of feature vectors.
115
+
116
+ Args:
117
+ features (np.ndarray): A 2D NumPy array where each row is a sample.
118
+
119
+ Returns:
120
+ Dict[str, Any]: A dictionary where keys are target names.
121
+ - For regression: The value is a NumPy array of predictions.
122
+ - For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
123
+ """
124
+ if features.ndim != 2:
125
+ raise ValueError("Input for batch prediction must be a 2D array.")
126
+
127
+ results: Dict[str, Any] = dict()
128
+ for target_name, model in self.models.items():
129
+ if self.task == "regression":
130
+ results[target_name] = model.predict(features)
131
+ else: # Classification
132
+ labels = model.predict(features)
133
+ probabilities = model.predict_proba(features)
134
+ results[target_name] = {"labels": labels, "probabilities": probabilities}
135
+
136
+ if self.verbose:
137
+ _LOGGER.info("✅ Inference process complete.")
138
+
139
+ return results
140
+
141
+
142
+ def model_report(
143
+ model_path: Union[str,Path],
144
+ output_dir: Optional[Union[str,Path]] = None,
145
+ verbose: bool = True
146
+ ) -> Dict[str, Any]:
147
+ """
148
+ Deserializes a model and generates a summary report.
149
+
150
+ This function loads a serialized model object (joblib), prints a summary to the
151
+ console (if verbose), and saves a detailed JSON report.
152
+
153
+ Args:
154
+ model_path (str): The path to the serialized model file.
155
+ output_dir (str, optional): Directory to save the JSON report.
156
+ If None, it defaults to the same directory as the model file.
157
+ verbose (bool, optional): If True, prints summary information
158
+ to the console. Defaults to True.
159
+
160
+ Returns:
161
+ (Dict[str, Any]): A dictionary containing the model metadata.
162
+
163
+ Raises:
164
+ FileNotFoundError: If the model_path does not exist.
165
+ KeyError: If the deserialized object is missing required keys from `ModelSaveKeys`.
166
+ """
167
+ # 1. Convert to Path object
168
+ model_p = make_fullpath(model_path)
169
+
170
+ # --- 2. Deserialize and Extract Info ---
171
+ try:
172
+ full_object: dict = _deserialize_object(model_p) # type: ignore
173
+ model = full_object[ModelSaveKeys.MODEL]
174
+ target = full_object[ModelSaveKeys.TARGET]
175
+ features = full_object[ModelSaveKeys.FEATURES]
176
+ except FileNotFoundError:
177
+ _LOGGER.error(f"❌ Model file not found at '{model_p}'")
178
+ raise
179
+ except (KeyError, TypeError) as e:
180
+ _LOGGER.error(
181
+ f"❌ The serialized object is missing required keys '{ModelSaveKeys.MODEL}', '{ModelSaveKeys.TARGET}', '{ModelSaveKeys.FEATURES}'"
182
+ )
183
+ raise e
184
+
185
+ # --- 3. Print Summary to Console (if verbose) ---
186
+ if verbose:
187
+ print("\n--- 📝 Model Summary ---")
188
+ print(f"Source File: {model_p.name}")
189
+ print(f"Model Type: {type(model).__name__}")
190
+ print(f"Target: {target}")
191
+ print(f"Feature Count: {len(features)}")
192
+ print("-----------------------")
193
+
194
+ # --- 4. Generate JSON Report ---
195
+ report_data = {
196
+ "source_file": model_p.name,
197
+ "model_type": str(type(model)),
198
+ "target_name": target,
199
+ "feature_count": len(features),
200
+ "feature_names": features
201
+ }
202
+
203
+ # Determine output path
204
+ output_p = make_fullpath(output_dir, make=True) if output_dir else model_p.parent
205
+ json_filename = model_p.stem + "_info.json"
206
+ json_filepath = output_p / json_filename
207
+
208
+ try:
209
+ with open(json_filepath, 'w') as f:
210
+ json.dump(report_data, f, indent=4)
211
+ if verbose:
212
+ _LOGGER.info(f"✅ JSON report saved to: '{json_filepath}'")
213
+ except PermissionError:
214
+ _LOGGER.error(f"❌ Permission denied to write JSON report at '{json_filepath}'")
215
+
216
+ # --- 5. Return the extracted data ---
217
+ return report_data
218
+
219
+
220
+ # Local implementation to avoid calling utilities' dependencies
221
+ def _deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_error: bool=True) -> Optional[Any]:
222
+ """
223
+ Loads a serialized object from a .joblib file.
224
+
225
+ Parameters:
226
+ filepath (str | Path): Full path to the serialized .joblib file.
227
+
228
+ Returns:
229
+ (Any | None): The deserialized Python object, or None if loading fails.
230
+ """
231
+ true_filepath = make_fullpath(filepath)
232
+
233
+ try:
234
+ obj = joblib.load(true_filepath)
235
+ except (IOError, OSError, EOFError, TypeError, ValueError) as e:
236
+ message = f"❌ Failed to deserialize object from '{true_filepath}': {e}"
237
+ if raise_on_error:
238
+ raise Exception(message)
239
+ else:
240
+ print(message)
241
+ return None
242
+ else:
243
+ if verbose:
244
+ print(f"\n✅ Loaded object of type '{type(obj)}'")
245
+ return obj
246
+
247
+
248
+ def info():
249
+ _script_info(__all__)