dragon-ml-toolbox 2.1.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 2.1.0
3
+ Version: 2.2.0
4
4
  Summary: A collection of tools for data science and machine learning projects
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,7 +1,8 @@
1
- dragon_ml_toolbox-2.1.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-2.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
1
+ dragon_ml_toolbox-2.2.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-2.2.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
+ ml_tools/ETL_engineering.py,sha256=9Lg-anXhggtdzvRPgVVSiAUGu5sb-LAZDfLDFXJlHns,21328
3
4
  ml_tools/MICE_imputation.py,sha256=1fovHycZMdZ6OgVh_bk8-r3wGi4rqf6rS10LOEWYaQo,11177
4
- ml_tools/PSO_optimization.py,sha256=vty1dZDY7P2iGUuE_oojyGdgM1EkDj5kXCfCxRMdk28,20957
5
+ ml_tools/PSO_optimization.py,sha256=T-wnB94DcRWuRd2M3loDVT4POtIP0MOhs-VilAf1L4E,20974
5
6
  ml_tools/VIF_factor.py,sha256=lpM3Z2X_iZfXUWbCbURoeI0Tb196lU0bAsRo7q6AzBM,10235
6
7
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9wF2IXptOqkOw,22220
@@ -12,9 +13,9 @@ ml_tools/handle_excel.py,sha256=Uasx-DX7RNVQSzGHVJhX7UQ9RgBbX5H1ud1Hw_y8Kp4,1294
12
13
  ml_tools/logger.py,sha256=_k7WJdpFJj3IsjOgvjLJgUFZyF8RK3Jlgp5tAu_dLQU,4767
13
14
  ml_tools/pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
14
15
  ml_tools/trainer.py,sha256=WAZ4EdrZuTOAnGXRWV3XcLNce4s7EKGf2-qchLC08Ik,15702
15
- ml_tools/utilities.py,sha256=5vVXqIH-jiY4PHUAoDI1o26mZYPsmrWO6I97Fs3oC90,18661
16
+ ml_tools/utilities.py,sha256=A7Wm1ArpqFG80WKmnkYdtSzIRLvg5x-9nPNidZIbpPA,20671
16
17
  ml_tools/vision_helpers.py,sha256=idQ-Ugp1IdsvwXiYyhYa9G3rTRTm37YRpkQDLEpANHM,7701
