dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from typing import Union, Optional, Literal
25
25
  from .path_manager import sanitize_filename, make_fullpath
26
26
  from ._script_info import _script_info
27
27
  from ._logger import _LOGGER
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
29
 
30
30
 
31
31
  __all__ = [
@@ -112,7 +112,7 @@ def evaluate_model_classification(
112
112
  report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
113
113
  plt.figure(figsize=figsize)
114
114
  sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
115
- annot_kws={"size": base_fontsize - 4})
115
+ annot_kws={"size": base_fontsize - 4}, vmin=0.0, vmax=1.0)
116
116
  plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
117
117
  plt.xticks(fontsize=base_fontsize - 2)
118
118
  plt.yticks(fontsize=base_fontsize - 2)
@@ -133,6 +133,7 @@ def evaluate_model_classification(
133
133
  normalize="true",
134
134
  ax=ax
135
135
  )
136
+ disp.im_.set_clim(vmin=0.0, vmax=1.0)
136
137
 
137
138
  ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
138
139
  ax.tick_params(axis='both', labelsize=base_fontsize)
@@ -327,7 +328,8 @@ def plot_calibration_curve(
327
328
  target_name: str,
328
329
  figure_size: tuple = (10, 10),
329
330
  base_fontsize: int = 24,
330
- n_bins: int = 15
331
+ n_bins: int = 15,
332
+ line_color: str = 'darkorange'
331
333
  ) -> plt.Figure: # type: ignore
332
334
  """
333
335
  Plots the calibration curve (reliability diagram) for a classifier.
@@ -348,22 +350,63 @@ def plot_calibration_curve(
348
350
  """
349
351
  fig, ax = plt.subplots(figsize=figure_size)
350
352
 
