dragon-ml-toolbox 12.0.1__py3-none-any.whl → 12.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,413 @@
1
+ import pandas # logger
2
+ import torch
3
+ import numpy #handling torch to numpy
4
+ import evotorch
5
+ from evotorch.algorithms import SNES, CEM, GeneticAlgorithm
6
+ from evotorch.logging import PandasLogger
7
+ from evotorch.operators import SimulatedBinaryCrossOver, GaussianMutation
8
+ from typing import Literal, Union, Tuple, List, Optional, Any, Callable
9
+ from pathlib import Path
10
+ from tqdm.auto import trange
11
+ from contextlib import nullcontext
12
+ from functools import partial
13
+
14
+ from .path_manager import make_fullpath, sanitize_filename
15
+ from ._logger import _LOGGER
16
+ from ._script_info import _script_info
17
+ from .ML_inference import PyTorchInferenceHandler
18
+ from .keys import PyTorchInferenceKeys
19
+ from .SQL import DatabaseManager
20
+ from .optimization_tools import _save_result
21
+ from .utilities import save_dataframe
22
+ from .math_utilities import threshold_binary_values
23
+
24
+ """
25
+ DEPRECATED
26
+ """
27
+
28
+ __all__ = [
29
+ "s_MLOptimizer",
30
+ "s_create_pytorch_problem",
31
+ "s_run_optimization"
32
+ ]
33
+
34
+
35
+ class s_MLOptimizer:
36
+ """
37
+ A wrapper class for setting up and running EvoTorch optimization tasks.
38
+
39
+ This class combines the functionality of `create_pytorch_problem` and
40
+ `run_optimization` functions into a single, streamlined workflow.
41
+
42
+ SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
43
+
44
+ Example:
45
+ >>> # 1. Initialize the optimizer with model and search parameters
46
+ >>> optimizer = MLOptimizer(
47
+ ... inference_handler=my_handler,
48
+ ... bounds=(lower_bounds, upper_bounds),
49
+ ... number_binary_features=2,
50
+ ... task="max",
51
+ ... algorithm="Genetic"
52
+ ... )
53
+ >>> # 2. Run the optimization and save the results
54
+ >>> best_result = optimizer.run(
55
+ ... num_generations=100,
56
+ ... target_name="my_target",
57
+ ... feature_names=my_feature_names,
58
+ ... save_dir="/path/to/results",
59
+ ... save_format="csv"
60
+ ... )
61
+ """
62
+ def __init__(self,
63
+ inference_handler: PyTorchInferenceHandler,
64
+ bounds: Tuple[List[float], List[float]],
65
+ number_binary_features: int,
66
+ task: Literal["min", "max"],
67
+ algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
68
+ population_size: int = 200,
69
+ **searcher_kwargs):
70
+ """
71
+ Initializes the optimizer by creating the EvoTorch problem and searcher.
72
+
73
+ Args:
74
+ inference_handler (PyTorchInferenceHandler): An initialized inference handler containing the model and weights.
75
+ bounds (tuple[list[float], list[float]]): A tuple containing the lower and upper bounds for the solution features.
76
+ number_binary_features (int): Number of binary features located at the END of the feature vector.
77
+ task (str): The optimization goal, either "min" or "max".
78
+ algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
79
+ population_size (int): Population size for CEM and GeneticAlgorithm.
80
+ **searcher_kwargs: Additional keyword arguments for the selected search algorithm's constructor.
81
+ """
82
+ # Call the existing factory function to get the problem and searcher factory
83
+ self.problem, self.searcher_factory = s_create_pytorch_problem(
84
+ inference_handler=inference_handler,
85
+ bounds=bounds,
86
+ binary_features=number_binary_features,
87
+ task=task,
88
+ algorithm=algorithm,
89
+ population_size=population_size,
90
+ **searcher_kwargs
91
+ )
92
+ # Store binary_features count to pass it to the run function later
93
+ self._binary_features = number_binary_features
94
+
95
+ def run(self,
96
+ num_generations: int,
97
+ target_name: str,
98
+ save_dir: Union[str, Path],
99
+ feature_names: Optional[List[str]],
100
+ save_format: Literal['csv', 'sqlite', 'both'],
101
+ repetitions: int = 1,
102
+ verbose: bool = True) -> Optional[dict]:
103
+ """
104
+ Runs the evolutionary optimization process using the pre-configured settings.
105
+
106
+ Args:
107
+ num_generations (int): The total number of generations for each repetition.
108
+ target_name (str): Target name used for the CSV filename and/or SQL table.
109
+ save_dir (str | Path): The directory where result files will be saved.
110
+ feature_names (List[str] | None): Names of the solution features for labeling output. If None, generic names like 'feature_0', 'feature_1', ... , will be created.
111
+ save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
112
+ repetitions (int): The number of independent times to run the optimization.
113
+ verbose (bool): If True, enables detailed logging.
114
+
115
+ Returns:
116
+ Optional[dict]: A dictionary with the best result if repetitions is 1, otherwise None.
117
+ """
118
+ # Call the existing run function with the stored problem, searcher, and binary feature count
119
+ return s_run_optimization(
120
+ problem=self.problem,
121
+ searcher_factory=self.searcher_factory,
122
+ num_generations=num_generations,
123
+ target_name=target_name,
124
+ binary_features=self._binary_features,
125
+ save_dir=save_dir,
126
+ save_format=save_format,
127
+ feature_names=feature_names,
128
+ repetitions=repetitions,
129
+ verbose=verbose
130
+ )
131
+
132
+
133
+ def s_create_pytorch_problem(
134
+ inference_handler: PyTorchInferenceHandler,
135
+ bounds: Tuple[List[float], List[float]],
136
+ binary_features: int,
137
+ task: Literal["min", "max"],
138
+ algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
139
+ population_size: int = 200,
140
+ **searcher_kwargs
141
+ ) -> Tuple[evotorch.Problem, Callable[[], Any]]:
142
+ """
143
+ Creates and configures an EvoTorch Problem and a Searcher factory class for a PyTorch model.
144
+
145
+ SNES and CEM do not accept bounds, the given bounds will be used as an initial starting point.
146
+
147
+ The Genetic Algorithm works directly with the bounds, and operators such as SimulatedBinaryCrossOver and GaussianMutation.
148
+
149
+ Args:
150
+ inference_handler (PyTorchInferenceHandler): An initialized inference handler containing the model and weights.
151
+ bounds (tuple[list[float], list[float]]): A tuple containing the lower and upper bounds for the solution features.
152
+ binary_features (int): Number of binary features located at the END of the feature vector. Will be automatically added to the bounds.
153
+ task (str): The optimization goal, either "minimize" or "maximize".
154
+ algorithm (str): The search algorithm to use.
155
+ population_size (int): Used for CEM and GeneticAlgorithm.
156
+ **searcher_kwargs: Additional keyword arguments to pass to the
157
+ selected search algorithm's constructor (e.g., stdev_init=0.5 for CMAES).
158
+
159
+ Returns:
160
+ Tuple:
161
+ A tuple containing the configured Problem and Searcher.
162
+ """
163
+ # Create copies to avoid modifying the original lists passed in the `bounds` tuple
164
+ lower_bounds = list(bounds[0])
165
+ upper_bounds = list(bounds[1])
166
+
167
+ # add binary bounds
168
+ if binary_features > 0:
169
+ lower_bounds.extend([0.48] * binary_features)
170
+ upper_bounds.extend([0.52] * binary_features)
171
+
172
+ solution_length = len(lower_bounds)
173
+ device = inference_handler.device
174
+
175
+ # Define the fitness function that EvoTorch will call.
176
+ def fitness_func(solution_tensor: torch.Tensor) -> torch.Tensor:
177
+ # Directly use the continuous-valued tensor from the optimizer for prediction
178
+ predictions = inference_handler.predict_batch(solution_tensor)[PyTorchInferenceKeys.PREDICTIONS]
179
+ return predictions.flatten()
180
+
181
+
182
+ # Create the Problem instance.
183
+ if algorithm == "CEM" or algorithm == "SNES":
184
+ problem = evotorch.Problem(
185
+ objective_sense=task,
186
+ objective_func=fitness_func,
187
+ solution_length=solution_length,
188
+ initial_bounds=(lower_bounds, upper_bounds),
189
+ device=device,
190
+ vectorized=True #Use batches
191
+ )
192
+
193
+ # If stdev_init is not provided, calculate it based on the bounds (used for SNES and CEM)
194
+ if 'stdev_init' not in searcher_kwargs:
195
+ # Calculate stdev for each parameter as 25% of its search range
196
+ stdevs = [abs(up - low) * 0.25 for low, up in zip(lower_bounds, upper_bounds)]
197
+ searcher_kwargs['stdev_init'] = torch.tensor(stdevs, dtype=torch.float32, requires_grad=False)
198
+
199
+ if algorithm == "SNES":
200
+ SearcherClass = SNES
201
+ elif algorithm == "CEM":
202
+ SearcherClass = CEM
203
+ # Set a defaults for CEM if not provided
204
+ if 'popsize' not in searcher_kwargs:
205
+ searcher_kwargs['popsize'] = population_size
206
+ if 'parenthood_ratio' not in searcher_kwargs:
207
+ searcher_kwargs['parenthood_ratio'] = 0.2 #float 0.0 - 1.0
208
+
209
+ elif algorithm == "Genetic":
210
+ problem = evotorch.Problem(
211
+ objective_sense=task,
212
+ objective_func=fitness_func,
213
+ solution_length=solution_length,
214
+ bounds=(lower_bounds, upper_bounds),
215
+ device=device,
216
+ vectorized=True #Use batches
217
+ )
218
+
219
+ operators = [
220
+ SimulatedBinaryCrossOver(problem,
221
+ tournament_size=3,
222
+ eta=0.6),
223
+ GaussianMutation(problem,
224
+ stdev=0.4)
225
+ ]
226
+
227
+ searcher_kwargs["operators"] = operators
228
+ if 'popsize' not in searcher_kwargs:
229
+ searcher_kwargs['popsize'] = population_size
230
+
231
+ SearcherClass = GeneticAlgorithm
232
+
233
+ else:
234
+ _LOGGER.error(f"Unknown algorithm '{algorithm}'.")
235
+ raise ValueError()
236
+
237
+ # Create a factory function with all arguments pre-filled
238
+ searcher_factory = partial(SearcherClass, problem, **searcher_kwargs)
239
+
240
+ return problem, searcher_factory
241
+
242
+
243
+ def s_run_optimization(
244
+ problem: evotorch.Problem,
245
+ searcher_factory: Callable[[],Any],
246
+ num_generations: int,
247
+ target_name: str,
248
+ binary_features: int,
249
+ save_dir: Union[str, Path],
250
+ save_format: Literal['csv', 'sqlite', 'both'],
251
+ feature_names: Optional[List[str]],
252
+ repetitions: int = 1,
253
+ verbose: bool = True
254
+ ) -> Optional[dict]:
255
+ """
256
+ Runs the evolutionary optimization process, with support for multiple repetitions.
257
+
258
+ This function serves as the main engine for the optimization task. It takes a
259
+ configured Problem and a Searcher from EvoTorch and executes the optimization
260
+ for a specified number of generations.
261
+
262
+ It has two modes of operation:
263
+ 1. **Single Run (repetitions=1):** Executes the optimization once, saves the
264
+ single best result to a CSV file, and returns it as a dictionary.
265
+ 2. **Iterative Analysis (repetitions > 1):** Executes the optimization
266
+ multiple times. Results from each run are streamed incrementally to the
267
+ specified file formats (CSV and/or SQLite database). In this mode,
268
+ the function returns None.
269
+
270
+ Args:
271
+ problem (evotorch.Problem): The configured problem instance, which defines
272
+ the objective function, solution space, and optimization sense.
273
+ searcher_factory (Callable): The searcher factory to generate fresh evolutionary algorithms.
274
+ num_generations (int): The total number of generations to run the search algorithm for in each repetition.
275
+ target_name (str): Target name that will also be used for the CSV filename and SQL table.
276
+ binary_features (int): Number of binary features located at the END of the feature vector.
277
+ save_dir (str | Path): The directory where the result file(s) will be saved.
278
+ save_format (Literal['csv', 'sqlite', 'both'], optional): The format for
279
+ saving results during iterative analysis.
280
+ feature_names (List[str], optional): Names of the solution features for
281
+ labeling the output files. If None, generic names like 'feature_0',
282
+ 'feature_1', etc., will be created.
283
+ repetitions (int, optional): The number of independent times to run the
284
+ entire optimization process.
285
+ verbose (bool): Add an Evotorch Pandas logger saved as a csv. Only for the first repetition.
286
+
287
+ Returns:
288
+ Optional[dict]: A dictionary containing the best feature values and the
289
+ fitness score if `repetitions` is 1. Returns `None` if `repetitions`
290
+ is greater than 1, as results are streamed to files instead.
291
+ """
292
+ # preprocess paths
293
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
294
+
295
+ sanitized_target_name = sanitize_filename(target_name)
296
+ if not sanitized_target_name.endswith(".csv"):
297
+ sanitized_target_name = sanitized_target_name + ".csv"
298
+
299
+ csv_path = save_path / sanitized_target_name
300
+
301
+ db_path = save_path / "Optimization.db"
302
+ db_table_name = target_name
303
+
304
+ # preprocess feature names
305
+ if feature_names is None:
306
+ feature_names = [f"feature_{i}" for i in range(problem.solution_length)] # type: ignore
307
+
308
+ # --- SINGLE RUN LOGIC ---
309
+ if repetitions <= 1:
310
+ searcher = searcher_factory()
311
+ _LOGGER.info(f"🤖 Starting optimization with {searcher.__class__.__name__} Algorithm for {num_generations} generations...")
312
+ # for _ in trange(num_generations, desc="Optimizing"):
313
+ # searcher.step()
314
+
315
+ # Attach logger if requested
316
+ if verbose:
317
+ pandas_logger = PandasLogger(searcher)
318
+
319
+ searcher.run(num_generations) # Use the built-in run method for simplicity
320
+
321
+ # # DEBUG new searcher objects
322
+ # for status_key in searcher.iter_status_keys():
323
+ # print("===", status_key, "===")
324
+ # print(searcher.status[status_key])
325
+ # print()
326
+
327
+ # Get results from the .status dictionary
328
+ # SNES and CEM use the key 'center' to get mean values if needed best_solution_tensor = searcher.status["center"]
329
+ best_solution_container = searcher.status["pop_best"]
330
+ best_solution_tensor = best_solution_container.values
331
+ best_fitness = best_solution_container.evals
332
+
333
+ best_solution_np = best_solution_tensor.cpu().numpy()
334
+
335
+ # threshold binary features
336
+ if binary_features > 0:
337
+ best_solution_thresholded = threshold_binary_values(input_array=best_solution_np, binary_values=binary_features)
338
+ else:
339
+ best_solution_thresholded = best_solution_np
340
+
341
+ result_dict = {name: value for name, value in zip(feature_names, best_solution_thresholded)}
342
+ result_dict[target_name] = best_fitness.item()
343
+
344
+ _save_result(result_dict, 'csv', csv_path) # Single run defaults to CSV
345
+
346
+ # Process logger
347
+ if verbose:
348
+ _handle_pandas_log(pandas_logger, save_path=save_path, target_name=target_name)
349
+
350
+ _LOGGER.info(f"Optimization complete. Best solution saved to '{csv_path.name}'")
351
+ return result_dict
352
+
353
+ # --- MULTIPLE REPETITIONS LOGIC ---
354
+ else:
355
+ _LOGGER.info(f"🏁 Starting optimal solution space analysis with {repetitions} repetitions...")
356
+
357
+ db_context = DatabaseManager(db_path) if save_format in ['sqlite', 'both'] else nullcontext()
358
+
359
+ with db_context as db_manager:
360
+ if db_manager:
361
+ schema = {name: "REAL" for name in feature_names}
362
+ schema[target_name] = "REAL"
363
+ db_manager.create_table(db_table_name, schema)
364
+
365
+ print("")
366
+ # Repetitions loop
367
+ pandas_logger = None
368
+ for i in trange(repetitions, desc="Repetitions"):
369
+ # CRITICAL: Create a fresh searcher for each run using the factory
370
+ searcher = searcher_factory()
371
+
372
+ # Attach logger if requested
373
+ if verbose and i==0:
374
+ pandas_logger = PandasLogger(searcher)
375
+
376
+ searcher.run(num_generations) # Use the built-in run method for simplicity
377
+
378
+ # Get results from the .status dictionary
379
+ # SNES and CEM use the key 'center' to get mean values if needed best_solution_tensor = searcher.status["center"]
380
+ best_solution_container = searcher.status["pop_best"]
381
+ best_solution_tensor = best_solution_container.values
382
+ best_fitness = best_solution_container.evals
383
+
384
+ best_solution_np = best_solution_tensor.cpu().numpy()
385
+
386
+ # threshold binary features
387
+ if binary_features > 0:
388
+ best_solution_thresholded = threshold_binary_values(input_array=best_solution_np, binary_values=binary_features)
389
+ else:
390
+ best_solution_thresholded = best_solution_np
391
+
392
+ # make results dictionary
393
+ result_dict = {name: value for name, value in zip(feature_names, best_solution_thresholded)}
394
+ result_dict[target_name] = best_fitness.item()
395
+
396
+ # Save each result incrementally
397
+ _save_result(result_dict, save_format, csv_path, db_manager, db_table_name)
398
+
399
+ # Process logger
400
+ if pandas_logger is not None:
401
+ _handle_pandas_log(pandas_logger, save_path=save_path, target_name=target_name)
402
+
403
+ _LOGGER.info(f"Optimal solution space complete. Results saved to '{save_path}'")
404
+ return None
405
+
406
+
407
+ def _handle_pandas_log(logger: PandasLogger, save_path: Path, target_name: str):
408
+ log_dataframe = logger.to_dataframe()
409
+ save_dataframe(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name)
410
+
411
+
412
+ def info():
413
+ _script_info(__all__)
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "plot_value_distributions",
30
30
  "clip_outliers_single",
