dragon-ml-toolbox 10.2.0__py3-none-any.whl → 14.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
- dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
- {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
- ml_tools/ETL_cleaning.py +72 -34
- ml_tools/ETL_engineering.py +506 -70
- ml_tools/GUI_tools.py +2 -1
- ml_tools/MICE_imputation.py +212 -7
- ml_tools/ML_callbacks.py +73 -40
- ml_tools/ML_datasetmaster.py +267 -284
- ml_tools/ML_evaluation.py +119 -58
- ml_tools/ML_evaluation_multi.py +107 -32
- ml_tools/ML_inference.py +15 -5
- ml_tools/ML_models.py +234 -170
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +321 -97
- ml_tools/ML_scaler.py +10 -5
- ml_tools/ML_trainer.py +585 -40
- ml_tools/ML_utilities.py +528 -0
- ml_tools/ML_vision_datasetmaster.py +1315 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +10 -7
- ml_tools/RNN_forecast.py +2 -0
- ml_tools/SQL.py +22 -9
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_logger.py +0 -2
- ml_tools/_schema.py +96 -0
- ml_tools/constants.py +79 -0
- ml_tools/custom_logger.py +164 -16
- ml_tools/data_exploration.py +1092 -109
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/ensemble_inference.py +6 -7
- ml_tools/ensemble_learning.py +4 -3
- ml_tools/handle_excel.py +1 -0
- ml_tools/keys.py +80 -0
- ml_tools/math_utilities.py +259 -0
- ml_tools/optimization_tools.py +198 -24
- ml_tools/path_manager.py +144 -45
- ml_tools/serde.py +192 -0
- ml_tools/utilities.py +287 -227
- dragon_ml_toolbox-10.2.0.dist-info/RECORD +0 -36
- {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ETL_engineering.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
import polars as pl
|
|
2
2
|
import re
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from typing import Literal, Union, Optional, Any, Callable, List, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from .utilities import load_dataframe, save_dataframe_filename
|
|
7
|
+
from .path_manager import make_fullpath
|
|
4
8
|
from ._script_info import _script_info
|
|
5
9
|
from ._logger import _LOGGER
|
|
10
|
+
from .constants import CHEMICAL_ELEMENT_SYMBOLS
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
__all__ = [
|
|
@@ -14,11 +19,15 @@ __all__ = [
|
|
|
14
19
|
"KeywordDummifier",
|
|
15
20
|
"NumberExtractor",
|
|
16
21
|
"MultiNumberExtractor",
|
|
22
|
+
"TemperatureExtractor",
|
|
23
|
+
"MultiTemperatureExtractor",
|
|
17
24
|
"RatioCalculator",
|
|
25
|
+
"TriRatioCalculator",
|
|
18
26
|
"CategoryMapper",
|
|
19
27
|
"RegexMapper",
|
|
20
28
|
"ValueBinner",
|
|
21
|
-
"DateFeatureExtractor"
|
|
29
|
+
"DateFeatureExtractor",
|
|
30
|
+
"MolecularFormulaTransformer"
|
|
22
31
|
]
|
|
23
32
|
|
|
24
33
|
############ TRANSFORM MAIN ####################
|
|
@@ -42,17 +51,20 @@ class TransformationRecipe:
|
|
|
42
51
|
def add(
|
|
43
52
|
self,
|
|
44
53
|
input_col_name: str,
|
|
45
|
-
output_col_names: Union[str, List[str]],
|
|
46
54
|
transform: Union[str, Callable],
|
|
55
|
+
output_col_names: Optional[Union[str, List[str]]] = None
|
|
47
56
|
) -> "TransformationRecipe":
|
|
48
57
|
"""
|
|
49
58
|
Adds a new transformation step to the recipe.
|
|
50
59
|
|
|
51
60
|
Args:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
A string for a 1-to-1 mapping
|
|
55
|
-
for a 1-to-many mapping.
|
|
61
|
+
input_col_name: The name of the column from the source DataFrame.
|
|
62
|
+
output_col_names: The desired name(s) for the output column(s).
|
|
63
|
+
- A string for a 1-to-1 mapping.
|
|
64
|
+
- A list of strings for a 1-to-many mapping.
|
|
65
|
+
- A string prefix for 1-to-many mapping.
|
|
66
|
+
- If None, the input name is used for 1-to-1 transforms,
|
|
67
|
+
or the transformer's default names are used for 1-to-many.
|
|
56
68
|
transform: The transformation to apply:
|
|
57
69
|
- Use "rename" for simple column renaming
|
|
58
70
|
- If callable, must accept a `pl.Series` as the only parameter and return either a `pl.Series` or `pl.DataFrame`.
|
|
@@ -72,10 +84,6 @@ class TransformationRecipe:
|
|
|
72
84
|
elif not isinstance(transform, Callable):
|
|
73
85
|
_LOGGER.error(f"'transform' must be a callable function or the string '{_RENAME}'.")
|
|
74
86
|
raise TypeError()
|
|
75
|
-
|
|
76
|
-
if isinstance(output_col_names, list) and transform == _RENAME:
|
|
77
|
-
_LOGGER.error("A RENAME operation cannot have a list of output columns.")
|
|
78
|
-
raise ValueError()
|
|
79
87
|
|
|
80
88
|
# --- Add Step ---
|
|
81
89
|
step = {
|
|
@@ -99,7 +107,7 @@ class DataProcessor:
|
|
|
99
107
|
"""
|
|
100
108
|
Transforms a Polars DataFrame based on a provided `TransformationRecipe` object.
|
|
101
109
|
|
|
102
|
-
Use the
|
|
110
|
+
Use the methods `transform()` or `load_transform_save()`.
|
|
103
111
|
"""
|
|
104
112
|
def __init__(self, recipe: TransformationRecipe):
|
|
105
113
|
"""
|
|
@@ -142,33 +150,53 @@ class DataProcessor:
|
|
|
142
150
|
result = transform_action(input_series)
|
|
143
151
|
|
|
144
152
|
if isinstance(result, pl.Series):
|
|
145
|
-
if
|
|
146
|
-
|
|
153
|
+
# Default to input name if spec is None
|
|
154
|
+
output_name = output_col_spec if output_col_spec is not None else input_col_name
|
|
155
|
+
|
|
156
|
+
if not isinstance(output_name, str):
|
|
157
|
+
_LOGGER.error(f"Function for '{input_col_name}' returned a Series but 'output_col' must be a string or None.")
|
|
147
158
|
raise TypeError()
|
|
148
|
-
processed_columns.append(result.alias(
|
|
159
|
+
processed_columns.append(result.alias(output_name))
|
|
149
160
|
|
|
150
161
|
elif isinstance(result, pl.DataFrame):
|
|
151
|
-
# 1. Handle
|
|
152
|
-
if
|
|
162
|
+
# 1. Handle None in output names
|
|
163
|
+
if output_col_spec is None:
|
|
164
|
+
# Use the column names generated by the transformer directly
|
|
165
|
+
processed_columns.extend(result.get_columns())
|
|
166
|
+
|
|
167
|
+
# 2. Handle list-based renaming
|
|
168
|
+
elif isinstance(output_col_spec, list):
|
|
153
169
|
if len(result.columns) != len(output_col_spec):
|
|
154
170
|
_LOGGER.error(f"Mismatch in '{input_col_name}': function produced {len(result.columns)} columns, but recipe specifies {len(output_col_spec)} output names.")
|
|
155
171
|
raise ValueError()
|
|
156
172
|
|
|
157
173
|
renamed_df = result.rename(dict(zip(result.columns, output_col_spec)))
|
|
158
174
|
processed_columns.extend(renamed_df.get_columns())
|
|
159
|
-
|
|
160
|
-
#
|
|
175
|
+
|
|
176
|
+
# 3. Global logic for adding a single prefix to all columns.
|
|
161
177
|
elif isinstance(output_col_spec, str):
|
|
162
178
|
prefix = output_col_spec
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
179
|
+
new_names = {}
|
|
180
|
+
|
|
181
|
+
for col in result.columns:
|
|
182
|
+
# Case 1: Transformer's output column name contains the input name.
|
|
183
|
+
# Action: Replace the input name with the desired prefix.
|
|
184
|
+
# Example: input='color', output='color_red', prefix='spec' -> 'spec_red'
|
|
185
|
+
if input_col_name in col:
|
|
186
|
+
new_names[col] = col.replace(input_col_name, prefix, 1)
|
|
187
|
+
|
|
188
|
+
# Case 2: Transformer's output is an independent name.
|
|
189
|
+
# Action: Prepend the prefix to the output name.
|
|
190
|
+
# Example: input='ratio', output='A_B', prefix='spec' -> 'spec_A_B'
|
|
191
|
+
else:
|
|
192
|
+
new_names[col] = f"{prefix}_{col}"
|
|
193
|
+
|
|
167
194
|
renamed_df = result.rename(new_names)
|
|
168
|
-
processed_columns.extend(renamed_df.get_columns())
|
|
195
|
+
processed_columns.extend(renamed_df.get_columns())
|
|
196
|
+
|
|
169
197
|
|
|
170
198
|
else:
|
|
171
|
-
_LOGGER.error(f"Function for '{input_col_name}' returned a DataFrame, so 'output_col' must be a list of names
|
|
199
|
+
_LOGGER.error(f"Function for '{input_col_name}' returned a DataFrame, so 'output_col' must be a list of names, a string prefix, or None.")
|
|
172
200
|
raise TypeError()
|
|
173
201
|
|
|
174
202
|
else:
|
|
@@ -182,9 +210,28 @@ class DataProcessor:
|
|
|
182
210
|
if not processed_columns:
|
|
183
211
|
_LOGGER.error("The transformation resulted in an empty DataFrame.")
|
|
184
212
|
return pl.DataFrame()
|
|
185
|
-
|
|
213
|
+
|
|
214
|
+
_LOGGER.info(f"Processed dataframe with {len(processed_columns)} columns.")
|
|
215
|
+
|
|
186
216
|
return pl.DataFrame(processed_columns)
|
|
187
217
|
|
|
218
|
+
def load_transform_save(self, input_path: Union[str,Path], output_path: Union[str,Path]):
|
|
219
|
+
"""
|
|
220
|
+
Convenience wrapper for the transform method that includes automatic dataframe loading and saving.
|
|
221
|
+
"""
|
|
222
|
+
# Validate paths
|
|
223
|
+
in_path = make_fullpath(input_path, enforce="file")
|
|
224
|
+
out_path = make_fullpath(output_path, make=True, enforce="file")
|
|
225
|
+
|
|
226
|
+
# load df
|
|
227
|
+
df, _ = load_dataframe(df_path=in_path, kind="polars", all_strings=True)
|
|
228
|
+
|
|
229
|
+
# Process
|
|
230
|
+
df_processed = self.transform(df)
|
|
231
|
+
|
|
232
|
+
# save processed df
|
|
233
|
+
save_dataframe_filename(df=df_processed, save_dir=out_path.parent, filename=out_path.name)
|
|
234
|
+
|
|
188
235
|
def __str__(self) -> str:
|
|
189
236
|
"""
|
|
190
237
|
Provides a detailed, human-readable string representation of the
|
|
@@ -253,7 +300,7 @@ class BinaryTransformer:
|
|
|
253
300
|
_LOGGER.error("Provide either 'true_keywords' or 'false_keywords', but not both.")
|
|
254
301
|
raise ValueError()
|
|
255
302
|
if true_keywords is None and false_keywords is None:
|
|
256
|
-
_LOGGER.error("
|
|
303
|
+
_LOGGER.error("Provide either 'true_keywords' or 'false_keywords'.")
|
|
257
304
|
raise ValueError()
|
|
258
305
|
|
|
259
306
|
# --- Configuration ---
|
|
@@ -285,16 +332,17 @@ class BinaryTransformer:
|
|
|
285
332
|
Returns:
|
|
286
333
|
pl.Series: A new Series of type UInt8 containing 1s and 0s.
|
|
287
334
|
"""
|
|
335
|
+
column_base_name = column.name
|
|
288
336
|
# Create a boolean Series: True if any keyword is found, else False
|
|
289
337
|
contains_keyword = column.str.contains(self.pattern)
|
|
290
338
|
|
|
291
339
|
# Apply logic and cast directly to integer type
|
|
292
340
|
if self.mode == "true_mode":
|
|
293
341
|
# True -> 1, False -> 0
|
|
294
|
-
return contains_keyword.cast(pl.UInt8)
|
|
342
|
+
return contains_keyword.cast(pl.UInt8).alias(column_base_name)
|
|
295
343
|
else: # false_mode
|
|
296
344
|
# We want the inverse: True -> 0, False -> 1
|
|
297
|
-
return (~contains_keyword).cast(pl.UInt8)
|
|
345
|
+
return (~contains_keyword).cast(pl.UInt8).alias(column_base_name)
|
|
298
346
|
|
|
299
347
|
|
|
300
348
|
class AutoDummifier:
|
|
@@ -302,6 +350,15 @@ class AutoDummifier:
|
|
|
302
350
|
A transformer that performs one-hot encoding on a categorical column,
|
|
303
351
|
automatically detecting the unique categories from the data.
|
|
304
352
|
"""
|
|
353
|
+
def __init__(self, drop_first: bool = False):
|
|
354
|
+
"""
|
|
355
|
+
Initializes the AutoDummifier.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
drop_first (bool): If True, drops the first dummy column.
|
|
359
|
+
"""
|
|
360
|
+
self.drop_first = drop_first
|
|
361
|
+
|
|
305
362
|
def __call__(self, column: pl.Series) -> pl.DataFrame:
|
|
306
363
|
"""
|
|
307
364
|
Executes the one-hot encoding logic.
|
|
@@ -314,8 +371,20 @@ class AutoDummifier:
|
|
|
314
371
|
Column names are auto-generated by Polars as
|
|
315
372
|
'{original_col_name}_{category_value}'.
|
|
316
373
|
"""
|
|
317
|
-
#
|
|
318
|
-
|
|
374
|
+
# Store the original column name to construct the potential null column name
|
|
375
|
+
col_name = column.name
|
|
376
|
+
|
|
377
|
+
# Create the dummy variables from the series
|
|
378
|
+
dummies = column.cast(pl.Utf8).to_dummies(drop_first=self.drop_first)
|
|
379
|
+
|
|
380
|
+
# Define the name of the column that Polars creates for null values
|
|
381
|
+
null_col_name = f"{col_name}_null"
|
|
382
|
+
|
|
383
|
+
# Check if the null column exists and drop it if it does
|
|
384
|
+
if null_col_name in dummies.columns:
|
|
385
|
+
return dummies.drop(null_col_name)
|
|
386
|
+
|
|
387
|
+
return dummies
|
|
319
388
|
|
|
320
389
|
|
|
321
390
|
class MultiBinaryDummifier:
|
|
@@ -332,7 +401,7 @@ class MultiBinaryDummifier:
|
|
|
332
401
|
A list of strings, where each string is a keyword to search for. A separate
|
|
333
402
|
binary column will be created for each keyword.
|
|
334
403
|
case_insensitive (bool):
|
|
335
|
-
If True, keyword matching ignores case.
|
|
404
|
+
If True, keyword matching ignores case.
|
|
336
405
|
"""
|
|
337
406
|
def __init__(self, keywords: List[str], case_insensitive: bool = True):
|
|
338
407
|
if not isinstance(keywords, list) or not all(isinstance(k, str) for k in keywords):
|
|
@@ -355,11 +424,12 @@ class MultiBinaryDummifier:
|
|
|
355
424
|
Returns:
|
|
356
425
|
pl.DataFrame: A DataFrame where each column corresponds to a keyword.
|
|
357
426
|
"""
|
|
427
|
+
column_base_name = column.name
|
|
358
428
|
# Ensure the input is treated as a string, preserving nulls
|
|
359
429
|
str_column = column.cast(pl.Utf8)
|
|
360
430
|
|
|
361
431
|
output_expressions = []
|
|
362
|
-
for
|
|
432
|
+
for keyword in self.keywords:
|
|
363
433
|
# Escape keyword to treat it as a literal, not a regex pattern
|
|
364
434
|
base_pattern = re.escape(keyword)
|
|
365
435
|
|
|
@@ -373,7 +443,7 @@ class MultiBinaryDummifier:
|
|
|
373
443
|
.when(str_column.str.contains(pattern))
|
|
374
444
|
.then(pl.lit(1, dtype=pl.UInt8))
|
|
375
445
|
.otherwise(pl.lit(0, dtype=pl.UInt8))
|
|
376
|
-
.alias(f"
|
|
446
|
+
.alias(f"{column_base_name}_{keyword}") # name for DataProcessor
|
|
377
447
|
)
|
|
378
448
|
output_expressions.append(expr)
|
|
379
449
|
|
|
@@ -417,6 +487,7 @@ class KeywordDummifier:
|
|
|
417
487
|
Returns:
|
|
418
488
|
pl.DataFrame: A DataFrame with one-hot encoded columns.
|
|
419
489
|
"""
|
|
490
|
+
column_base_name = column.name
|
|
420
491
|
column = column.cast(pl.Utf8)
|
|
421
492
|
|
|
422
493
|
categorize_expr = pl.when(pl.lit(False)).then(pl.lit(None, dtype=pl.Utf8))
|
|
@@ -435,22 +506,24 @@ class KeywordDummifier:
|
|
|
435
506
|
column.str.contains(pattern)
|
|
436
507
|
).then(pl.lit(name))
|
|
437
508
|
|
|
438
|
-
|
|
509
|
+
dummy_name = 'dummy_category'
|
|
510
|
+
|
|
511
|
+
categorize_expr = categorize_expr.otherwise(None).alias(dummy_name)
|
|
439
512
|
|
|
440
513
|
temp_df = pl.select(categorize_expr)
|
|
441
|
-
df_with_dummies = temp_df.to_dummies(columns=[
|
|
514
|
+
df_with_dummies = temp_df.to_dummies(columns=[dummy_name])
|
|
442
515
|
|
|
443
516
|
final_columns = []
|
|
444
517
|
for name in self.group_names:
|
|
445
|
-
dummy_col_name = f"
|
|
518
|
+
dummy_col_name = f"{dummy_name}_{name}"
|
|
446
519
|
if dummy_col_name in df_with_dummies.columns:
|
|
447
|
-
# The alias here uses the group name as the
|
|
520
|
+
# The alias here uses the group name as the final column name
|
|
448
521
|
final_columns.append(
|
|
449
|
-
df_with_dummies.get_column(dummy_col_name).alias(name)
|
|
522
|
+
df_with_dummies.get_column(dummy_col_name).alias(f"{column_base_name}_{name}")
|
|
450
523
|
)
|
|
451
524
|
else:
|
|
452
525
|
# If a group had no matches, create a column of zeros
|
|
453
|
-
final_columns.append(pl.lit(0, dtype=pl.UInt8).alias(name))
|
|
526
|
+
final_columns.append(pl.lit(0, dtype=pl.UInt8).alias(f"{column_base_name}_{name}"))
|
|
454
527
|
|
|
455
528
|
return pl.select(final_columns)
|
|
456
529
|
|
|
@@ -471,7 +544,7 @@ class NumberExtractor:
|
|
|
471
544
|
round_digits (int | None):
|
|
472
545
|
If the dtype is 'float', you can specify the number of decimal
|
|
473
546
|
places to round the result to. This parameter is ignored if
|
|
474
|
-
dtype is 'int'.
|
|
547
|
+
dtype is 'int'.
|
|
475
548
|
"""
|
|
476
549
|
def __init__(
|
|
477
550
|
self,
|
|
@@ -519,6 +592,7 @@ class NumberExtractor:
|
|
|
519
592
|
Returns:
|
|
520
593
|
pl.Series: A new Series containing the extracted numbers.
|
|
521
594
|
"""
|
|
595
|
+
column_base_name = column.name
|
|
522
596
|
# Extract the first (and only) capturing group
|
|
523
597
|
extracted = column.str.extract(self.regex_pattern, 1)
|
|
524
598
|
|
|
@@ -529,7 +603,7 @@ class NumberExtractor:
|
|
|
529
603
|
if self.dtype == "float" and self.round_digits is not None:
|
|
530
604
|
return casted.round(self.round_digits)
|
|
531
605
|
|
|
532
|
-
return casted
|
|
606
|
+
return casted.alias(column_base_name)
|
|
533
607
|
|
|
534
608
|
|
|
535
609
|
class MultiNumberExtractor:
|
|
@@ -590,12 +664,13 @@ class MultiNumberExtractor:
|
|
|
590
664
|
"""
|
|
591
665
|
Executes the multi-number extraction logic. Preserves nulls from the input column.
|
|
592
666
|
"""
|
|
667
|
+
column_base_name = column.name
|
|
593
668
|
output_expressions = []
|
|
594
669
|
for i in range(self.num_outputs):
|
|
595
670
|
# Define the core extraction logic for the i-th number
|
|
596
671
|
extraction_expr = (
|
|
597
672
|
column.str.extract_all(self.regex_pattern)
|
|
598
|
-
.list.get(i)
|
|
673
|
+
.list.get(i, null_on_oob=True)
|
|
599
674
|
.cast(self.polars_dtype, strict=False)
|
|
600
675
|
)
|
|
601
676
|
|
|
@@ -609,24 +684,214 @@ class MultiNumberExtractor:
|
|
|
609
684
|
pl.when(column.is_not_null())
|
|
610
685
|
.then(extraction_expr)
|
|
611
686
|
.otherwise(None)
|
|
612
|
-
.alias(f"
|
|
687
|
+
.alias(f"{column_base_name}_{i}") # Name the final output expression
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
output_expressions.append(final_expr)
|
|
691
|
+
|
|
692
|
+
return pl.select(output_expressions)
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
class TemperatureExtractor:
|
|
696
|
+
"""
|
|
697
|
+
Extracts temperature values from a string column.
|
|
698
|
+
|
|
699
|
+
This transformer assumes that the source temperature values are in Celsius.
|
|
700
|
+
It can extract a single value using a specific regex or find all numbers in
|
|
701
|
+
a string and calculate their average. It also supports converting the final
|
|
702
|
+
Celsius value to Kelvin or Rankine.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
regex_pattern (str):
|
|
706
|
+
The regex to find a single temperature. MUST contain exactly one
|
|
707
|
+
capturing group `(...)`. This is ignored if `average_mode` is True.
|
|
708
|
+
average_mode (bool):
|
|
709
|
+
If True, extracts all numbers from the string and returns their average.
|
|
710
|
+
This overrides the `regex_pattern` with a generic number-finding regex.
|
|
711
|
+
convert (str | None):
|
|
712
|
+
If "K", converts the final Celsius value to Kelvin.
|
|
713
|
+
If "R", converts the final Celsius value to Rankine.
|
|
714
|
+
If None (default), the value remains in Celsius.
|
|
715
|
+
"""
|
|
716
|
+
def __init__(
|
|
717
|
+
self,
|
|
718
|
+
regex_pattern: str = r"(\d+\.?\d*)",
|
|
719
|
+
average_mode: bool = False,
|
|
720
|
+
convert: Optional[Literal["K", "R"]] = None,
|
|
721
|
+
):
|
|
722
|
+
# --- Store configuration ---
|
|
723
|
+
self.average_mode = average_mode
|
|
724
|
+
self.convert = convert
|
|
725
|
+
self.regex_pattern = regex_pattern
|
|
726
|
+
|
|
727
|
+
# Generic pattern for average mode, defined once for efficiency.
|
|
728
|
+
self._avg_mode_pattern = r"(\d+\.?\d*)"
|
|
729
|
+
|
|
730
|
+
# --- Validation ---
|
|
731
|
+
if not self.average_mode:
|
|
732
|
+
try:
|
|
733
|
+
if re.compile(self.regex_pattern).groups != 1:
|
|
734
|
+
_LOGGER.error("'regex_pattern' must contain exactly one capturing group '(...)' for single extraction mode.")
|
|
735
|
+
raise ValueError()
|
|
736
|
+
except re.error as e:
|
|
737
|
+
_LOGGER.error(f"Invalid regex pattern provided: {e}")
|
|
738
|
+
raise ValueError()
|
|
739
|
+
|
|
740
|
+
if self.convert is not None and self.convert not in ["K", "R"]:
|
|
741
|
+
_LOGGER.error("'convert' must be either 'K' (Kelvin) or 'R' (Rankine).")
|
|
742
|
+
raise ValueError()
|
|
743
|
+
|
|
744
|
+
def __call__(self, column: pl.Series) -> pl.Series:
|
|
745
|
+
"""
|
|
746
|
+
Applies the temperature extraction and conversion logic.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
column (pl.Series): The input Polars Series with string data.
|
|
750
|
+
|
|
751
|
+
Returns:
|
|
752
|
+
pl.Series: A new Series containing the final temperature values as floats.
|
|
753
|
+
"""
|
|
754
|
+
column_base_name = column.name
|
|
755
|
+
# --- Step 1: Extract number(s) to get a Celsius value expression ---
|
|
756
|
+
if self.average_mode:
|
|
757
|
+
# Extract all numbers and compute their mean. Polars' list.mean()
|
|
758
|
+
# handles the casting to float automatically.
|
|
759
|
+
celsius_expr = (
|
|
760
|
+
column.str.extract_all(self._avg_mode_pattern)
|
|
761
|
+
.list.eval(pl.element().cast(pl.Float64, strict=False))
|
|
762
|
+
.list.mean()
|
|
763
|
+
)
|
|
764
|
+
else:
|
|
765
|
+
# Extract a single number using the specified pattern.
|
|
766
|
+
# Cast to Float64, with non-matches becoming null.
|
|
767
|
+
celsius_expr = column.str.extract(self.regex_pattern, 1).cast(pl.Float64, strict=False)
|
|
768
|
+
|
|
769
|
+
# --- Step 2: Apply conversion if specified ---
|
|
770
|
+
if self.convert == "K":
|
|
771
|
+
# Celsius to Kelvin: C + 273.15
|
|
772
|
+
final_expr = celsius_expr + 273.15
|
|
773
|
+
elif self.convert == "R":
|
|
774
|
+
# Celsius to Rankine: (C * 9/5) + 491.67
|
|
775
|
+
final_expr = (celsius_expr * 1.8) + 491.67
|
|
776
|
+
else:
|
|
777
|
+
# No conversion needed
|
|
778
|
+
final_expr = celsius_expr
|
|
779
|
+
|
|
780
|
+
# --- Step 3: Round the result and return as a Series ---
|
|
781
|
+
# The select().to_series() pattern is a robust way to execute an
|
|
782
|
+
# expression and guarantee a Series is returned.
|
|
783
|
+
return pl.select(final_expr.round(2)).to_series().alias(column_base_name)
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
class MultiTemperatureExtractor:
|
|
787
|
+
"""
|
|
788
|
+
Extracts multiple temperature values from a single string column into
|
|
789
|
+
several new columns, assuming the source values are in Celsius.
|
|
790
|
+
|
|
791
|
+
This one-to-many transformer is designed for cases where multiple readings
|
|
792
|
+
are packed into one field, like "Min: 10C, Max: 25C".
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
num_outputs (int):
|
|
796
|
+
The number of numeric columns to create.
|
|
797
|
+
regex_pattern (str):
|
|
798
|
+
The regex to find all numbers. Must contain exactly one capturing
|
|
799
|
+
group around the number part (e.g., r"(-?\\d+\\.?\\d*)").
|
|
800
|
+
convert (str | None):
|
|
801
|
+
If "K", converts the final Celsius values to Kelvin.
|
|
802
|
+
If "R", converts the final Celsius values to Rankine.
|
|
803
|
+
If None (default), the values remain in Celsius.
|
|
804
|
+
fill_value (int | float | None):
|
|
805
|
+
A value to use if a temperature is not found at a given position.
|
|
806
|
+
For example, if `num_outputs=3` and only two temperatures are
|
|
807
|
+
found, the third column will be filled with this value. If None,
|
|
808
|
+
it will be filled with null.
|
|
809
|
+
"""
|
|
810
|
+
def __init__(
|
|
811
|
+
self,
|
|
812
|
+
num_outputs: int,
|
|
813
|
+
regex_pattern: str = r"(\d+\.?\d*)",
|
|
814
|
+
convert: Optional[Literal["K", "R"]] = None,
|
|
815
|
+
fill_value: Optional[Union[int, float]] = None
|
|
816
|
+
):
|
|
817
|
+
# --- Validation ---
|
|
818
|
+
if not isinstance(num_outputs, int) or num_outputs <= 0:
|
|
819
|
+
_LOGGER.error("'num_outputs' must be a positive integer.")
|
|
820
|
+
raise ValueError()
|
|
821
|
+
|
|
822
|
+
try:
|
|
823
|
+
if re.compile(regex_pattern).groups != 1:
|
|
824
|
+
_LOGGER.error("'regex_pattern' must contain exactly one capturing group '(...)'.")
|
|
825
|
+
raise ValueError()
|
|
826
|
+
except re.error as e:
|
|
827
|
+
_LOGGER.error(f"Invalid regex pattern provided: {e}")
|
|
828
|
+
raise ValueError()
|
|
829
|
+
|
|
830
|
+
if convert is not None and convert not in ["K", "R"]:
|
|
831
|
+
_LOGGER.error("'convert' must be either 'K' (Kelvin) or 'R' (Rankine).")
|
|
832
|
+
raise ValueError()
|
|
833
|
+
|
|
834
|
+
# --- Store configuration ---
|
|
835
|
+
self.num_outputs = num_outputs
|
|
836
|
+
self.regex_pattern = regex_pattern
|
|
837
|
+
self.convert = convert
|
|
838
|
+
self.fill_value = fill_value
|
|
839
|
+
|
|
840
|
+
def __call__(self, column: pl.Series) -> pl.DataFrame:
|
|
841
|
+
"""
|
|
842
|
+
Applies the multi-temperature extraction and conversion logic.
|
|
843
|
+
"""
|
|
844
|
+
column_base_name = column.name
|
|
845
|
+
output_expressions = []
|
|
846
|
+
for i in range(self.num_outputs):
|
|
847
|
+
# --- Step 1: Extract the i-th number as a Celsius value ---
|
|
848
|
+
celsius_expr = (
|
|
849
|
+
column.str.extract_all(self.regex_pattern)
|
|
850
|
+
.list.get(i, null_on_oob=True)
|
|
851
|
+
.cast(pl.Float64, strict=False)
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
# --- Step 2: Apply conversion if specified ---
|
|
855
|
+
if self.convert == "K":
|
|
856
|
+
# Celsius to Kelvin: C + 273.15
|
|
857
|
+
converted_expr = celsius_expr + 273.15
|
|
858
|
+
elif self.convert == "R":
|
|
859
|
+
# Celsius to Rankine: (C * 9/5) + 491.67
|
|
860
|
+
converted_expr = (celsius_expr * 1.8) + 491.67
|
|
861
|
+
else:
|
|
862
|
+
# No conversion needed
|
|
863
|
+
converted_expr = celsius_expr
|
|
864
|
+
|
|
865
|
+
# --- Step 3: Apply fill value and handle original nulls ---
|
|
866
|
+
final_expr = converted_expr.round(2)
|
|
867
|
+
if self.fill_value is not None:
|
|
868
|
+
final_expr = final_expr.fill_null(self.fill_value)
|
|
869
|
+
|
|
870
|
+
# Ensure that if the original row was null, all outputs are null
|
|
871
|
+
final_expr = (
|
|
872
|
+
pl.when(column.is_not_null())
|
|
873
|
+
.then(final_expr)
|
|
874
|
+
.otherwise(None)
|
|
875
|
+
.alias(f"{column_base_name}_{i}") # Temporary name for DataProcessor
|
|
613
876
|
)
|
|
614
877
|
|
|
615
878
|
output_expressions.append(final_expr)
|
|
616
879
|
|
|
880
|
+
# Execute all expressions at once for performance
|
|
617
881
|
return pl.select(output_expressions)
|
|
618
882
|
|
|
619
883
|
|
|
620
884
|
class RatioCalculator:
|
|
621
885
|
"""
|
|
622
886
|
A transformer that parses a string ratio (e.g., "40:5" or "30/2") and
|
|
623
|
-
computes the result of the division.
|
|
624
|
-
|
|
887
|
+
computes the result of the division. Includes robust handling for
|
|
888
|
+
zeros and single numbers.
|
|
625
889
|
"""
|
|
626
890
|
def __init__(
|
|
627
891
|
self,
|
|
628
|
-
|
|
629
|
-
|
|
892
|
+
regex_pattern: str = r"(\d+\.?\d*)\s*[::/]\s*(\d+\.?\d*)",
|
|
893
|
+
handle_zeros: bool = False,
|
|
894
|
+
handle_single_number: bool = False
|
|
630
895
|
):
|
|
631
896
|
# --- Robust Validation ---
|
|
632
897
|
try:
|
|
@@ -642,24 +907,119 @@ class RatioCalculator:
|
|
|
642
907
|
raise ValueError()
|
|
643
908
|
|
|
644
909
|
self.regex_pattern = regex_pattern
|
|
910
|
+
self.handle_zeros = handle_zeros
|
|
911
|
+
self.handle_single_number = handle_single_number
|
|
645
912
|
|
|
646
913
|
def __call__(self, column: pl.Series) -> pl.Series:
|
|
647
914
|
"""
|
|
648
|
-
Applies the ratio calculation logic to the input column.
|
|
649
|
-
This version uses .str.extract() for maximum stability.
|
|
915
|
+
Applies the ratio calculation logic to the input column. Uses .str.extract() for maximum stability and includes optional handling for zeros and single numbers.
|
|
650
916
|
"""
|
|
917
|
+
column_base_name = column.name
|
|
651
918
|
# Extract numerator (group 1) and denominator (group 2) separately.
|
|
652
919
|
numerator_expr = column.str.extract(self.regex_pattern, 1).cast(pl.Float64, strict=False)
|
|
653
920
|
denominator_expr = column.str.extract(self.regex_pattern, 2).cast(pl.Float64, strict=False)
|
|
654
921
|
|
|
655
|
-
#
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
922
|
+
# --- Logic for Requirement A: Special zero handling ---
|
|
923
|
+
if self.handle_zeros:
|
|
924
|
+
ratio_expr = (
|
|
925
|
+
pl.when(numerator_expr.is_not_null() & denominator_expr.is_not_null())
|
|
926
|
+
.then(
|
|
927
|
+
pl.when((numerator_expr == 0) & (denominator_expr == 0)).then(pl.lit(0.0))
|
|
928
|
+
.when((numerator_expr != 0) & (denominator_expr == 0)).then(numerator_expr)
|
|
929
|
+
.when((numerator_expr == 0) & (denominator_expr != 0)).then(denominator_expr)
|
|
930
|
+
.otherwise(numerator_expr / denominator_expr) # Default: both are non-zero
|
|
931
|
+
)
|
|
932
|
+
)
|
|
933
|
+
else:
|
|
934
|
+
# Original logic
|
|
935
|
+
ratio_expr = pl.when(denominator_expr != 0).then(
|
|
936
|
+
numerator_expr / denominator_expr
|
|
937
|
+
).otherwise(
|
|
938
|
+
None # Handles null denominators and division by zero
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
# --- Logic for Requirement B: Handle single numbers as a fallback ---
|
|
942
|
+
if self.handle_single_number:
|
|
943
|
+
# Regex to match a string that is ONLY a valid float/int
|
|
944
|
+
single_number_regex = r"^\d+\.?\d*$"
|
|
945
|
+
single_number_expr = (
|
|
946
|
+
pl.when(column.str.contains(single_number_regex))
|
|
947
|
+
.then(column.cast(pl.Float64, strict=False))
|
|
948
|
+
.otherwise(None)
|
|
949
|
+
)
|
|
950
|
+
# If ratio_expr is null, try to fill it with single_number_expr
|
|
951
|
+
final_expr = ratio_expr.fill_null(single_number_expr)
|
|
952
|
+
else:
|
|
953
|
+
final_expr = ratio_expr
|
|
954
|
+
|
|
955
|
+
return pl.select(final_expr.round(4)).to_series().alias(column_base_name)
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
class TriRatioCalculator:
|
|
959
|
+
"""
|
|
960
|
+
A transformer that handles three-part ("A:B:C") ratios, enforcing a strict output structure.
|
|
961
|
+
|
|
962
|
+
- Three-part ratios produce A/B and A/C.
|
|
963
|
+
- Two-part ratios are assumed to be A:C and produce None for A/B.
|
|
964
|
+
- Single values produce None for both outputs.
|
|
965
|
+
"""
|
|
966
|
+
def __init__(self, handle_zeros: bool = False):
|
|
967
|
+
"""
|
|
968
|
+
Initializes the TriRatioCalculator.
|
|
661
969
|
|
|
662
|
-
|
|
970
|
+
Args:
|
|
971
|
+
handle_zeros (bool): If True, returns a valid value if either the denominator or numerator is zero; returns zero if both are zero.
|
|
972
|
+
"""
|
|
973
|
+
self.handle_zeros = handle_zeros
|
|
974
|
+
|
|
975
|
+
def _calculate_ratio(self, num: pl.Expr, den: pl.Expr) -> pl.Expr:
|
|
976
|
+
"""Helper to contain the core division logic."""
|
|
977
|
+
if self.handle_zeros:
|
|
978
|
+
# Special handling for zeros
|
|
979
|
+
expr = (
|
|
980
|
+
pl.when((num == 0) & (den == 0)).then(pl.lit(0.0))
|
|
981
|
+
.when((num != 0) & (den == 0)).then(num) # Return numerator
|
|
982
|
+
.when((num == 0) & (den != 0)).then(den) # Return denominator
|
|
983
|
+
.otherwise(num / den)
|
|
984
|
+
)
|
|
985
|
+
else:
|
|
986
|
+
# Default behavior: return null if denominator is 0
|
|
987
|
+
expr = pl.when(den != 0).then(num / den).otherwise(None)
|
|
988
|
+
|
|
989
|
+
return expr.round(4)
|
|
990
|
+
|
|
991
|
+
def __call__(self, column: pl.Series) -> pl.DataFrame:
|
|
992
|
+
"""
|
|
993
|
+
Applies the robust tri-ratio logic using the lazy API.
|
|
994
|
+
"""
|
|
995
|
+
column_base_name = column.name
|
|
996
|
+
# Wrap the input Series in a DataFrame to use the lazy expression API
|
|
997
|
+
temp_df = column.to_frame()
|
|
998
|
+
|
|
999
|
+
# Define all steps as lazy expressions
|
|
1000
|
+
all_numbers_expr = pl.col(column.name).str.extract_all(r"(\d+\.?\d*)")
|
|
1001
|
+
num_parts_expr = all_numbers_expr.list.len()
|
|
1002
|
+
|
|
1003
|
+
expr_A = all_numbers_expr.list.get(0, null_on_oob=True).cast(pl.Float64)
|
|
1004
|
+
expr_B = all_numbers_expr.list.get(1, null_on_oob=True).cast(pl.Float64)
|
|
1005
|
+
expr_C = all_numbers_expr.list.get(2, null_on_oob=True).cast(pl.Float64)
|
|
1006
|
+
|
|
1007
|
+
# Define logic for each output column using expressions
|
|
1008
|
+
ratio_ab_expr = pl.when(num_parts_expr == 3).then(
|
|
1009
|
+
self._calculate_ratio(expr_A, expr_B)
|
|
1010
|
+
).otherwise(None)
|
|
1011
|
+
|
|
1012
|
+
ratio_ac_expr = pl.when(num_parts_expr == 3).then(
|
|
1013
|
+
self._calculate_ratio(expr_A, expr_C)
|
|
1014
|
+
).when(num_parts_expr == 2).then(
|
|
1015
|
+
self._calculate_ratio(expr_A, expr_B) # B is actually C in this case
|
|
1016
|
+
).otherwise(None)
|
|
1017
|
+
|
|
1018
|
+
# Execute the expressions and return the final DataFrame
|
|
1019
|
+
return temp_df.select(
|
|
1020
|
+
ratio_ab_expr.alias(f"{column_base_name}_A_to_B"),
|
|
1021
|
+
ratio_ac_expr.alias(f"{column_base_name}_A_to_C")
|
|
1022
|
+
)
|
|
663
1023
|
|
|
664
1024
|
|
|
665
1025
|
class CategoryMapper:
|
|
@@ -699,6 +1059,7 @@ class CategoryMapper:
|
|
|
699
1059
|
Returns:
|
|
700
1060
|
pl.Series: A new Series with categories mapped to numbers.
|
|
701
1061
|
"""
|
|
1062
|
+
column_base_name = column.name
|
|
702
1063
|
# Ensure the column is treated as a string for matching keys
|
|
703
1064
|
str_column = column.cast(pl.Utf8)
|
|
704
1065
|
|
|
@@ -715,7 +1076,7 @@ class CategoryMapper:
|
|
|
715
1076
|
pl.lit(self.default_value)
|
|
716
1077
|
)
|
|
717
1078
|
|
|
718
|
-
return pl.select(final_expr).to_series()
|
|
1079
|
+
return pl.select(final_expr).to_series().alias(column_base_name)
|
|
719
1080
|
|
|
720
1081
|
|
|
721
1082
|
class RegexMapper:
|
|
@@ -779,6 +1140,7 @@ class RegexMapper:
|
|
|
779
1140
|
pl.Series: A new Series with strings mapped to numbers based on
|
|
780
1141
|
the first matching regex pattern.
|
|
781
1142
|
"""
|
|
1143
|
+
column_base_name = column.name
|
|
782
1144
|
# pl.String is the modern alias for pl.Utf8
|
|
783
1145
|
str_column = column.cast(pl.String)
|
|
784
1146
|
|
|
@@ -793,7 +1155,7 @@ class RegexMapper:
|
|
|
793
1155
|
.otherwise(mapping_expr)
|
|
794
1156
|
)
|
|
795
1157
|
|
|
796
|
-
return pl.select(mapping_expr).to_series()
|
|
1158
|
+
return pl.select(mapping_expr).to_series().alias(column_base_name)
|
|
797
1159
|
|
|
798
1160
|
|
|
799
1161
|
class ValueBinner:
|
|
@@ -843,6 +1205,7 @@ class ValueBinner:
|
|
|
843
1205
|
pl.Series: A new Series of integer labels for the bins. Values
|
|
844
1206
|
outside the specified breaks will become null.
|
|
845
1207
|
"""
|
|
1208
|
+
column_base_name = column.name
|
|
846
1209
|
# `cut` creates a new column of type Categorical
|
|
847
1210
|
binned_column = column.cut(
|
|
848
1211
|
breaks=self.breaks,
|
|
@@ -852,7 +1215,7 @@ class ValueBinner:
|
|
|
852
1215
|
|
|
853
1216
|
# to_physical() converts the Categorical type to its underlying
|
|
854
1217
|
# integer representation (u32), which is perfect for ML.
|
|
855
|
-
return binned_column.to_physical()
|
|
1218
|
+
return binned_column.to_physical().alias(column_base_name)
|
|
856
1219
|
|
|
857
1220
|
|
|
858
1221
|
class DateFeatureExtractor:
|
|
@@ -861,16 +1224,6 @@ class DateFeatureExtractor:
|
|
|
861
1224
|
|
|
862
1225
|
It can handle columns that are already in a Polars Date/Datetime format,
|
|
863
1226
|
or it can parse string columns if a format is provided.
|
|
864
|
-
|
|
865
|
-
Args:
|
|
866
|
-
features (List[str]):
|
|
867
|
-
A list of the date/time features to extract. Supported features are:
|
|
868
|
-
'year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond',
|
|
869
|
-
'microsecond', 'nanosecond', 'ordinal_day' (day of year),
|
|
870
|
-
'weekday' (Mon=1, Sun=7), 'week' (week of year), and 'timestamp'.
|
|
871
|
-
format (str | None):
|
|
872
|
-
The format code used to parse string dates (e.g., "%Y-%m-%d %H:%M:%S").
|
|
873
|
-
Use if the input column is not a Date or Datetime type.
|
|
874
1227
|
"""
|
|
875
1228
|
|
|
876
1229
|
ALLOWED_FEATURES = {
|
|
@@ -883,6 +1236,17 @@ class DateFeatureExtractor:
|
|
|
883
1236
|
features: List[str],
|
|
884
1237
|
format: Optional[str] = None,
|
|
885
1238
|
):
|
|
1239
|
+
"""
|
|
1240
|
+
Args:
|
|
1241
|
+
features (List[str]):
|
|
1242
|
+
A list of the date/time features to extract. Supported features are:
|
|
1243
|
+
'year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond',
|
|
1244
|
+
'microsecond', 'nanosecond', 'ordinal_day' (day of year),
|
|
1245
|
+
'weekday' (Mon=1, Sun=7), 'week' (week of year), 'timestamp'.
|
|
1246
|
+
format (str | None):
|
|
1247
|
+
The format code used to parse string dates (e.g., "%Y-%m-%d %H:%M:%S").
|
|
1248
|
+
Use if the input column is not a Date or Datetime type.
|
|
1249
|
+
"""
|
|
886
1250
|
# --- Validation ---
|
|
887
1251
|
if not isinstance(features, list) or not features:
|
|
888
1252
|
_LOGGER.error("'features' must be a non-empty list of strings.")
|
|
@@ -906,6 +1270,7 @@ class DateFeatureExtractor:
|
|
|
906
1270
|
Returns:
|
|
907
1271
|
pl.DataFrame: A DataFrame with columns for each extracted feature.
|
|
908
1272
|
"""
|
|
1273
|
+
column_base_name = column.name
|
|
909
1274
|
date_col = column
|
|
910
1275
|
# First, parse strings into a datetime object if a format is given
|
|
911
1276
|
if self.format is not None:
|
|
@@ -921,10 +1286,81 @@ class DateFeatureExtractor:
|
|
|
921
1286
|
expr = getattr(date_col.dt, feature)()
|
|
922
1287
|
|
|
923
1288
|
# Alias with a generic name for the processor to handle
|
|
924
|
-
output_expressions.append(expr.alias(f"
|
|
1289
|
+
output_expressions.append(expr.alias(f"{column_base_name}_{feature}"))
|
|
925
1290
|
|
|
926
1291
|
return pl.select(output_expressions)
|
|
927
1292
|
|
|
928
1293
|
|
|
1294
|
+
class MolecularFormulaTransformer:
|
|
1295
|
+
"""
|
|
1296
|
+
Parses a Polars Series of molecular formula strings into a wide DataFrame.
|
|
1297
|
+
|
|
1298
|
+
This one-to-many transformer takes a column of condensed molecular formulas
|
|
1299
|
+
(e.g., 'Li0.115Mn0.529Ni0.339O2') and converts it into a DataFrame where
|
|
1300
|
+
each chemical element has its own column. The value in each column is the
|
|
1301
|
+
stoichiometric quantity of that element.
|
|
1302
|
+
|
|
1303
|
+
It is designed to be used within the DataProcessor pipeline.
|
|
1304
|
+
"""
|
|
1305
|
+
|
|
1306
|
+
def __init__(self):
|
|
1307
|
+
"""
|
|
1308
|
+
Initializes the transformer and pre-compiles the regex pattern.
|
|
1309
|
+
"""
|
|
1310
|
+
# Sort symbols by length to prevent matching 'C' in 'Co'
|
|
1311
|
+
sorted_symbols = sorted(CHEMICAL_ELEMENT_SYMBOLS, key=len, reverse=True)
|
|
1312
|
+
|
|
1313
|
+
# Pre-compile regex for efficiency
|
|
1314
|
+
self.pattern = re.compile(rf'({"|".join(sorted_symbols)})(\d*\.?\d*)')
|
|
1315
|
+
|
|
1316
|
+
def __call__(self, column: pl.Series) -> pl.DataFrame:
|
|
1317
|
+
"""
|
|
1318
|
+
Executes the formula parsing logic.
|
|
1319
|
+
|
|
1320
|
+
Args:
|
|
1321
|
+
column: A Polars Series containing strings of molecular formulas.
|
|
1322
|
+
|
|
1323
|
+
Returns:
|
|
1324
|
+
A Polars DataFrame with columns for every chemical element.
|
|
1325
|
+
"""
|
|
1326
|
+
column_base_name = column.name
|
|
1327
|
+
def parse_formula(formula: str) -> dict:
|
|
1328
|
+
"""Helper to parse a single formula string into a dictionary."""
|
|
1329
|
+
if not isinstance(formula, str) or not formula:
|
|
1330
|
+
return {}
|
|
1331
|
+
|
|
1332
|
+
matches = self.pattern.findall(formula)
|
|
1333
|
+
|
|
1334
|
+
# This dict comprehension is correct for your use case where
|
|
1335
|
+
# each element appears only once in the formula string.
|
|
1336
|
+
return {
|
|
1337
|
+
element: float(value) if value else 1.0
|
|
1338
|
+
for element, value in matches
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
# Apply the parsing function to each element
|
|
1342
|
+
parsed_series = column.map_elements(parse_formula, return_dtype=pl.Object)
|
|
1343
|
+
|
|
1344
|
+
# Convert the Series of dictionaries into a DataFrame
|
|
1345
|
+
df = pl.DataFrame(parsed_series.to_list())
|
|
1346
|
+
|
|
1347
|
+
# Ensure all possible element columns are created, filling with 0
|
|
1348
|
+
select_expressions = []
|
|
1349
|
+
for symbol in CHEMICAL_ELEMENT_SYMBOLS:
|
|
1350
|
+
col_name = f"{column_base_name}_{symbol}"
|
|
1351
|
+
if symbol in df.columns:
|
|
1352
|
+
expr = pl.col(symbol).fill_null(0).alias(col_name)
|
|
1353
|
+
else:
|
|
1354
|
+
expr = pl.lit(0.0, dtype=pl.Float64).alias(col_name)
|
|
1355
|
+
select_expressions.append(expr)
|
|
1356
|
+
|
|
1357
|
+
# Handle edge case where input series is not empty but parsing yields no rows
|
|
1358
|
+
base_df = df
|
|
1359
|
+
if df.height == 0 and column.len() > 0:
|
|
1360
|
+
base_df = pl.DataFrame({'dummy': range(column.len())})
|
|
1361
|
+
|
|
1362
|
+
return base_df.select(select_expressions)
|
|
1363
|
+
|
|
1364
|
+
|
|
929
1365
|
def info():
|
|
930
1366
|
_script_info(__all__)
|