351
- disp = CalibrationDisplay.from_estimator(
352
- model,
353
- x_test,
354
- y_test,
355
- n_bins=n_bins,
356
- ax=ax
353
+ # --- Step 1: Get probabilities from the estimator ---
354
+ # We do this manually so we can pass them to from_predictions
355
+ try:
356
+ y_prob = model.predict_proba(x_test)
357
+ # Use probabilities for the positive class (assuming binary)
358
+ y_score = y_prob[:, 1]
359
+ except Exception as e:
360
+ _LOGGER.error(f"Could not get probabilities from model: {e}")
361
+ plt.close(fig)
362
+ return fig # Return empty figure
363
+
364
+ # --- Step 2: Get binned data *without* plotting ---
365
+ with plt.ioff():
366
+ fig_temp, ax_temp = plt.subplots()
367
+ cal_display_temp = CalibrationDisplay.from_predictions(
368
+ y_test,
369
+ y_score,
370
+ n_bins=n_bins,
371
+ ax=ax_temp,
372
+ name="temp"
373
+ )
374
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
375
+ plt.close(fig_temp)
376
+
377
+ # --- Step 3: Build the plot from scratch on ax ---
378
+
379
+ # 3a. Plot the ideal diagonal line
380
+ ax.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
381
+
382
+ # 3b. Use regplot for the regression line and its CI
383
+ sns.regplot(
384
+ x=line_x,
385
+ y=line_y,
386
+ ax=ax,
387
+ scatter=False, # No scatter dots
388
+ label=f"Calibration Curve ({n_bins} bins)",
389
+ line_kws={
390
+ 'color': line_color,
391
+ 'linestyle': '--',
392
+ 'linewidth': 2
393
+ }
357
394
  )
358
395
 
396
+ # --- Step 4: Apply original formatting ---
359
397
  ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
360
398
  ax.tick_params(axis='both', labelsize=base_fontsize - 2)
361
399
  ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
362
400
  ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
363
- ax.legend(fontsize=base_fontsize - 4)
401
+
402
+ # Set limits
403
+ ax.set_ylim(0.0, 1.0)
404
+ ax.set_xlim(0.0, 1.0)
405
+
406
+ ax.legend(fontsize=base_fontsize - 4, loc='lower right')
364
407
  fig.tight_layout()
365
408
 
366
- # Save figure
409
+ # --- Step 5: Save figure (using original logic) ---
367
410
  save_path = make_fullpath(save_dir, make=True)
368
411
  sanitized_target_name = sanitize_filename(target_name)
369
412
  full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
@@ -1,7 +1,6 @@
1
1
  from typing import Union, Literal, Dict, Any, Optional, List
2
2
  from pathlib import Path
3
3
  import json
4
- import joblib
5
4
  import numpy as np
6
5
  # Inference models
7
6
  import xgboost
@@ -10,16 +9,17 @@ import lightgbm
10
9
  from ._script_info import _script_info
11
10
  from ._logger import _LOGGER
12
11
  from .path_manager import make_fullpath, list_files_by_extension
13
- from .keys import EnsembleKeys
12
+ from ._keys import EnsembleKeys
13
+ from .serde import deserialize_object
14
14
 
15
15
 
16
16
  __all__ = [
17
- "InferenceHandler",
17
+ "DragonEnsembleInferenceHandler",
18
18
  "model_report"
19
19
  ]
20
20
 
21
21
 
22
- class InferenceHandler:
22
+ class DragonEnsembleInferenceHandler:
23
23
  """
24
24
  Handles loading ensemble models and performing inference for either regression or classification tasks.
25
25
  """
@@ -44,9 +44,9 @@ class InferenceHandler:
44
44
  for fname, fpath in model_files.items():
45
45
  try:
46
46
  full_object: dict
47
- full_object = _deserialize_object(filepath=fpath,
47
+ full_object = deserialize_object(filepath=fpath,
48
48
  verbose=self.verbose,
49
- raise_on_error=True) # type: ignore
49
+ expected_type=dict)
50
50
 
51
51
  model: Any = full_object[EnsembleKeys.MODEL]
52
52
  target_name: str = full_object[EnsembleKeys.TARGET]
@@ -170,7 +170,7 @@ def model_report(
170
170
 
171
171
  # --- 2. Deserialize and Extract Info ---
172
172
  try:
173
- full_object: dict = _deserialize_object(model_p) # type: ignore
173
+ full_object: dict = deserialize_object(model_p, expected_type=dict, verbose=verbose) # type: ignore
174
174
  model = full_object[EnsembleKeys.MODEL]
175
175
  target = full_object[EnsembleKeys.TARGET]
176
176
  features = full_object[EnsembleKeys.FEATURES]
@@ -218,31 +218,5 @@ def model_report(
218
218
  return report_data
219
219
 
220
220
 
221
- # Local implementation to avoid calling utilities dependencies
222
- def _deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_error: bool=True) -> Optional[Any]:
223
- """
224
- Loads a serialized object from a .joblib file.
225
-
226
- Parameters:
227
- filepath (str | Path): Full path to the serialized .joblib file.
228
-
229
- Returns:
230
- (Any | None): The deserialized Python object, or None if loading fails.
231
- """
232
- true_filepath = make_fullpath(filepath)
233
-
234
- try:
235
- obj = joblib.load(true_filepath)
236
- except (IOError, OSError, EOFError, TypeError, ValueError) as e:
237
- _LOGGER.error(f"Failed to deserialize object from '{true_filepath}'.")
238
- if raise_on_error:
239
- raise e
240
- return None
241
- else:
242
- if verbose:
243
- _LOGGER.info(f"Loaded object of type '{type(obj)}'")
244
- return obj
245
-
246
-
247
221
  def info():
248
222
  _script_info(__all__)
@@ -17,7 +17,7 @@ from .utilities import yield_dataframes_from_dir, train_dataset_yielder
17
17
  from .serde import serialize_object_filename
18
18
  from .path_manager import sanitize_filename, make_fullpath
19
19
  from ._script_info import _script_info
20
- from .keys import EnsembleKeys
20
+ from ._keys import EnsembleKeys
21
21
  from ._logger import _LOGGER
22
22
  from .ensemble_evaluation import (evaluate_model_classification,
23
23
  plot_roc_curve,
@@ -219,7 +219,7 @@ def discretize_categorical_values(
219
219
  _LOGGER.error(f"'categorical_info' is not a dictionary, or is empty.")
220
220
  raise ValueError()
221
221
 
222
- _, total_features = input_array.shape
222
+ _, total_features = working_array.shape
223
223
  for col_idx, cardinality in categorical_info.items():
224
224
  if not isinstance(col_idx, int):
225
225
  _LOGGER.error(f"Column index key {col_idx} is not an integer.")
@@ -8,7 +8,7 @@ from .path_manager import make_fullpath, list_csv_paths, sanitize_filename
8
8
  from .utilities import yield_dataframes_from_dir
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
- from .SQL import DatabaseManager
11
+ from .SQL import DragonSQL
12
12
  from ._schema import FeatureSchema
13
13
 
14
14
 
@@ -262,7 +262,7 @@ def _save_result(
262
262
  result_dict: dict,
263
263
  save_format: Literal['csv', 'sqlite', 'both'],
264
264
  csv_path: Path,
265
- db_manager: Optional[DatabaseManager] = None,
265
+ db_manager: Optional[DragonSQL] = None,
266
266
  db_table_name: Optional[str] = None,
267
267
  categorical_mappings: Optional[Dict[str, Dict[str, int]]] = None
268
268
  ):
ml_tools/path_manager.py CHANGED
@@ -9,7 +9,7 @@ from ._logger import _LOGGER
9
9
 
10
10
 
11
11
  __all__ = [
12
- "PathManager",
12
+ "DragonPathManager",
13
13
  "make_fullpath",
14
14
  "sanitize_filename",
15
15
  "list_csv_paths",
@@ -18,7 +18,7 @@ __all__ = [
18
18
  ]
19
19
 
20
20
 
21
- class PathManager:
21
+ class DragonPathManager:
22
22
  """
23
23
  Manages and stores a project's file paths, acting as a centralized
24
24
  "path database". It supports both development mode and applications
@@ -43,7 +43,7 @@ class PathManager:
43
43
 
44
44
  Args:
45
45
  anchor_file (str): The path to a file within your package, typically
46
- the `__file__` of the script where PathManager
46
+ the `__file__` of the script where DragonPathManager
47
47
  is instantiated. This is used to locate the
48
48
  package root directory.
49
49
  base_directories (List[str] | None): An optional list of strings,
@@ -149,7 +149,7 @@ class PathManager:
149
149
  if key in self._paths:
150
150
  path_items.append((key, self._paths[key]))
151
151
  elif verbose:
152
- _LOGGER.warning(f"Key '{key}' not found in PathManager, skipping.")
152
+ _LOGGER.warning(f"Key '{key}' not found in DragonPathManager, skipping.")
153
153
  else:
154
154
  path_items = self._paths.items()
155
155
 
@@ -194,7 +194,7 @@ class PathManager:
194
194
  def __repr__(self) -> str:
195
195
  """Provides a string representation of the stored paths."""
196
196
  path_list = "\n".join(f" '{k}': '{v}'" for k, v in self._paths.items())
197
- return f"PathManager(\n{path_list}\n)"
197
+ return f"DragonPathManager(\n{path_list}\n)"
198
198
 
199
199
  # --- Dictionary-Style Methods ---
200
200
  def __getitem__(self, key: str) -> Path:
ml_tools/serde.py CHANGED
@@ -85,7 +85,7 @@ def serialize_object(obj: Any, file_path: Path, verbose: bool = True, raise_on_e
85
85
  return None
86
86
  else:
87
87
  if verbose:
88
- if isinstance(obj, _SIMPLE_TYPES):
88
+ if type(obj) in _SIMPLE_TYPES:
89
89
  _LOGGER.info(f"Object of type '{type(obj)}' saved to '{file_path}'")
90
90
  else:
91
91
  _LOGGER.info(f"Object '{obj}' saved to '{file_path}'")
@@ -140,7 +140,7 @@ def deserialize_object(
140
140
 
141
141
  if verbose:
142
142
  # log special objects
143
- if isinstance(obj, _SIMPLE_TYPES):
143
+ if type(obj) in _SIMPLE_TYPES:
144
144
  _LOGGER.info(f"Loaded object of type '{type(obj)}' from '{true_filepath}'.")
145
145
  else:
146
146
  _LOGGER.info(f"Loaded object '{obj}' from '{true_filepath}'.")
ml_tools/utilities.py CHANGED
@@ -7,16 +7,18 @@ from typing import Literal, Union, Optional, Any, Iterator, Tuple, overload
7
7
  from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
8
8
  from ._script_info import _script_info
9
9
  from ._logger import _LOGGER
10
+ from ._schema import FeatureSchema
10
11
 
11
12
 
12
- # Keep track of available tools
13
13
  __all__ = [
14
14
  "load_dataframe",
15
15
  "load_dataframe_greedy",
16
+ "load_dataframe_with_schema",
16
17
  "yield_dataframes_from_dir",
17
18
  "merge_dataframes",
18
19
  "save_dataframe_filename",
19
20
  "save_dataframe",
21
+ "save_dataframe_with_schema",
20
22
  "distribute_dataset_by_target",
21
23
  "train_dataset_orchestrator",
22
24
  "train_dataset_yielder"
@@ -96,6 +98,7 @@ def load_dataframe(
96
98
  elif kind == "polars":
97
99
  pl_kwargs: dict[str,Any]
98
100
  pl_kwargs = {}
101
+ pl_kwargs['null_values'] = ["", " "]
99
102
  if use_columns:
100
103
  pl_kwargs['columns'] = use_columns
101
104
 
@@ -173,6 +176,68 @@ def load_dataframe_greedy(directory: Union[str, Path],
173
176
  return df
174
177
 
175
178
 
179
+ def load_dataframe_with_schema(
180
+ df_path: Union[str, Path],
181
+ schema: "FeatureSchema",
182
+ all_strings: bool = False,
183
+ ) -> Tuple[pd.DataFrame, str]:
184
+ """
185
+ Loads a CSV file into a Pandas DataFrame, strictly validating its
186
+ feature columns against a FeatureSchema.
187
+
188
+ This function wraps `load_dataframe`. After loading, it validates
189
+ that the first N columns of the DataFrame (where N =
190
+ len(schema.feature_names)) contain *exactly* the set of features
191
+ specified in the schema.
192
+
193
+ - If the columns are present but out of order, they are reordered.
194
+ - If any required feature is missing from the first N columns, it fails.
195
+ - If any extra column is found within the first N columns, it fails.
196
+
197
+ Columns *after* the first N are considered target columns and are
198
+ logged for verification.
199
+
200
+ Args:
201
+ df_path (str, Path):
202
+ The path to the CSV file.
203
+ schema (FeatureSchema):
204
+ The schema object to validate against.
205
+ all_strings (bool):
206
+ If True, loads all columns as string data types.
207
+
208
+ Returns:
209
+ (Tuple[pd.DataFrame, str]):
210
+ A tuple containing the loaded, validated (and possibly
211
+ reordered) pandas DataFrame and the base name of the file.
212
+
213
+ Raises:
214
+ ValueError:
215
+ - If the DataFrame is missing columns required by the schema
216
+ within its first N columns.
217
+ - If the DataFrame's first N columns contain unexpected
218
+ columns that are not in the schema.
219
+ FileNotFoundError:
220
+ If the file does not exist at the given path.
221
+ """
222
+ # Step 1: Load the dataframe using the original function
223
+ try:
224
+ df, df_name = load_dataframe(
225
+ df_path=df_path,
226
+ use_columns=None, # Load all columns for validation
227
+ kind="pandas",
228
+ all_strings=all_strings,
229
+ verbose=True
230
+ )
231
+ except Exception as e:
232
+ _LOGGER.error(f"Failed during initial load for schema validation: {e}")
233
+ raise e
234
+
235
+ # Step 2: Call the helper to validate and reorder
236
+ df_validated = _validate_and_reorder_schema(df=df, schema=schema)
237
+
238
+ return df_validated, df_name
239
+
240
+
176
241
  def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
177
242
  """
178
243
  Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
@@ -288,15 +353,25 @@ def save_dataframe_filename(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Uni
288
353
 
289
354
  # --- Type-specific saving logic ---
290
355
  if isinstance(df, pd.DataFrame):
291
- df.to_csv(output_path, index=False, encoding='utf-8')
356
+ # Transform "" to np.nan before saving
357
+ df_to_save = df.replace(r'^\s*$', np.nan, regex=True)
358
+ # Save
359
+ df_to_save.to_csv(output_path, index=False, encoding='utf-8')
292
360
  elif isinstance(df, pl.DataFrame):
293
- df.write_csv(output_path) # Polars defaults to utf8 and no index
361
+ # Transform empty strings to Null
362
+ df_to_save = df.with_columns(
363
+ pl.when(pl.col(pl.Utf8).str.strip() == "") # type: ignore
364
+ .then(None)
365
+ .otherwise(pl.col(pl.Utf8))
366
+ )
367
+ # Save
368
+ df_to_save.write_csv(output_path)
294
369
  else:
295
370
  # This error handles cases where an unsupported type is passed
296
371
  _LOGGER.error(f"Unsupported DataFrame type: {type(df)}. Must be pandas or polars.")
297
372
  raise TypeError()
298
373
 
299
- _LOGGER.info(f"Saved dataset: '{filename}' with shape: {df.shape}")
374
+ _LOGGER.info(f"Saved dataset: '{filename}' with shape: {df_to_save.shape}")
300
375
 
301
376
 
302
377
  def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], full_path: Path):
@@ -319,6 +394,52 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], full_path: Path):
319
394
  filename=full_path.name)
320
395
 
321
396
 
397
+ def save_dataframe_with_schema(
398
+ df: pd.DataFrame,
399
+ full_path: Path,
400
+ schema: "FeatureSchema"
401
+ ) -> None:
402
+ """
403
+ Saves a pandas DataFrame to a CSV, strictly enforcing that the
404
+ first N columns match the FeatureSchema.
405
+
406
+ This function validates that the first N columns of the DataFrame
407
+ (where N = len(schema.feature_names)) contain *exactly* the set
408
+ of features specified in the schema.
409
+
410
+ - If the columns are present but out of order, they are reordered.
411
+ - If any required feature is missing from the first N columns, it fails.
412
+ - If any extra column is found within the first N columns, it fails.
413
+
414
+ Columns *after* the first N are considered target columns and are
415
+ logged for verification.
416
+
417
+ Args:
418
+ df (pd.DataFrame):
419
+ The DataFrame to save.
420
+ full_path (Path):
421
+ The complete file path where the DataFrame will be saved.
422
+ schema (FeatureSchema):
423
+ The schema object to validate against.
424
+
425
+ Raises:
426
+ ValueError:
427
+ - If the DataFrame is missing columns required by the schema
428
+ within its first N columns.
429
+ - If the DataFrame's first N columns contain unexpected
430
+ columns that are not in the schema.
431
+ """
432
+ if not isinstance(full_path, Path) or not full_path.suffix.endswith(".csv"):
433
+ _LOGGER.error('A path object pointing to a .csv file must be provided.')
434
+ raise ValueError()
435
+
436
+ # Call the helper to validate and reorder
437
+ df_to_save = _validate_and_reorder_schema(df=df, schema=schema)
438
+
439
+ # Call the original save function
440
+ save_dataframe(df=df_to_save, full_path=full_path)
441
+
442
+
322
443
  def distribute_dataset_by_target(
323
444
  df_or_path: Union[pd.DataFrame, str, Path],
324
445
  target_columns: list[str],
@@ -431,5 +552,72 @@ def train_dataset_yielder(
431
552
  yield (df_features, df_target, feature_names, target_col)
432
553
 
433
554
 
555
+ def _validate_and_reorder_schema(
556
+ df: pd.DataFrame,
557
+ schema: "FeatureSchema"
558
+ ) -> pd.DataFrame:
559
+ """
560
+ Internal helper to validate and reorder a DataFrame against a schema.
561
+
562
+ Checks for missing, extra, and out-of-order feature columns
563
+ (the first N columns). Returns a reordered DataFrame if necessary.
564
+ Logs all actions.
565
+
566
+ Raises:
567
+ ValueError: If validation fails.
568
+ """
569
+ # Get schema and DataFrame column info
570
+ expected_features = list(schema.feature_names)
571
+ expected_set = set(expected_features)
572
+ n_features = len(expected_features)
573
+
574
+ all_df_columns = df.columns.to_list()
575
+
576
+ # --- Strict Validation ---
577
+
578
+ # 0. Check if DataFrame is long enough
579
+ if len(all_df_columns) < n_features:
580
+ _LOGGER.error(f"DataFrame has only {len(all_df_columns)} columns, but schema requires {n_features} features.")
581
+ raise ValueError()
582
+
583
+ df_feature_cols = all_df_columns[:n_features]
584
+ df_feature_set = set(df_feature_cols)
585
+ df_target_cols = all_df_columns[n_features:]
586
+
587
+ # 1. Check for missing features
588
+ missing_from_df = expected_set - df_feature_set
589
+ if missing_from_df:
590
+ _LOGGER.error(f"DataFrame's first {n_features} columns are missing required schema features: {missing_from_df}")
591
+ raise ValueError()
592
+
593
+ # 2. Check for extra (unexpected) features
594
+ extra_in_df = df_feature_set - expected_set
595
+ if extra_in_df:
596
+ _LOGGER.error(f"DataFrame's first {n_features} columns contain unexpected columns: {extra_in_df}")
597
+ raise ValueError()
598
+
599
+ # --- Reordering ---
600
+
601
+ df_to_process = df
602
+
603
+ # If we pass validation, the sets are equal. Now check order.
604
+ if df_feature_cols == expected_features:
605
+ _LOGGER.info("DataFrame feature columns already match schema order.")
606
+ else:
607
+ _LOGGER.warning("DataFrame feature columns do not match schema order. Reordering...")
608
+
609
+ # Rebuild the DataFrame with the correct feature order + target columns
610
+ new_order = expected_features + df_target_cols
611
+ df_to_process = df[new_order]
612
+
613
+ # Log the presumed target columns for user verification
614
+ if not df_target_cols:
615
+ _LOGGER.warning(f"No target columns were found after index {n_features-1}.")
616
+ else:
617
+ _LOGGER.info(f"Presumed Target Columns: {df_target_cols}")
618
+
619
+ return df_to_process # type: ignore
620
+
621
+
434
622
  def info():
435
623
  _script_info(__all__)
@@ -1,41 +0,0 @@
1
- dragon_ml_toolbox-13.3.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-13.3.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
3
- ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
- ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
- ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
- ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
7
- ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
8
- ml_tools/ML_datasetmaster.py,sha256=7QJnOM6GWFklKt2fiukITM3DK49i3ThK8wazb5szwpE,34396
9
- ml_tools/ML_evaluation.py,sha256=3u5dOhS77gn3kAshKr2GwSa5xZBF0YM77ZkFevqNPvA,18528
10
- ml_tools/ML_evaluation_multi.py,sha256=L6Ub_uObXsI7ToVCF6DtmAFekHRcga5wWMOnRYRR-BY,16121
11
- ml_tools/ML_inference.py,sha256=yq2gdN6s_OUYC5ZLQrIJC5BA5H33q8UKODXwb-_0M2c,23549
12
- ml_tools/ML_models.py,sha256=4Kb23pSusPMRH8h-R9ztK6JoH1lMuckxq7ihorll-H8,29965
13
- ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
14
- ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
15
- ml_tools/ML_trainer.py,sha256=9BP6JFClqGfe7GL-FGG3n5e-no9ssjEOLol7P6baGrI,29019
16
- ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
17
- ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
18
- ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
19
- ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
20
- ml_tools/VIF_factor.py,sha256=at5IVqPvicja2-DNSTSIIy3SkzDWCmLzo3qTG_qr5n8,10422
21
- ml_tools/__init__.py,sha256=kJiankjz9_qXu7gU92mYqYg_anLvt-B6RtW0mMH8uGo,76
22
- ml_tools/_logger.py,sha256=dlp5cGbzooK9YSNSZYB4yjZrOaQUGW8PTrM411AOvL8,4717
23
- ml_tools/_schema.py,sha256=yu6aWmn_2Z4_AxAtJGDDCIa96y6JcUp-vgnCS013Qmw,3908
24
- ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
25
- ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
26
- ml_tools/custom_logger.py,sha256=7tSAgRL7e-Ekm7rS1FLDocaPLCnaoKc7VSrtfwCtCEg,10067
27
- ml_tools/data_exploration.py,sha256=-BbWO7BBFapPi_7ZuWo65VqguJXaBfgFSptrXyoWrDk,51902
28
- ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
29
- ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
30
- ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
31
- ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
32
- ml_tools/keys.py,sha256=oykUVLB4Wos3AZomowjtI8AFFC5xnMUH-icNHydRpOk,2275
33
- ml_tools/math_utilities.py,sha256=PxoOrnuj6Ntp7_TJqyDWi0JX03WpAO5iaFNK2Oeq5I4,8800
34
- ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
35
- ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
36
- ml_tools/serde.py,sha256=CmdJmQCPdrm2RQA1hWLsGxU_B3aClQoQ9B4vcQtIrEs,6951
37
- ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
38
- dragon_ml_toolbox-13.3.0.dist-info/METADATA,sha256=m2RVQa8YeN6e4hnsg6TwAMjymhTrburFXbmw-yB8JeQ,6166
39
- dragon_ml_toolbox-13.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- dragon_ml_toolbox-13.3.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
- dragon_ml_toolbox-13.3.0.dist-info/RECORD,,
ml_tools/RNN_forecast.py DELETED
@@ -1,56 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import numpy as np
4
-
5
- from ._script_info import _script_info
6
-
7
-
8
- __all__ = [
9
- "rnn_forecast"
10
- ]
11
-
12
- def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, device: str = 'cpu'):
13
- """
14
- Runs a sequential forecast for a trained RNN-based model.
15
-
16
- This function iteratively predicts future time steps, where each new prediction
17
- is generated by feeding the previous prediction back into the model.
18
-
19
- Args:
20
- model (nn.Module): The trained PyTorch RNN model (e.g., LSTM, GRU).
21
- start_sequence (torch.Tensor): The initial sequence to start the forecast from.
22
- Shape should be (sequence_length, num_features).
23
- steps (int): The number of future time steps to predict.
24
- device (str, optional): The device to run the forecast on ('cpu', 'cuda', 'mps').
25
- Defaults to 'cpu'.
26
-
27
- Returns:
28
- np.ndarray: A numpy array containing the forecasted values.
29
- """
30
- model.eval()
31
- model.to(device)
32
-
33
- predictions = []
34
- current_sequence = start_sequence.to(device)
35
-
36
- with torch.no_grad():
37
- for _ in range(steps):
38
- # Get the model's prediction for the current sequence
39
- output = model(current_sequence.unsqueeze(0)) # Add batch dimension
40
-
41
- # The prediction is the last element of the output sequence
42
- next_pred = output[0, -1, :].view(1, -1)
43
-
44
- # Store the prediction
45
- predictions.append(next_pred.cpu().numpy())
46
-
47
- # Update the sequence for the next iteration:
48
- # Drop the first element and append the new prediction
49
- current_sequence = torch.cat([current_sequence[1:], next_pred], dim=0)
50
-
51
- # Concatenate all predictions and flatten the array for easy use
52
- return np.concatenate(predictions).flatten()
53
-
54
-
55
- def info():
56
- _script_info(__all__)