31
31
  "clip_outliers_multi",
32
+ "drop_outlier_samples",
32
33
  "match_and_filter_columns_by_regex",
33
34
  "standardize_percentages",
34
35
  "create_transformer_categorical_map",
@@ -358,8 +359,8 @@ def encode_categorical_features(
358
359
  df (pd.DataFrame): The input DataFrame.
359
360
  columns_to_encode (List[str]): A list of column names to be encoded.
360
361
  encode_nulls (bool): If True, encodes Null values as a distinct category
361
- "Other" with a value of 0. Other categories start from 1.
362
- If False, Nulls are ignored.
362
+ "Other" with a value of 0. Other categories start from 1.
363
+ If False, Nulls are ignored and categories start from 0.
363
364
  split_resulting_dataset (bool): If True, returns two separate DataFrames:
364
365
  one with non-categorical columns and one with the encoded columns.
365
366
  If False, returns a single DataFrame with all columns.
@@ -758,7 +759,99 @@ def clip_outliers_multi(
758
759
  if skipped_columns:
759
760
  _LOGGER.warning("Skipped columns:")
760
761
  for col, msg in skipped_columns:
761
- print(f" - {col}: {msg}")
762
+ print(f" - {col}")
763
+
764
+ return new_df
765
+
766
+
767
+ def drop_outlier_samples(
768
+ df: pd.DataFrame,
769
+ bounds_dict: Dict[str, Tuple[Union[int, float], Union[int, float]]],
770
+ drop_on_nulls: bool = False,
771
+ verbose: bool = True
772
+ ) -> pd.DataFrame:
773
+ """
774
+ Drops entire rows where values in specified numeric columns fall outside
775
+ a given [min, max] range.
776
+
777
+ This function processes a copy of the DataFrame, ensuring the original is
778
+ not modified. It skips columns with invalid specifications.
779
+
780
+ Args:
781
+ df (pd.DataFrame): The input DataFrame.
782
+ bounds_dict (dict): A dictionary where keys are column names and values
783
+ are (min_val, max_val) tuples defining the valid range.
784
+ drop_on_nulls (bool): If True, rows with NaN/None in a checked column
785
+ will also be dropped. If False, NaN/None are ignored.
786
+ verbose (bool): If True, prints the number of rows dropped for each column.
787
+
788
+ Returns:
789
+ pd.DataFrame: A new DataFrame with the outlier rows removed.
790
+
791
+ Notes:
792
+ - Invalid specifications (e.g., missing column, non-numeric type,
793
+ incorrectly formatted bounds) will be reported and skipped.
794
+ """
795
+ new_df = df.copy()
796
+ skipped_columns: List[Tuple[str, str]] = []
797
+ initial_rows = len(new_df)
798
+
799
+ for col, bounds in bounds_dict.items():
800
+ try:
801
+ # --- Validation Checks ---
802
+ if col not in df.columns:
803
+ _LOGGER.error(f"Column '{col}' not found in DataFrame.")
804
+ raise ValueError()
805
+
806
+ if not pd.api.types.is_numeric_dtype(df[col]):
807
+ _LOGGER.error(f"Column '{col}' is not of a numeric data type.")
808
+ raise TypeError()
809
+
810
+ if not (isinstance(bounds, tuple) and len(bounds) == 2):
811
+ _LOGGER.error(f"Bounds for '{col}' must be a tuple of (min, max).")
812
+ raise ValueError()
813
+
814
+ # --- Filtering Logic ---
815
+ min_val, max_val = bounds
816
+ rows_before_drop = len(new_df)
817
+
818
+ # Create the base mask for values within the specified range
819
+ # .between() is inclusive and evaluates to False for NaN
820
+ mask_in_bounds = new_df[col].between(min_val, max_val)
821
+
822
+ if drop_on_nulls:
823
+ # Keep only rows that are within bounds.
824
+ # Since mask_in_bounds is False for NaN, nulls are dropped.
825
+ final_mask = mask_in_bounds
826
+ else:
827
+ # Keep rows that are within bounds OR are null.
828
+ mask_is_null = new_df[col].isnull()
829
+ final_mask = mask_in_bounds | mask_is_null
830
+
831
+ # Apply the final mask
832
+ new_df = new_df[final_mask]
833
+
834
+ rows_after_drop = len(new_df)
835
+
836
+ if verbose:
837
+ dropped_count = rows_before_drop - rows_after_drop
838
+ if dropped_count > 0:
839
+ print(
840
+ f" - Column '{col}': Dropped {dropped_count} rows with values outside range [{min_val}, {max_val}]."
841
+ )
842
+
843
+ except (ValueError, TypeError) as e:
844
+ skipped_columns.append((col, str(e)))
845
+ continue
846
+
847
+ total_dropped = initial_rows - len(new_df)
848
+ _LOGGER.info(f"Finished processing. Total rows dropped: {total_dropped}.")
849
+
850
+ if skipped_columns:
851
+ _LOGGER.warning("Skipped the following columns due to errors:")
852
+ for col, msg in skipped_columns:
853
+ # Only print the column name for cleaner output as the error was already logged
854
+ print(f" - {col}")
762
855
 
763
856
  return new_df
764
857
 
@@ -174,16 +174,18 @@ def threshold_binary_values_batch(
174
174
  def discretize_categorical_values(
175
175
  input_array: np.ndarray,
176
176
  categorical_info: dict[int, int],
177
- start_at_zero: bool = False
177
+ start_at_zero: bool = True
178
178
  ) -> np.ndarray:
179
179
  """
180
180
  Rounds specified columns of a 2D NumPy array to the nearest integer and
181
181
  clamps the result to a valid categorical range.
182
+
183
+ If a 1D array is provided, it is treated as a single batch.
182
184
 
183
185
  Parameters
184
186
  ----------
185
187
  input_array : np.ndarray
186
- 2D array with shape (batch_size, n_features) containing continuous values.
188
+ 1D array (n_features,) or 2D array with shape (batch_size, n_features) containing continuous values.
187
189
  categorical_info : dict[int, int]
188
190
  A dictionary mapping column indices to their cardinality (number of categories).
189
191
  Example: {3: 4} means column 3 will be clamped to its 4 valid categories.
@@ -195,10 +197,22 @@ def discretize_categorical_values(
195
197
  -------
196
198
  np.ndarray
197
199
  A new array with the specified columns converted to integer categories.
200
+ Shape matches the input array's original shape.
198
201
  """
199
202
  # --- Input Validation ---
200
- if input_array.ndim != 2:
201
- _LOGGER.error(f"Expected 2D array, got {input_array.ndim}D array.")
203
+ if not isinstance(input_array, np.ndarray):
204
+ _LOGGER.error(f"Expected np.ndarray, got {type(input_array)}.")
205
+ raise ValueError()
206
+
207
+ if input_array.ndim == 1:
208
+ # Reshape 1D array (n_features,) to 2D (1, n_features)
209
+ working_array = input_array.reshape(1, -1)
210
+ original_was_1d = True
211
+ elif input_array.ndim == 2:
212
+ working_array = input_array
213
+ original_was_1d = False
214
+ else:
215
+ _LOGGER.error(f"Expected 1D or 2D array, got {input_array.ndim}D array.")
202
216
  raise ValueError()
203
217
 
204
218
  if not isinstance(categorical_info, dict) or not categorical_info:
@@ -207,6 +221,9 @@ def discretize_categorical_values(
207
221
 
208
222
  _, total_features = input_array.shape
209
223
  for col_idx, cardinality in categorical_info.items():
224
+ if not isinstance(col_idx, int):
225
+ _LOGGER.error(f"Column index key {col_idx} is not an integer.")
226
+ raise TypeError()
210
227
  if not (0 <= col_idx < total_features):
211
228
  _LOGGER.error(f"Column index {col_idx} is out of bounds for an array with {total_features} features.")
212
229
  raise ValueError()
@@ -215,7 +232,7 @@ def discretize_categorical_values(
215
232
  raise ValueError()
216
233
 
217
234
  # --- Core Logic ---
218
- output_array = input_array.copy()
235
+ output_array = working_array.copy()
219
236
 
220
237
  for col_idx, cardinality in categorical_info.items():
221
238
  # 1. Round the column values using "round half up"
@@ -228,7 +245,14 @@ def discretize_categorical_values(
228
245
  # 3. Clamp the values and update the output array
229
246
  output_array[:, col_idx] = np.clip(rounded_col, min_bound, max_bound)
230
247
 
231
- return output_array.astype(np.int32)
248
+ final_output = output_array.astype(np.int32)
249
+
250
+ # --- Output Shape Handling ---
251
+ if original_was_1d:
252
+ # Squeeze the batch dimension to return a 1D array
253
+ return final_output.squeeze(axis=0)
254
+ else:
255
+ return final_output
232
256
 
233
257
 
234
258
  def info():