17
- dragon_ml_toolbox-2.1.0.dist-info/METADATA,sha256=LDXrXkR1nm6WiEVHudCy7wI0dwkMejT0NzPuYptGSmw,2974
18
- dragon_ml_toolbox-2.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- dragon_ml_toolbox-2.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
20
- dragon_ml_toolbox-2.1.0.dist-info/RECORD,,
18
+ dragon_ml_toolbox-2.2.0.dist-info/METADATA,sha256=oTLE1Q6BzsIwicQM7XCumt89XAjHZcV6CxDTfyteP_w,2974
19
+ dragon_ml_toolbox-2.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ dragon_ml_toolbox-2.2.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
21
+ dragon_ml_toolbox-2.2.0.dist-info/RECORD,,
@@ -0,0 +1,543 @@
1
+ import polars as pl
2
+ import re
3
+ from typing import Literal, Union, Optional, Any, Callable, List, Dict
4
+ from .utilities import _script_info
5
+
6
+
7
+ __all__ = [
8
+ "TransformationRecipe",
9
+ "DataProcessor",
10
+ "KeywordDummifier",
11
+ "NumberExtractor",
12
+ "MultiNumberExtractor",
13
+ "CategoryMapper",
14
+ "ValueBinner",
15
+ "DateFeatureExtractor"
16
+ ]
17
+
18
+ # Magic word for rename-only transformation
19
+ _RENAME = "rename"
20
+
21
+ class TransformationRecipe:
22
+ """
23
+ A builder class for creating a data transformation recipe.
24
+
25
+ This class provides a structured way to define a series of transformation
26
+ steps, with validation performed at the time of addition. It is designed
27
+ to be passed to a `DataProcessor`.
28
+
29
+ Use the method `add()` to add recipes.
30
+ """
31
+ def __init__(self):
32
+ self._steps: List[Dict[str, Any]] = []
33
+
34
+ def add(
35
+ self,
36
+ input_col_name: str,
37
+ output_col_names: Union[str, List[str]],
38
+ transform: Union[str, Callable],
39
+ ) -> "TransformationRecipe":
40
+ """
41
+ Adds a new transformation step to the recipe.
42
+
43
+ Args:
44
+ input_col: The name of the column from the source DataFrame.
45
+ output_col: The desired name(s) for the output column(s).
46
+ A string for a 1-to-1 mapping, or a list of strings
47
+ for a 1-to-many mapping.
48
+ transform: The transformation to apply:
49
+ - Use "rename" for simple column renaming
50
+ - If callable, must accept a `pl.Series` as the only parameter and return either a `pl.Series` or `pl.DataFrame`.
51
+
52
+ Returns:
53
+ The instance of the recipe itself to allow for method chaining.
54
+ """
55
+ # --- Validation ---
56
+ if not isinstance(input_col_name, str) or not input_col_name:
57
+ raise TypeError("'input_col' must be a non-empty string.")
58
+
59
+ if transform == _RENAME:
60
+ if not isinstance(output_col_names, str):
61
+ raise TypeError("For a RENAME operation, 'output_col' must be a string.")
62
+ elif not isinstance(transform, Callable):
63
+ raise TypeError(f"'transform' must be a callable function or the string '{_RENAME}'.")
64
+
65
+ if isinstance(output_col_names, list) and transform == _RENAME:
66
+ raise ValueError("A RENAME operation cannot have a list of output columns.")
67
+
68
+ # --- Add Step ---
69
+ step = {
70
+ "input_col": input_col_name,
71
+ "output_col": output_col_names,
72
+ "transform": transform,
73
+ }
74
+ self._steps.append(step)
75
+ return self # Allow chaining: recipe.add(...).add(...)
76
+
77
+ def __iter__(self):
78
+ """Allows the class to be iterated over, like a list."""
79
+ return iter(self._steps)
80
+
81
+ def __len__(self):
82
+ """Allows the len() function to be used on an instance."""
83
+ return len(self._steps)
84
+
85
+
86
+ class DataProcessor:
87
+ """
88
+ Transforms a Polars DataFrame based on a provided `TransformationRecipe` object.
89
+
90
+ Use the method `transform()`.
91
+ """
92
+ def __init__(self, recipe: TransformationRecipe):
93
+ """
94
+ Initializes the DataProcessor with a transformation recipe.
95
+
96
+ Args:
97
+ recipe: An instance of the `TransformationRecipe` class that has
98
+ been populated with transformation steps.
99
+ """
100
+ if not isinstance(recipe, TransformationRecipe):
101
+ raise TypeError("The recipe must be an instance of TransformationRecipe.")
102
+ if len(recipe) == 0:
103
+ raise ValueError("The recipe cannot be empty.")
104
+ self.recipe = recipe
105
+
106
+ def transform(self, df: pl.DataFrame) -> pl.DataFrame:
107
+ """
108
+ Applies the transformation recipe to the input DataFrame.
109
+ """
110
+ processed_columns = []
111
+ # Recipe object is iterable
112
+ for step in self.recipe:
113
+ input_col_name = step["input_col"]
114
+ output_col_spec = step["output_col"]
115
+ transform_action = step["transform"]
116
+
117
+ if input_col_name not in df.columns:
118
+ raise ValueError(f"Input column '{input_col_name}' not found in DataFrame.")
119
+
120
+ input_series = df.get_column(input_col_name)
121
+
122
+ if transform_action == _RENAME:
123
+ processed_columns.append(input_series.alias(output_col_spec))
124
+ continue
125
+
126
+ if isinstance(transform_action, Callable):
127
+ result = transform_action(input_series)
128
+
129
+ if isinstance(result, pl.Series):
130
+ if not isinstance(output_col_spec, str):
131
+ raise TypeError(f"Function for '{input_col_name}' returned a Series but 'output_col' is not a string.")
132
+ processed_columns.append(result.alias(output_col_spec))
133
+
134
+ elif isinstance(result, pl.DataFrame):
135
+ if not isinstance(output_col_spec, list):
136
+ raise TypeError(f"Function for '{input_col_name}' returned a DataFrame but 'output_col' is not a list.")
137
+ if len(result.columns) != len(output_col_spec):
138
+ raise ValueError(
139
+ f"Mismatch in '{input_col_name}': function produced {len(result.columns)} columns, "
140
+ f"but recipe specifies {len(output_col_spec)} output names."
141
+ )
142
+
143
+ renamed_df = result.rename(dict(zip(result.columns, output_col_spec)))
144
+ processed_columns.extend(renamed_df.get_columns())
145
+
146
+ else:
147
+ raise TypeError(f"Function for '{input_col_name}' returned an unexpected type: {type(result)}.")
148
+
149
+ else: # This case is now unlikely due to builder validation.
150
+ raise TypeError(f"Invalid 'transform' action for '{input_col_name}': {transform_action}")
151
+
152
+ if not processed_columns:
153
+ print("Warning: The transformation resulted in an empty DataFrame.")
154
+ return pl.DataFrame()
155
+
156
+ return pl.DataFrame(processed_columns)
157
+
158
+
159
+ class KeywordDummifier:
160
+ """
161
+ A configurable transformer that creates one-hot encoded columns based on
162
+ keyword matching in a Polars Series.
163
+
164
+ Instantiate this class with keyword configurations. The instance can be used as a 'transform' callable compatible with the `TransformationRecipe`.
165
+
166
+ Args:
167
+ group_names (List[str]):
168
+ A list of strings, where each string is the name of a category.
169
+ This defines the matching priority and the base column names of the
170
+ DataFrame returned by the transformation.
171
+ group_keywords (List[List[str]]):
172
+ A list of lists of strings. Each inner list corresponds to a
173
+ `group_name` at the same index and contains the keywords to search for.
174
+ """
175
+ def __init__(self, group_names: List[str], group_keywords: List[List[str]]):
176
+ if len(group_names) != len(group_keywords):
177
+ raise ValueError("Initialization failed: 'group_names' and 'group_keywords' must have the same length.")
178
+
179
+ self.group_names = group_names
180
+ self.group_keywords = group_keywords
181
+
182
+ def __call__(self, column: pl.Series) -> pl.DataFrame:
183
+ """
184
+ Executes the one-hot encoding logic.
185
+
186
+ Args:
187
+ column (pl.Series): The input Polars Series to transform.
188
+
189
+ Returns:
190
+ pl.DataFrame: A DataFrame with one-hot encoded columns.
191
+ """
192
+ column = column.cast(pl.Utf8)
193
+
194
+ categorize_expr = pl.when(pl.lit(False)).then(pl.lit(None))
195
+ for name, keywords in zip(self.group_names, self.group_keywords):
196
+ pattern = "|".join(re.escape(k) for k in keywords)
197
+ categorize_expr = categorize_expr.when(
198
+ column.str.contains(pattern)
199
+ ).then(pl.lit(name))
200
+
201
+ categorize_expr = categorize_expr.otherwise(None).alias("category")
202
+
203
+ temp_df = pl.DataFrame(categorize_expr)
204
+ df_with_dummies = temp_df.to_dummies(columns=["category"])
205
+
206
+ final_columns = []
207
+ for name in self.group_names:
208
+ dummy_col_name = f"category_{name}"
209
+ if dummy_col_name in df_with_dummies.columns:
210
+ # The alias here uses the group name as the temporary column name
211
+ final_columns.append(
212
+ df_with_dummies.get_column(dummy_col_name).alias(name)
213
+ )
214
+ else:
215
+ final_columns.append(pl.lit(0, dtype=pl.UInt8).alias(name))
216
+
217
+ return pl.DataFrame(final_columns)
218
+
219
+
220
+ class NumberExtractor:
221
+ """
222
+ A configurable transformer that extracts a single number from a Polars string series using a regular expression.
223
+
224
+ An instance can be used as a 'transform' callable within the
225
+ `DataProcessor` pipeline.
226
+
227
+ Args:
228
+ regex_pattern (str):
229
+ The regular expression used to find the number. This pattern
230
+ MUST contain exactly one capturing group `(...)`. Defaults to a standard pattern for integers and floats.
231
+ dtype (str):
232
+ The desired data type for the output column. Defaults to "float".
233
+ round_digits (int | None):
234
+ If the dtype is 'float', you can specify the number of decimal
235
+ places to round the result to. This parameter is ignored if
236
+ dtype is 'int'. Defaults to None (no rounding).
237
+ """
238
+ def __init__(
239
+ self,
240
+ regex_pattern: str = r"(\d+\.?\d*)",
241
+ dtype: Literal["float", "int"] = "float",
242
+ round_digits: Optional[int] = None,
243
+ ):
244
+ # --- Validation ---
245
+ if not isinstance(regex_pattern, str):
246
+ raise TypeError("regex_pattern must be a string.")
247
+
248
+ # Validate that the regex has exactly one capturing group
249
+ try:
250
+ if re.compile(regex_pattern).groups != 1:
251
+ raise ValueError("regex_pattern must contain exactly one capturing group '(...)'")
252
+ except re.error as e:
253
+ raise ValueError(f"Invalid regex pattern provided: {e}") from e
254
+
255
+ if dtype not in ["float", "int"]:
256
+ raise ValueError("dtype must be either 'float' or 'int'.")
257
+
258
+ if round_digits is not None:
259
+ if not isinstance(round_digits, int):
260
+ raise TypeError("round_digits must be an integer.")
261
+ if dtype == "int":
262
+ print(f"Warning: 'round_digits' is specified but dtype is 'int'. Rounding will be ignored.")
263
+
264
+ self.regex_pattern = regex_pattern
265
+ self.dtype = dtype
266
+ self.round_digits = round_digits
267
+ self.polars_dtype = pl.Float64 if dtype == "float" else pl.Int64
268
+
269
+ def __call__(self, column: pl.Series) -> pl.Series:
270
+ """
271
+ Executes the number extraction logic.
272
+
273
+ Args:
274
+ column (pl.Series): The input Polars Series to transform.
275
+
276
+ Returns:
277
+ pl.Series: A new Series containing the extracted numbers.
278
+ """
279
+ # Extract the first (and only) capturing group
280
+ extracted = column.str.extract(self.regex_pattern, 1)
281
+
282
+ # Cast to the desired numeric type. Non-matching strings become null.
283
+ casted = extracted.cast(self.polars_dtype, strict=False)
284
+
285
+ # Apply rounding only if it's a float and round_digits is set
286
+ if self.dtype == "float" and self.round_digits is not None:
287
+ return casted.round(self.round_digits)
288
+
289
+ return casted
290
+
291
+
292
+ class MultiNumberExtractor:
293
+ """
294
+ Extracts multiple numbers from a single polars string column into several new columns.
295
+
296
+ This transformer is designed for one-to-many mappings, such as parsing
297
+ ratios (100:30) or coordinates (10, 25) into separate columns.
298
+
299
+ Args:
300
+ num_outputs (int):
301
+ Number of numeric columns to create.
302
+ regex_pattern (str):
303
+ The regex pattern to find all numbers. Must contain one
304
+ capturing group around the number part.
305
+ Defaults to a standard pattern for integers and floats.
306
+ dtype (str):
307
+ The desired data type for the output columns. Defaults to "float".
308
+ fill_value (int | float | None):
309
+ A value to fill in if a number is not found at a given position (if positive match).
310
+ - For example, if `num_outputs=2` and only one number is found in a string, the second output column will be filled with this value. If None, it will be filled with null.
311
+ """
312
+ def __init__(
313
+ self,
314
+ num_outputs: int,
315
+ regex_pattern: str = r"(\d+\.?\d*)",
316
+ dtype: Literal["float", "int"] = "float",
317
+ fill_value: Optional[Union[int, float]] = None
318
+ ):
319
+ # --- Validation ---
320
+ if not isinstance(num_outputs, int) or num_outputs <= 0:
321
+ raise ValueError("num_outputs must be a positive integer.")
322
+
323
+ if not isinstance(regex_pattern, str):
324
+ raise TypeError("regex_pattern must be a string.")
325
+
326
+ # Validate that the regex has exactly one capturing group
327
+ try:
328
+ if re.compile(regex_pattern).groups != 1:
329
+ raise ValueError("regex_pattern must contain exactly one capturing group '(...)'")
330
+ except re.error as e:
331
+ raise ValueError(f"Invalid regex pattern provided: {e}") from e
332
+
333
+ # Validate dtype
334
+ if dtype not in ["float", "int"]:
335
+ raise ValueError("dtype must be either 'float' or 'int'.")
336
+
337
+ self.num_outputs = num_outputs
338
+ self.regex_pattern = regex_pattern
339
+ self.fill_value = fill_value
340
+ self.polars_dtype = pl.Float64 if dtype == "float" else pl.Int64
341
+
342
+ def __call__(self, column: pl.Series) -> pl.DataFrame:
343
+ """
344
+ Executes the multi-number extraction logic. Preserves nulls from the input column.
345
+ """
346
+ output_expressions = []
347
+ for i in range(self.num_outputs):
348
+ # Define the core extraction logic for the i-th number
349
+ extraction_expr = (
350
+ column.str.extract_all(self.regex_pattern)
351
+ .list.get(i)
352
+ .cast(self.polars_dtype, strict=False)
353
+ )
354
+
355
+ # Apply the fill value if provided
356
+ if self.fill_value is not None:
357
+ extraction_expr = extraction_expr.fill_null(self.fill_value)
358
+
359
+ # Only apply the logic when the input is not null.
360
+ # Otherwise, the result should also be null.
361
+ final_expr = (
362
+ pl.when(column.is_not_null())
363
+ .then(extraction_expr)
364
+ .otherwise(None)
365
+ .alias(f"col_{i}") # Name the final output expression
366
+ )
367
+
368
+ output_expressions.append(final_expr)
369
+
370
+ return pl.select(output_expressions)
371
+
372
+
373
+ class CategoryMapper:
374
+ """
375
+ A transformer that maps string categories to specified numerical values using a dictionary.
376
+
377
+ Ideal for ordinal encoding.
378
+
379
+ Args:
380
+ mapping (Dict[str, [int | float]]):
381
+ A dictionary that defines the mapping from a string category (key)
382
+ to a numerical value (value).
383
+ unseen_value (int | float | None):
384
+ The numerical value to use for categories that are present in the
385
+ data but not in the mapping dictionary. If not provided or set
386
+ to None, unseen categories will be mapped to a null value.
387
+ """
388
+ def __init__(
389
+ self,
390
+ mapping: Dict[str, Union[int, float]],
391
+ unseen_value: Optional[Union[int, float]] = None,
392
+ ):
393
+ if not isinstance(mapping, dict):
394
+ raise TypeError("The 'mapping' argument must be a dictionary.")
395
+
396
+ self.mapping = mapping
397
+ self.default_value = unseen_value
398
+
399
+ def __call__(self, column: pl.Series) -> pl.Series:
400
+ """
401
+ Applies the dictionary mapping to the input column.
402
+
403
+ Args:
404
+ column (pl.Series): The input Polars Series of categories.
405
+
406
+ Returns:
407
+ pl.Series: A new Series with categories mapped to numbers.
408
+ """
409
+ # Ensure the column is treated as a string for matching keys
410
+ return column.cast(pl.Utf8).map_dict(self.mapping, default=self.default_value)
411
+
412
+
413
+ class ValueBinner:
414
+ """
415
+ A transformer that discretizes a continuous numerical column into a finite number of bins.
416
+
417
+ Each bin is assigned an integer label (0, 1, 2, ...).
418
+
419
+ Args:
420
+ breaks (List[int | float]):
421
+ A list of numbers defining the boundaries of the bins. The list
422
+ must be sorted in ascending order and contain at least two values.
423
+ For example, `breaks=[0, 18, 40, 65]` creates three bins.
424
+ left_closed (bool):
425
+ Determines which side of the interval is inclusive.
426
+ - If `False` (default): Intervals are (lower, upper].
427
+ - If `True`: Intervals are [lower, upper).
428
+ """
429
+ def __init__(
430
+ self,
431
+ breaks: List[Union[int, float]],
432
+ left_closed: bool = False,
433
+ ):
434
+ # --- Validation ---
435
+ if not isinstance(breaks, list) or len(breaks) < 2:
436
+ raise ValueError("The 'breaks' argument must be a list of at least two numbers.")
437
+
438
+ # Check if the list is sorted
439
+ if not all(breaks[i] <= breaks[i+1] for i in range(len(breaks)-1)):
440
+ raise ValueError("The 'breaks' list must be sorted in ascending order.")
441
+
442
+ self.breaks = breaks
443
+ self.left_closed = left_closed
444
+ # Generate numerical labels [0, 1, 2, ...] for the bins
445
+ self.labels = [str(i) for i in range(len(breaks) - 1)]
446
+
447
+ def __call__(self, column: pl.Series) -> pl.Series:
448
+ """
449
+ Applies the binning logic to the input column.
450
+
451
+ Args:
452
+ column (pl.Series): The input Polars Series of numerical data.
453
+
454
+ Returns:
455
+ pl.Series: A new Series of integer labels for the bins. Values
456
+ outside the specified breaks will become null.
457
+ """
458
+ # `cut` creates a new column of type Categorical
459
+ binned_column = column.cut(
460
+ breaks=self.breaks,
461
+ labels=self.labels,
462
+ left_closed=self.left_closed
463
+ )
464
+
465
+ # to_physical() converts the Categorical type to its underlying
466
+ # integer representation (u32), which is perfect for ML.
467
+ return binned_column.to_physical()
468
+
469
+
470
+ class DateFeatureExtractor:
471
+ """
472
+ A one-to-many transformer that extracts multiple numerical features from a date or datetime column.
473
+
474
+ It can handle columns that are already in a Polars Date/Datetime format,
475
+ or it can parse string columns if a format is provided.
476
+
477
+ Args:
478
+ features (List[str]):
479
+ A list of the date/time features to extract. Supported features are:
480
+ 'year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond',
481
+ 'microsecond', 'nanosecond', 'ordinal_day' (day of year),
482
+ 'weekday' (Mon=1, Sun=7), 'week' (week of year), and 'timestamp'.
483
+ format (str | None):
484
+ The format code used to parse string dates (e.g., "%Y-%m-%d %H:%M:%S").
485
+ Use if the input column is not a Date or Datetime type.
486
+ """
487
+
488
+ ALLOWED_FEATURES = {
489
+ 'year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond',
490
+ 'microsecond', 'nanosecond', 'ordinal_day', 'weekday', 'week', 'timestamp'
491
+ }
492
+
493
+ def __init__(
494
+ self,
495
+ features: List[str],
496
+ format: Optional[str] = None,
497
+ ):
498
+ # --- Validation ---
499
+ if not isinstance(features, list) or not features:
500
+ raise ValueError("'features' must be a non-empty list of strings.")
501
+
502
+ for feature in features:
503
+ if feature not in self.ALLOWED_FEATURES:
504
+ raise ValueError(
505
+ f"Feature '{feature}' is not supported. "
506
+ f"Allowed features are: {self.ALLOWED_FEATURES}"
507
+ )
508
+
509
+ self.features = features
510
+ self.format = format
511
+
512
+ def __call__(self, column: pl.Series) -> pl.DataFrame:
513
+ """
514
+ Applies the feature extraction logic to the input column.
515
+
516
+ Args:
517
+ column (pl.Series): The input Polars Series of dates.
518
+
519
+ Returns:
520
+ pl.DataFrame: A DataFrame with columns for each extracted feature.
521
+ """
522
+ date_col = column
523
+ # First, parse strings into a datetime object if a format is given
524
+ if self.format is not None:
525
+ date_col = date_col.str.to_datetime(format=self.format, strict=False)
526
+
527
+ output_expressions = []
528
+ for i, feature in enumerate(self.features):
529
+ # Build the expression based on the feature name
530
+ if feature == 'timestamp':
531
+ expr = date_col.dt.timestamp(time_unit="ms")
532
+ else:
533
+ # getattr is a clean way to call methods like .dt.year(), .dt.month(), etc.
534
+ expr = getattr(date_col.dt, feature)()
535
+
536
+ # Alias with a generic name for the processor to handle
537
+ output_expressions.append(expr.alias(f"col_{i}"))
538
+
539
+ return pl.select(output_expressions)
540
+
541
+
542
+ def info():
543
+ _script_info(__all__)
@@ -340,8 +340,8 @@ def _pso(func: ObjectiveFunction,
340
340
  lb: np.ndarray,
341
341
  ub: np.ndarray,
342
342
  device: torch.device,
343
- swarmsize=100,
344
- maxiter=100,
343
+ swarmsize: int,
344
+ maxiter: int,
345
345
  omega = 0.729, # Clerc and Kennedy’s constriction coefficient
346
346
  phip = 1.49445, # Clerc and Kennedy’s constriction coefficient
347
347
  phig = 1.49445, # Clerc and Kennedy’s constriction coefficient
@@ -391,7 +391,7 @@ def _pso(func: ObjectiveFunction,
391
391
  If True, returns the full history of particle positions and objective scores at each iteration.
392
392
 
393
393
  seed : int or None, default=None
394
- Random seed for reproducibility. If None, defaults to 42.
394
+ Random seed for reproducibility. If None, the random state is not fixed.
395
395
 
396
396
  Returns
397
397
  -------
ml_tools/utilities.py CHANGED
@@ -144,23 +144,61 @@ def list_files_by_extension(directory: Union[str,Path], extension: str) -> dict[
144
144
  return name_path_dict
145
145
 
146
146
 
147
- def load_dataframe(df_path: Union[str,Path]) -> tuple[pd.DataFrame, str]:
147
+ def load_dataframe(
148
+ df_path: Union[str, Path],
149
+ kind: Literal["pandas", "polars"] = "pandas",
150
+ all_strings: bool = False
151
+ ) -> Tuple[Union[pd.DataFrame, pl.DataFrame], str]:
148
152
  """
149
- Load a CSV file into a pandas DataFrame and extract the base name (without extension) from the file path.
153
+ Load a CSV file into a DataFrame and extract its base name.
154
+
155
+ Can load data as either a pandas or a polars DataFrame. Allows for loading all
156
+ columns as string types to prevent type inference errors.
150
157
 
151
158
  Args:
152
- df_path (str | Path): The path to the CSV file.
159
+ df_path (Union[str, Path]):
160
+ The path to the CSV file.
161
+ kind (Literal["pandas", "polars"], optional):
162
+ The type of DataFrame to load. Defaults to "pandas".
163
+ all_strings (bool, optional):
164
+ If True, loads all columns as string data types. This is useful for
165
+ ETL tasks and to avoid type-inference errors. Defaults to False.
153
166
 
154
167
  Returns:
155
- Tuple ([pd.DataFrame, str]):
156
- A tuple containing the loaded pandas DataFrame and the base name of the file.
168
+ (Tuple[DataFrameType, str]):
169
+ A tuple containing the loaded DataFrame (either pandas or polars)
170
+ and the base name of the file (without extension).
171
+
172
+ Raises:
173
+ FileNotFoundError: If the file does not exist at the given path.
174
+ ValueError: If the DataFrame is empty or an invalid 'kind' is provided.
157
175
  """
158
176
  path = make_fullpath(df_path)
159
- df = pd.read_csv(path, encoding='utf-8')
177
+
160
178
  df_name = path.stem
161
- if df.empty:
162
- raise ValueError(f"DataFrame '{df_name}' is empty.")
163
- print(f"\n💿 Loaded dataset: '{df_name}' with shape: {df.shape}")
179
+
180
+ if kind == "pandas":
181
+ if all_strings:
182
+ df = pd.read_csv(path, encoding='utf-8', dtype=str)
183
+ else:
184
+ df = pd.read_csv(path, encoding='utf-8')
185
+
186
+ elif kind == "polars":
187
+ if all_strings:
188
+ df = pl.read_csv(path, infer_schema=False)
189
+ else:
190
+ # Default behavior: infer the schema.
191
+ df = pl.read_csv(path, infer_schema_length=1000)
192
+
193
+ else:
194
+ raise ValueError(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
195
+
196
+ # This check works for both pandas and polars DataFrames
197
+ if df.shape[0] == 0:
198
+ raise ValueError(f"DataFrame '{df_name}' loaded from '{path}' is empty.")
199
+
200
+ print(f"\n💿 Loaded {kind} dataset: '{df_name}' with shape: {df.shape}")
201
+
164
202
  return df, df_name
165
203
 
166
204
 
@@ -247,29 +285,42 @@ def merge_dataframes(
247
285
  return merged_df
248
286
 
249
287
 
250
- def save_dataframe(df: pd.DataFrame, save_dir: Union[str,Path], filename: str) -> None:
288
+ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Union[str,Path], filename: str) -> None:
251
289
  """
252
- Save a pandas DataFrame to a CSV file.
290
+ Saves a pandas or polars DataFrame to a CSV file.
253
291
 
254
- Parameters:
255
- df (pd.DataFrame): Dataframe to save.
256
- save_dir (str | Path): Directory where the CSV file will be saved.
257
- filename (str): CSV filename, extension will be added if missing.
292
+ Args:
293
+ df (Union[pd.DataFrame, pl.DataFrame]):
294
+ The DataFrame to save.
295
+ save_dir (Union[str, Path]):
296
+ The directory where the CSV file will be saved.
297
+ filename (str):
298
+ The CSV filename. The '.csv' extension will be added if missing.
258
299
  """
259
- if df.empty:
300
+ # This check works for both pandas and polars
301
+ if df.shape[0] == 0:
260
302
  print(f"⚠️ Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
261
303
  return
262
304
 
305
+ # Create the directory if it doesn't exist
263
306
  save_path = make_fullpath(save_dir, make=True)
264
307
 
308
+ # Clean the filename
265
309
  filename = sanitize_filename(filename)
266
-
267
310
  if not filename.endswith('.csv'):
268
311
  filename += '.csv'
269
312
 
270
313
  output_path = save_path / filename
271
314
 
272
- df.to_csv(output_path, index=False, encoding='utf-8')
315
+ # --- Type-specific saving logic ---
316
+ if isinstance(df, pd.DataFrame):
317
+ df.to_csv(output_path, index=False, encoding='utf-8')
318
+ elif isinstance(df, pl.DataFrame):
319
+ df.write_csv(output_path) # Polars defaults to utf8 and no index
320
+ else:
321
+ # This error handles cases where an unsupported type is passed
322
+ raise TypeError(f"Unsupported DataFrame type: {type(df)}. Must be pandas or polars.")
323
+
273
324
  print(f"✅ Saved dataset: '{filename}' with shape: {df.shape}")
274
325
 
275
326
 
@@ -446,7 +497,7 @@ def threshold_binary_values_batch(
446
497
  return np.hstack([cont_part, bin_part])
447
498
 
448
499
 
449
- def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose: bool=True, raise_on_error: bool=False) -> Optional[str]:
500
+ def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose: bool=True, raise_on_error: bool=False) -> Optional[Path]:
450
501
  """
451
502
  Serializes a Python object using joblib; suitable for Python built-ins, numpy, and pandas.
452
503
 
@@ -456,7 +507,7 @@ def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose
456
507
  filename (str) : Name for the output file, extension will be appended if needed.
457
508
 
458
509
  Returns:
459
- (str | None) : The full file path where the object was saved if successful; otherwise, None.
510
+ (Path | None) : The full file path where the object was saved if successful; otherwise, None.
460
511
  """
461
512
  try:
462
513
  save_path = make_fullpath(save_dir, make=True)
@@ -540,7 +591,7 @@ def distribute_datasets_by_target(
540
591
  feature_columns = [col for col in df.columns if col not in valid_targets]
541
592
 
542
593
  for target in valid_targets:
543
- subset = df[feature_columns + [target]].dropna(subset=[target])
594
+ subset = df[feature_columns + [target]].dropna(subset=[target]) # type: ignore
544
595
  if verbose:
545
596
  print(f"Target: '{target}' - Dataframe shape: {subset.shape}")
546
597
  yield target, subset