nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -18
- nextrec/basic/asserts.py +1 -22
- nextrec/basic/callback.py +2 -2
- nextrec/basic/features.py +6 -37
- nextrec/basic/heads.py +13 -1
- nextrec/basic/layers.py +33 -123
- nextrec/basic/loggers.py +3 -2
- nextrec/basic/metrics.py +85 -4
- nextrec/basic/model.py +518 -7
- nextrec/basic/summary.py +88 -42
- nextrec/cli.py +117 -30
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -844
- nextrec/loss/grad_norm.py +78 -76
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/models/tree_base/base.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/config.py +1 -1
- nextrec/utils/console.py +1 -1
- nextrec/utils/onnx_utils.py +252 -0
- nextrec/utils/torch_utils.py +63 -56
- nextrec/utils/types.py +43 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
- nextrec/models/multi_task/[pre]star.py +0 -192
- nextrec/models/representation/autorec.py +0 -0
- nextrec/models/representation/bpr.py +0 -0
- nextrec/models/representation/cl4srec.py +0 -0
- nextrec/models/representation/lightgcn.py +0 -0
- nextrec/models/representation/mf.py +0 -0
- nextrec/models/representation/s3rec.py +0 -0
- nextrec/models/sequential/sasrec.py +0 -0
- nextrec/utils/feature.py +0 -29
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
3
3
|
|
|
4
4
|
Date: create on 13/11/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 28/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -13,14 +13,12 @@ import logging
|
|
|
13
13
|
import os
|
|
14
14
|
import pickle
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import Any, Dict, Literal, Optional, Union, overload
|
|
16
|
+
from typing import Any, Dict, Iterable, Literal, Optional, Union, overload
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import pandas as pd
|
|
20
|
-
import
|
|
21
|
-
import pyarrow.parquet as pq
|
|
20
|
+
import polars as pl
|
|
22
21
|
from sklearn.preprocessing import (
|
|
23
|
-
LabelEncoder,
|
|
24
22
|
MaxAbsScaler,
|
|
25
23
|
MinMaxScaler,
|
|
26
24
|
RobustScaler,
|
|
@@ -37,9 +35,6 @@ from nextrec.utils.data import (
|
|
|
37
35
|
FILE_FORMAT_CONFIG,
|
|
38
36
|
check_streaming_support,
|
|
39
37
|
default_output_dir,
|
|
40
|
-
iter_file_chunks,
|
|
41
|
-
load_dataframes,
|
|
42
|
-
read_table,
|
|
43
38
|
resolve_file_paths,
|
|
44
39
|
)
|
|
45
40
|
|
|
@@ -63,7 +58,7 @@ class DataProcessor(FeatureSet):
|
|
|
63
58
|
self.is_fitted = False
|
|
64
59
|
|
|
65
60
|
self.scalers: Dict[str, Any] = {}
|
|
66
|
-
self.label_encoders: Dict[str,
|
|
61
|
+
self.label_encoders: Dict[str, Any] = {}
|
|
67
62
|
self.target_encoders: Dict[str, Dict[str, int]] = {}
|
|
68
63
|
self.set_target_id(target=[], id_columns=[])
|
|
69
64
|
|
|
@@ -186,318 +181,228 @@ class DataProcessor(FeatureSet):
|
|
|
186
181
|
def hash_string(self, s: str, hash_size: int) -> int:
|
|
187
182
|
return self.hash_fn(str(s), int(hash_size))
|
|
188
183
|
|
|
189
|
-
def
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
if
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
logger = logging.getLogger()
|
|
223
|
-
name = str(data.name)
|
|
224
|
-
scaler_type = config["scaler"]
|
|
225
|
-
fill_na_value = config.get("fill_na_value", 0)
|
|
226
|
-
filled_data = data.fillna(fill_na_value)
|
|
227
|
-
values = np.array(filled_data.values, dtype=np.float64)
|
|
228
|
-
if scaler_type == "log":
|
|
229
|
-
result = np.log1p(np.maximum(values, 0))
|
|
230
|
-
elif scaler_type == "none":
|
|
231
|
-
result = values
|
|
184
|
+
def polars_scan(self, file_paths: list[str], file_type: str):
|
|
185
|
+
file_type = file_type.lower()
|
|
186
|
+
if file_type == "csv":
|
|
187
|
+
return pl.scan_csv(file_paths, ignore_errors=True)
|
|
188
|
+
if file_type == "parquet":
|
|
189
|
+
return pl.scan_parquet(file_paths)
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[Data Processor Error] Polars backend only supports csv/parquet, got: {file_type}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def sequence_expr(
|
|
195
|
+
self, pl, name: str, config: Dict[str, Any], schema: Dict[str, Any]
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
generate polars expression for sequence feature processing
|
|
199
|
+
|
|
200
|
+
Example Input:
|
|
201
|
+
sequence_str: "1,2,3"
|
|
202
|
+
sequence_str: " 4, ,5 "
|
|
203
|
+
sequence_list: ["7", "8", "9"]
|
|
204
|
+
sequence_list: ["", "10", " 11 "]
|
|
205
|
+
|
|
206
|
+
Example Output:
|
|
207
|
+
sequence_str -> ["1","2","3"]
|
|
208
|
+
sequence_str -> ["4","5"]
|
|
209
|
+
sequence_list -> ["7","8","9"]
|
|
210
|
+
sequence_list -> ["10","11"]
|
|
211
|
+
"""
|
|
212
|
+
separator = config["separator"]
|
|
213
|
+
dtype = schema.get(name)
|
|
214
|
+
col = pl.col(name)
|
|
215
|
+
if dtype is not None and isinstance(dtype, pl.List):
|
|
216
|
+
seq_col = col
|
|
232
217
|
else:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
218
|
+
seq_col = col.cast(pl.Utf8).fill_null("").str.split(separator)
|
|
219
|
+
elem = pl.element().cast(pl.Utf8).str.strip_chars()
|
|
220
|
+
seq_col = seq_col.list.eval(
|
|
221
|
+
pl.when(elem == "").then(None).otherwise(elem)
|
|
222
|
+
).list.drop_nulls()
|
|
223
|
+
return seq_col
|
|
224
|
+
|
|
225
|
+
def apply_transforms(self, lf, schema: Dict[str, Any], warn_missing: bool):
|
|
226
|
+
"""
|
|
227
|
+
Apply all transformations to a Polars LazyFrame.
|
|
242
228
|
|
|
243
|
-
|
|
229
|
+
"""
|
|
244
230
|
logger = logging.getLogger()
|
|
231
|
+
expressions = []
|
|
245
232
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
if min_freq is not None:
|
|
252
|
-
counts = filled_data.value_counts()
|
|
253
|
-
config["_token_counts"] = counts.to_dict()
|
|
254
|
-
vocab = sorted(counts[counts >= min_freq].index.tolist())
|
|
255
|
-
low_freq_types = int((counts < min_freq).sum())
|
|
256
|
-
total_types = int(counts.size)
|
|
257
|
-
kept_types = total_types - low_freq_types
|
|
258
|
-
if not config.get("_min_freq_logged"):
|
|
259
|
-
logger.info(
|
|
260
|
-
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
261
|
-
f"{total_types} token types total, "
|
|
262
|
-
f"{low_freq_types} low-frequency, "
|
|
263
|
-
f"{kept_types} kept."
|
|
264
|
-
)
|
|
265
|
-
config["_min_freq_logged"] = True
|
|
266
|
-
else:
|
|
267
|
-
vocab = sorted(set(filled_data.tolist()))
|
|
268
|
-
if "<UNK>" not in vocab:
|
|
269
|
-
vocab.append("<UNK>")
|
|
270
|
-
token_to_idx = {token: idx for idx, token in enumerate(vocab)}
|
|
271
|
-
config["_token_to_idx"] = token_to_idx
|
|
272
|
-
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
273
|
-
config["vocab_size"] = len(vocab)
|
|
274
|
-
elif encode_method == "hash":
|
|
275
|
-
min_freq = config.get("min_freq")
|
|
276
|
-
if min_freq is not None:
|
|
277
|
-
counts = filled_data.value_counts()
|
|
278
|
-
config["_token_counts"] = counts.to_dict()
|
|
279
|
-
config["_unk_hash"] = self.hash_string(
|
|
280
|
-
"<UNK>", int(config["hash_size"])
|
|
281
|
-
)
|
|
282
|
-
low_freq_types = int((counts < min_freq).sum())
|
|
283
|
-
total_types = int(counts.size)
|
|
284
|
-
kept_types = total_types - low_freq_types
|
|
285
|
-
if not config.get("_min_freq_logged"):
|
|
286
|
-
logger.info(
|
|
287
|
-
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
288
|
-
f"{total_types} token types total, "
|
|
289
|
-
f"{low_freq_types} low-frequency, "
|
|
290
|
-
f"{kept_types} kept."
|
|
291
|
-
)
|
|
292
|
-
config["_min_freq_logged"] = True
|
|
293
|
-
config["vocab_size"] = config["hash_size"]
|
|
294
|
-
|
|
295
|
-
def process_sparse_feature_transform(
|
|
296
|
-
self, data: pd.Series, config: Dict[str, Any]
|
|
297
|
-
) -> np.ndarray:
|
|
298
|
-
name = str(data.name)
|
|
299
|
-
encode_method = config["encode_method"]
|
|
300
|
-
fill_na = config["fill_na"]
|
|
301
|
-
|
|
302
|
-
sparse_series = (
|
|
303
|
-
data if isinstance(data, pd.Series) else pd.Series(data, name=name)
|
|
304
|
-
)
|
|
305
|
-
sparse_series = sparse_series.fillna(fill_na).astype(str)
|
|
306
|
-
if encode_method == "label":
|
|
307
|
-
token_to_idx = config.get("_token_to_idx")
|
|
308
|
-
if isinstance(token_to_idx, dict):
|
|
309
|
-
unk_index = int(config.get("_unk_index", 0))
|
|
310
|
-
return np.fromiter(
|
|
311
|
-
(token_to_idx.get(v, unk_index) for v in sparse_series.to_numpy()),
|
|
312
|
-
dtype=np.int64,
|
|
313
|
-
count=sparse_series.size,
|
|
314
|
-
)
|
|
315
|
-
raise ValueError(
|
|
316
|
-
f"[Data Processor Error] Token index for {name} not fitted"
|
|
233
|
+
def map_with_default(expr, mapping: Dict[str, int], default: int, dtype):
|
|
234
|
+
# Compatible with older polars versions without Expr.map_dict
|
|
235
|
+
return expr.map_elements(
|
|
236
|
+
lambda x: mapping.get(x, default),
|
|
237
|
+
return_dtype=dtype,
|
|
317
238
|
)
|
|
318
239
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
240
|
+
def ensure_present(feature_name: str, label: str) -> bool:
|
|
241
|
+
if feature_name not in schema:
|
|
242
|
+
if warn_missing:
|
|
243
|
+
logger.warning(f"{label} feature {feature_name} not found in data")
|
|
244
|
+
return False
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
# Numeric features
|
|
248
|
+
for name, config in self.numeric_features.items():
|
|
249
|
+
if not ensure_present(name, "Numeric"):
|
|
250
|
+
continue
|
|
251
|
+
scaler_type = config["scaler"]
|
|
252
|
+
fill_na_value = config.get("fill_na_value", 0)
|
|
253
|
+
col = pl.col(name).cast(pl.Float64).fill_null(fill_na_value)
|
|
254
|
+
if scaler_type == "log":
|
|
255
|
+
col = col.clip(lower_bound=0).log1p()
|
|
256
|
+
elif scaler_type == "none":
|
|
257
|
+
pass
|
|
258
|
+
else:
|
|
259
|
+
scaler = self.scalers.get(name)
|
|
260
|
+
if scaler is None:
|
|
261
|
+
logger.warning(
|
|
262
|
+
f"Scaler for {name} not fitted, returning original values"
|
|
336
263
|
)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
264
|
+
else:
|
|
265
|
+
if scaler_type == "standard":
|
|
266
|
+
mean = float(scaler.mean_[0])
|
|
267
|
+
scale = (
|
|
268
|
+
float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
|
|
269
|
+
)
|
|
270
|
+
col = (col - mean) / scale
|
|
271
|
+
elif scaler_type == "minmax":
|
|
272
|
+
scale = float(scaler.scale_[0])
|
|
273
|
+
min_val = float(scaler.min_[0])
|
|
274
|
+
col = col * scale + min_val
|
|
275
|
+
elif scaler_type == "maxabs":
|
|
276
|
+
max_abs = float(scaler.max_abs_[0]) or 1.0
|
|
277
|
+
col = col / max_abs
|
|
278
|
+
elif scaler_type == "robust":
|
|
279
|
+
center = float(scaler.center_[0])
|
|
280
|
+
scale = (
|
|
281
|
+
float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
|
|
282
|
+
)
|
|
283
|
+
col = (col - center) / scale
|
|
284
|
+
expressions.append(col.alias(name))
|
|
343
285
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
key = str(token)
|
|
357
|
-
token_counts[key] = token_counts.get(key, 0) + 1
|
|
358
|
-
if min_freq is not None:
|
|
359
|
-
config["_token_counts"] = token_counts
|
|
360
|
-
vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
|
|
361
|
-
low_freq_types = sum(
|
|
362
|
-
1 for count in token_counts.values() if count < min_freq
|
|
363
|
-
)
|
|
364
|
-
total_types = len(token_counts)
|
|
365
|
-
kept_types = total_types - low_freq_types
|
|
366
|
-
if not config.get("_min_freq_logged"):
|
|
367
|
-
logger.info(
|
|
368
|
-
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
369
|
-
f"{total_types} token types total, "
|
|
370
|
-
f"{low_freq_types} low-frequency, "
|
|
371
|
-
f"{kept_types} kept."
|
|
286
|
+
# Sparse features
|
|
287
|
+
for name, config in self.sparse_features.items():
|
|
288
|
+
if not ensure_present(name, "Sparse"):
|
|
289
|
+
continue
|
|
290
|
+
encode_method = config["encode_method"]
|
|
291
|
+
fill_na = config["fill_na"]
|
|
292
|
+
col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
|
|
293
|
+
if encode_method == "label":
|
|
294
|
+
token_to_idx = config.get("_token_to_idx")
|
|
295
|
+
if not isinstance(token_to_idx, dict):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"[Data Processor Error] Token index for {name} not fitted"
|
|
372
298
|
)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
tokens = self.extract_sequence_tokens(seq, separator)
|
|
390
|
-
for token in tokens:
|
|
391
|
-
if str(token).strip():
|
|
392
|
-
token_counts[str(token)] = (
|
|
393
|
-
token_counts.get(str(token), 0) + 1
|
|
394
|
-
)
|
|
395
|
-
config["_token_counts"] = token_counts
|
|
396
|
-
config["_unk_hash"] = self.hash_string(
|
|
397
|
-
"<UNK>", int(config["hash_size"])
|
|
398
|
-
)
|
|
399
|
-
low_freq_types = sum(
|
|
400
|
-
1 for count in token_counts.values() if count < min_freq
|
|
401
|
-
)
|
|
402
|
-
total_types = len(token_counts)
|
|
403
|
-
kept_types = total_types - low_freq_types
|
|
404
|
-
if not config.get("_min_freq_logged"):
|
|
405
|
-
logger.info(
|
|
406
|
-
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
407
|
-
f"{total_types} token types total, "
|
|
408
|
-
f"{low_freq_types} low-frequency, "
|
|
409
|
-
f"{kept_types} kept."
|
|
299
|
+
unk_index = int(config.get("_unk_index", 0))
|
|
300
|
+
col = map_with_default(col, token_to_idx, unk_index, pl.Int64)
|
|
301
|
+
elif encode_method == "hash":
|
|
302
|
+
hash_size = config["hash_size"]
|
|
303
|
+
hash_expr = col.hash().cast(pl.UInt64) % int(hash_size)
|
|
304
|
+
min_freq = config.get("min_freq")
|
|
305
|
+
token_counts = config.get("_token_counts")
|
|
306
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
307
|
+
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
308
|
+
unk_hash = config.get("_unk_hash")
|
|
309
|
+
if unk_hash is None:
|
|
310
|
+
unk_hash = self.hash_string("<UNK>", int(hash_size))
|
|
311
|
+
hash_expr = (
|
|
312
|
+
pl.when(col.is_in(low_freq))
|
|
313
|
+
.then(int(unk_hash))
|
|
314
|
+
.otherwise(hash_expr)
|
|
410
315
|
)
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
arr = np.asarray(data, dtype=object)
|
|
425
|
-
n = arr.shape[0]
|
|
426
|
-
output = np.full((n, max_len), pad_value, dtype=np.int64)
|
|
427
|
-
# Shared helpers cached locally for speed and cross-platform consistency
|
|
428
|
-
split_fn = str.split
|
|
429
|
-
is_nan = np.isnan
|
|
430
|
-
if encode_method == "label":
|
|
431
|
-
class_to_idx = config.get("_token_to_idx")
|
|
432
|
-
if class_to_idx is None:
|
|
433
|
-
raise ValueError(
|
|
434
|
-
f"[Data Processor Error] Token index for {name} not fitted"
|
|
435
|
-
)
|
|
436
|
-
unk_index = int(config.get("_unk_index", class_to_idx.get("<UNK>", 0)))
|
|
437
|
-
else:
|
|
438
|
-
class_to_idx = None # type: ignore
|
|
439
|
-
unk_index = 0
|
|
440
|
-
hash_fn = self.hash_string
|
|
441
|
-
hash_size = config.get("hash_size")
|
|
442
|
-
min_freq = config.get("min_freq")
|
|
443
|
-
token_counts = config.get("_token_counts")
|
|
444
|
-
if min_freq is not None and isinstance(token_counts, dict):
|
|
445
|
-
unk_hash = config.get("_unk_hash")
|
|
446
|
-
if unk_hash is None and hash_size is not None:
|
|
447
|
-
unk_hash = hash_fn("<UNK>", hash_size)
|
|
448
|
-
for i, seq in enumerate(arr):
|
|
449
|
-
# normalize sequence to a list of strings
|
|
450
|
-
tokens = []
|
|
451
|
-
if seq is None:
|
|
452
|
-
tokens = []
|
|
453
|
-
elif isinstance(seq, (float, np.floating)):
|
|
454
|
-
tokens = [] if is_nan(seq) else [str(seq)]
|
|
455
|
-
elif isinstance(seq, str):
|
|
456
|
-
seq_str = seq.strip()
|
|
457
|
-
tokens = [] if not seq_str else split_fn(seq_str, separator)
|
|
458
|
-
elif isinstance(seq, (list, tuple, np.ndarray)):
|
|
459
|
-
tokens = [str(t) for t in seq]
|
|
460
|
-
else:
|
|
461
|
-
tokens = []
|
|
316
|
+
col = hash_expr.cast(pl.Int64)
|
|
317
|
+
expressions.append(col.alias(name))
|
|
318
|
+
|
|
319
|
+
# Sequence features
|
|
320
|
+
for name, config in self.sequence_features.items():
|
|
321
|
+
if not ensure_present(name, "Sequence"):
|
|
322
|
+
continue
|
|
323
|
+
encode_method = config["encode_method"]
|
|
324
|
+
max_len = int(config["max_len"])
|
|
325
|
+
pad_value = int(config["pad_value"])
|
|
326
|
+
truncate = config["truncate"]
|
|
327
|
+
seq_col = self.sequence_expr(pl, name, config, schema)
|
|
328
|
+
|
|
462
329
|
if encode_method == "label":
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
330
|
+
token_to_idx = config.get("_token_to_idx")
|
|
331
|
+
if not isinstance(token_to_idx, dict):
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"[Data Processor Error] Token index for {name} not fitted"
|
|
334
|
+
)
|
|
335
|
+
unk_index = int(config.get("_unk_index", 0))
|
|
336
|
+
seq_col = seq_col.list.eval(
|
|
337
|
+
map_with_default(pl.element(), token_to_idx, unk_index, pl.Int64)
|
|
338
|
+
)
|
|
468
339
|
elif encode_method == "hash":
|
|
340
|
+
hash_size = config.get("hash_size")
|
|
469
341
|
if hash_size is None:
|
|
470
342
|
raise ValueError(
|
|
471
343
|
"[Data Processor Error] hash_size must be set for hash encoding"
|
|
472
344
|
)
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
345
|
+
elem = pl.element().cast(pl.Utf8)
|
|
346
|
+
hash_expr = elem.hash().cast(pl.UInt64) % int(hash_size)
|
|
347
|
+
min_freq = config.get("min_freq")
|
|
348
|
+
token_counts = config.get("_token_counts")
|
|
349
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
350
|
+
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
351
|
+
unk_hash = config.get("_unk_hash")
|
|
352
|
+
if unk_hash is None:
|
|
353
|
+
unk_hash = self.hash_string("<UNK>", int(hash_size))
|
|
354
|
+
hash_expr = (
|
|
355
|
+
pl.when(elem.is_in(low_freq))
|
|
356
|
+
.then(int(unk_hash))
|
|
357
|
+
.otherwise(hash_expr)
|
|
480
358
|
)
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
359
|
+
seq_col = seq_col.list.eval(hash_expr)
|
|
360
|
+
|
|
361
|
+
if truncate == "pre":
|
|
362
|
+
seq_col = seq_col.list.tail(max_len)
|
|
484
363
|
else:
|
|
485
|
-
|
|
486
|
-
|
|
364
|
+
seq_col = seq_col.list.head(max_len)
|
|
365
|
+
pad_list = [pad_value] * max_len
|
|
366
|
+
seq_col = pl.concat_list([seq_col, pl.lit(pad_list)]).list.head(max_len)
|
|
367
|
+
expressions.append(seq_col.alias(name))
|
|
368
|
+
|
|
369
|
+
# Target features
|
|
370
|
+
for name, config in self.target_features.items():
|
|
371
|
+
if not ensure_present(name, "Target"):
|
|
487
372
|
continue
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
373
|
+
target_type = config.get("target_type")
|
|
374
|
+
col = pl.col(name)
|
|
375
|
+
if target_type == "regression":
|
|
376
|
+
col = col.cast(pl.Float32)
|
|
377
|
+
elif target_type == "binary":
|
|
378
|
+
label_map = self.target_encoders.get(name)
|
|
379
|
+
if label_map is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
f"[Data Processor Error] Target encoder for {name} not fitted"
|
|
382
|
+
)
|
|
383
|
+
col = map_with_default(col.cast(pl.Utf8), label_map, 0, pl.Int64).cast(
|
|
384
|
+
pl.Float32
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"[Data Processor Error] Unsupported target type: {target_type}"
|
|
389
|
+
)
|
|
390
|
+
expressions.append(col.alias(name))
|
|
391
|
+
|
|
392
|
+
if not expressions:
|
|
393
|
+
return lf
|
|
394
|
+
return lf.with_columns(expressions)
|
|
492
395
|
|
|
493
|
-
def process_target_fit(
|
|
494
|
-
|
|
396
|
+
def process_target_fit(
|
|
397
|
+
self, data: Iterable[Any], config: Dict[str, Any], name: str
|
|
398
|
+
) -> None:
|
|
495
399
|
target_type = config["target_type"]
|
|
496
400
|
label_map = config.get("label_map")
|
|
497
401
|
if target_type == "binary":
|
|
498
402
|
if label_map is None:
|
|
499
|
-
unique_values = data
|
|
500
|
-
|
|
403
|
+
unique_values = {v for v in data if v is not None}
|
|
404
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
405
|
+
sorted_values = sorted(v for v in unique_values if v is not None)
|
|
501
406
|
try:
|
|
502
407
|
int_values = [int(v) for v in sorted_values]
|
|
503
408
|
if int_values == list(range(len(int_values))):
|
|
@@ -511,254 +416,149 @@ class DataProcessor(FeatureSet):
|
|
|
511
416
|
config["label_map"] = label_map
|
|
512
417
|
self.target_encoders[name] = label_map
|
|
513
418
|
|
|
514
|
-
def
|
|
515
|
-
self, data: pd.Series, config: Dict[str, Any]
|
|
516
|
-
) -> np.ndarray:
|
|
419
|
+
def polars_fit_from_lazy(self, lf, schema: Dict[str, Any]) -> "DataProcessor":
|
|
517
420
|
logger = logging.getLogger()
|
|
518
|
-
name = str(data.name)
|
|
519
|
-
target_type = config.get("target_type")
|
|
520
|
-
if target_type == "regression":
|
|
521
|
-
values = np.array(data.values, dtype=np.float32)
|
|
522
|
-
return values
|
|
523
|
-
if target_type == "binary":
|
|
524
|
-
label_map = self.target_encoders.get(name)
|
|
525
|
-
if label_map is None:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"[Data Processor Error] Target encoder for {name} not fitted"
|
|
528
|
-
)
|
|
529
|
-
result = []
|
|
530
|
-
for val in data:
|
|
531
|
-
str_val = str(val)
|
|
532
|
-
if str_val in label_map:
|
|
533
|
-
result.append(label_map[str_val])
|
|
534
|
-
else:
|
|
535
|
-
logger.warning(f"Unknown target value: {val}, mapping to 0")
|
|
536
|
-
result.append(0)
|
|
537
|
-
return np.array(result, dtype=np.float32)
|
|
538
|
-
raise ValueError(
|
|
539
|
-
f"[Data Processor Error] Unsupported target type: {target_type}"
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
|
|
543
|
-
"""
|
|
544
|
-
Load all data from a file or directory path into a single DataFrame.
|
|
545
|
-
|
|
546
|
-
Args:
|
|
547
|
-
path (str): File or directory path.
|
|
548
|
-
|
|
549
|
-
Returns:
|
|
550
|
-
pd.DataFrame: Loaded DataFrame.
|
|
551
|
-
"""
|
|
552
|
-
file_paths, file_type = resolve_file_paths(path)
|
|
553
|
-
frames = load_dataframes(file_paths, file_type)
|
|
554
|
-
return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
555
|
-
|
|
556
|
-
def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
|
|
557
|
-
"""Extract sequence tokens from a single value."""
|
|
558
|
-
if value is None:
|
|
559
|
-
return []
|
|
560
|
-
if isinstance(value, (float, np.floating)) and np.isnan(value):
|
|
561
|
-
return []
|
|
562
|
-
if isinstance(value, str):
|
|
563
|
-
stripped = value.strip()
|
|
564
|
-
return [] if not stripped else stripped.split(separator)
|
|
565
|
-
if isinstance(value, (list, tuple, np.ndarray)):
|
|
566
|
-
return [str(v) for v in value]
|
|
567
|
-
return [str(value)]
|
|
568
|
-
|
|
569
|
-
def fit_from_file_paths(
|
|
570
|
-
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
571
|
-
) -> "DataProcessor":
|
|
572
|
-
logger = logging.getLogger()
|
|
573
|
-
if not file_paths:
|
|
574
|
-
raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
|
|
575
|
-
if not check_streaming_support(file_type):
|
|
576
|
-
raise ValueError(
|
|
577
|
-
f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
|
|
578
|
-
"Streaming fit only supports csv, parquet to avoid high memory usage."
|
|
579
|
-
)
|
|
580
|
-
|
|
581
|
-
numeric_acc = {}
|
|
582
|
-
for name in self.numeric_features.keys():
|
|
583
|
-
numeric_acc[name] = {
|
|
584
|
-
"sum": 0.0,
|
|
585
|
-
"sumsq": 0.0,
|
|
586
|
-
"count": 0.0,
|
|
587
|
-
"min": np.inf,
|
|
588
|
-
"max": -np.inf,
|
|
589
|
-
"max_abs": 0.0,
|
|
590
|
-
}
|
|
591
|
-
sparse_vocab: Dict[str, set[str]] = {
|
|
592
|
-
name: set() for name in self.sparse_features.keys()
|
|
593
|
-
}
|
|
594
|
-
seq_vocab: Dict[str, set[str]] = {
|
|
595
|
-
name: set() for name in self.sequence_features.keys()
|
|
596
|
-
}
|
|
597
|
-
sparse_label_counts: Dict[str, Dict[str, int]] = {
|
|
598
|
-
name: {}
|
|
599
|
-
for name, config in self.sparse_features.items()
|
|
600
|
-
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
601
|
-
}
|
|
602
|
-
seq_label_counts: Dict[str, Dict[str, int]] = {
|
|
603
|
-
name: {}
|
|
604
|
-
for name, config in self.sequence_features.items()
|
|
605
|
-
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
606
|
-
}
|
|
607
|
-
sparse_hash_counts: Dict[str, Dict[str, int]] = {
|
|
608
|
-
name: {}
|
|
609
|
-
for name, config in self.sparse_features.items()
|
|
610
|
-
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
611
|
-
}
|
|
612
|
-
seq_hash_counts: Dict[str, Dict[str, int]] = {
|
|
613
|
-
name: {}
|
|
614
|
-
for name, config in self.sequence_features.items()
|
|
615
|
-
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
616
|
-
}
|
|
617
|
-
target_values: Dict[str, set[Any]] = {
|
|
618
|
-
name: set() for name in self.target_features.keys()
|
|
619
|
-
}
|
|
620
421
|
|
|
621
422
|
missing_features = set()
|
|
622
|
-
for
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
series = chunk[name]
|
|
635
|
-
if group == "numeric":
|
|
636
|
-
values = pd.to_numeric(series, errors="coerce").dropna()
|
|
637
|
-
if values.empty:
|
|
638
|
-
continue
|
|
639
|
-
acc = numeric_acc[name]
|
|
640
|
-
arr = values.to_numpy(dtype=np.float64, copy=False)
|
|
641
|
-
acc["count"] += arr.size
|
|
642
|
-
acc["sum"] += float(arr.sum())
|
|
643
|
-
acc["sumsq"] += float(np.square(arr).sum())
|
|
644
|
-
acc["min"] = min(acc["min"], float(arr.min()))
|
|
645
|
-
acc["max"] = max(acc["max"], float(arr.max()))
|
|
646
|
-
acc["max_abs"] = max(
|
|
647
|
-
acc["max_abs"], float(np.abs(arr).max())
|
|
648
|
-
)
|
|
649
|
-
elif group == "sparse":
|
|
650
|
-
fill_na = config["fill_na"]
|
|
651
|
-
series = series.fillna(fill_na).astype(str)
|
|
652
|
-
sparse_vocab[name].update(series.tolist())
|
|
653
|
-
if name in sparse_label_counts:
|
|
654
|
-
counts = sparse_label_counts[name]
|
|
655
|
-
for token in series.tolist():
|
|
656
|
-
counts[token] = counts.get(token, 0) + 1
|
|
657
|
-
if name in sparse_hash_counts:
|
|
658
|
-
counts = sparse_hash_counts[name]
|
|
659
|
-
for token in series.tolist():
|
|
660
|
-
counts[token] = counts.get(token, 0) + 1
|
|
661
|
-
else:
|
|
662
|
-
separator = config["separator"]
|
|
663
|
-
tokens = []
|
|
664
|
-
for val in series:
|
|
665
|
-
tokens.extend(
|
|
666
|
-
self.extract_sequence_tokens(val, separator)
|
|
667
|
-
)
|
|
668
|
-
seq_vocab[name].update(tokens)
|
|
669
|
-
if name in seq_label_counts:
|
|
670
|
-
counts = seq_label_counts[name]
|
|
671
|
-
for token in tokens:
|
|
672
|
-
if str(token).strip():
|
|
673
|
-
key = str(token)
|
|
674
|
-
counts[key] = counts.get(key, 0) + 1
|
|
675
|
-
if name in seq_hash_counts:
|
|
676
|
-
counts = seq_hash_counts[name]
|
|
677
|
-
for token in tokens:
|
|
678
|
-
if str(token).strip():
|
|
679
|
-
key = str(token)
|
|
680
|
-
counts[key] = counts.get(key, 0) + 1
|
|
681
|
-
|
|
682
|
-
# target features
|
|
683
|
-
missing_features.update(self.target_features.keys() - columns)
|
|
684
|
-
for name in self.target_features.keys() & columns:
|
|
685
|
-
vals = chunk[name].dropna().tolist()
|
|
686
|
-
target_values[name].update(vals)
|
|
687
|
-
|
|
423
|
+
for name in self.numeric_features.keys():
|
|
424
|
+
if name not in schema:
|
|
425
|
+
missing_features.add(name)
|
|
426
|
+
for name in self.sparse_features.keys():
|
|
427
|
+
if name not in schema:
|
|
428
|
+
missing_features.add(name)
|
|
429
|
+
for name in self.sequence_features.keys():
|
|
430
|
+
if name not in schema:
|
|
431
|
+
missing_features.add(name)
|
|
432
|
+
for name in self.target_features.keys():
|
|
433
|
+
if name not in schema:
|
|
434
|
+
missing_features.add(name)
|
|
688
435
|
if missing_features:
|
|
689
436
|
logger.warning(
|
|
690
|
-
f"The following configured features were not found in provided
|
|
437
|
+
f"The following configured features were not found in provided data: {sorted(missing_features)}"
|
|
691
438
|
)
|
|
692
439
|
|
|
693
|
-
#
|
|
440
|
+
# numeric aggregates in a single pass
|
|
441
|
+
if self.numeric_features:
|
|
442
|
+
agg_exprs = []
|
|
443
|
+
for name in self.numeric_features.keys():
|
|
444
|
+
if name not in schema:
|
|
445
|
+
continue
|
|
446
|
+
col = pl.col(name).cast(pl.Float64)
|
|
447
|
+
agg_exprs.extend(
|
|
448
|
+
[
|
|
449
|
+
col.sum().alias(f"{name}__sum"),
|
|
450
|
+
(col * col).sum().alias(f"{name}__sumsq"),
|
|
451
|
+
col.count().alias(f"{name}__count"),
|
|
452
|
+
col.min().alias(f"{name}__min"),
|
|
453
|
+
col.max().alias(f"{name}__max"),
|
|
454
|
+
col.abs().max().alias(f"{name}__max_abs"),
|
|
455
|
+
]
|
|
456
|
+
)
|
|
457
|
+
if self.numeric_features[name].get("scaler") == "robust":
|
|
458
|
+
agg_exprs.extend(
|
|
459
|
+
[
|
|
460
|
+
col.quantile(0.25).alias(f"{name}__q1"),
|
|
461
|
+
col.quantile(0.75).alias(f"{name}__q3"),
|
|
462
|
+
col.median().alias(f"{name}__median"),
|
|
463
|
+
]
|
|
464
|
+
)
|
|
465
|
+
stats = lf.select(agg_exprs).collect().to_dicts()[0] if agg_exprs else {}
|
|
466
|
+
else:
|
|
467
|
+
stats = {}
|
|
468
|
+
|
|
694
469
|
for name, config in self.numeric_features.items():
|
|
695
|
-
|
|
696
|
-
|
|
470
|
+
if name not in schema:
|
|
471
|
+
continue
|
|
472
|
+
count = float(stats.get(f"{name}__count", 0) or 0)
|
|
473
|
+
if count == 0:
|
|
697
474
|
logger.warning(
|
|
698
|
-
f"Numeric feature {name} has no valid values in provided
|
|
475
|
+
f"Numeric feature {name} has no valid values in provided data"
|
|
699
476
|
)
|
|
700
477
|
continue
|
|
701
|
-
|
|
478
|
+
sum_val = float(stats.get(f"{name}__sum", 0) or 0)
|
|
479
|
+
sumsq = float(stats.get(f"{name}__sumsq", 0) or 0)
|
|
480
|
+
mean_val = sum_val / count
|
|
702
481
|
if config["fill_na"] is not None:
|
|
703
482
|
config["fill_na_value"] = config["fill_na"]
|
|
704
483
|
else:
|
|
705
484
|
config["fill_na_value"] = mean_val
|
|
706
485
|
scaler_type = config["scaler"]
|
|
707
486
|
if scaler_type == "standard":
|
|
708
|
-
var = max(
|
|
487
|
+
var = max(sumsq / count - mean_val * mean_val, 0.0)
|
|
709
488
|
scaler = StandardScaler()
|
|
710
489
|
scaler.mean_ = np.array([mean_val], dtype=np.float64)
|
|
711
490
|
scaler.var_ = np.array([var], dtype=np.float64)
|
|
712
491
|
scaler.scale_ = np.array(
|
|
713
492
|
[np.sqrt(var) if var > 0 else 1.0], dtype=np.float64
|
|
714
493
|
)
|
|
715
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
494
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
716
495
|
self.scalers[name] = scaler
|
|
717
|
-
|
|
718
496
|
elif scaler_type == "minmax":
|
|
719
|
-
data_min =
|
|
720
|
-
data_max =
|
|
497
|
+
data_min = float(stats.get(f"{name}__min", 0) or 0)
|
|
498
|
+
data_max = float(stats.get(f"{name}__max", data_min) or data_min)
|
|
721
499
|
scaler = MinMaxScaler()
|
|
722
500
|
scaler.data_min_ = np.array([data_min], dtype=np.float64)
|
|
723
501
|
scaler.data_max_ = np.array([data_max], dtype=np.float64)
|
|
724
502
|
scaler.data_range_ = scaler.data_max_ - scaler.data_min_
|
|
725
503
|
scaler.data_range_[scaler.data_range_ == 0] = 1.0
|
|
726
|
-
# Manually set scale_/min_ for streaming fit to mirror sklearn's internal fit logic
|
|
727
504
|
feature_min, feature_max = scaler.feature_range
|
|
728
505
|
scale = (feature_max - feature_min) / scaler.data_range_
|
|
729
506
|
scaler.scale_ = scale
|
|
730
507
|
scaler.min_ = feature_min - scaler.data_min_ * scale
|
|
731
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
508
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
732
509
|
self.scalers[name] = scaler
|
|
733
|
-
|
|
734
510
|
elif scaler_type == "maxabs":
|
|
511
|
+
max_abs = float(stats.get(f"{name}__max_abs", 1.0) or 1.0)
|
|
735
512
|
scaler = MaxAbsScaler()
|
|
736
|
-
scaler.max_abs_ = np.array([
|
|
737
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
513
|
+
scaler.max_abs_ = np.array([max_abs], dtype=np.float64)
|
|
514
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
738
515
|
self.scalers[name] = scaler
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
516
|
+
elif scaler_type == "robust":
|
|
517
|
+
q1 = float(stats.get(f"{name}__q1", 0) or 0)
|
|
518
|
+
q3 = float(stats.get(f"{name}__q3", q1) or q1)
|
|
519
|
+
median = float(stats.get(f"{name}__median", 0) or 0)
|
|
520
|
+
scale = q3 - q1
|
|
521
|
+
if scale == 0:
|
|
522
|
+
scale = 1.0
|
|
523
|
+
scaler = RobustScaler()
|
|
524
|
+
scaler.center_ = np.array([median], dtype=np.float64)
|
|
525
|
+
scaler.scale_ = np.array([scale], dtype=np.float64)
|
|
526
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
527
|
+
self.scalers[name] = scaler
|
|
528
|
+
elif scaler_type in ("log", "none"):
|
|
742
529
|
continue
|
|
743
530
|
else:
|
|
744
531
|
raise ValueError(f"Unknown scaler type: {scaler_type}")
|
|
745
532
|
|
|
746
|
-
#
|
|
533
|
+
# sparse features
|
|
747
534
|
for name, config in self.sparse_features.items():
|
|
748
|
-
if
|
|
535
|
+
if name not in schema:
|
|
536
|
+
continue
|
|
537
|
+
encode_method = config["encode_method"]
|
|
538
|
+
fill_na = config["fill_na"]
|
|
539
|
+
col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
|
|
540
|
+
counts_df = (
|
|
541
|
+
lf.select(col.alias(name))
|
|
542
|
+
.group_by(name)
|
|
543
|
+
.agg(pl.len().alias("count"))
|
|
544
|
+
.collect()
|
|
545
|
+
)
|
|
546
|
+
counts = (
|
|
547
|
+
dict(zip(counts_df[name].to_list(), counts_df["count"].to_list()))
|
|
548
|
+
if counts_df.height > 0
|
|
549
|
+
else {}
|
|
550
|
+
)
|
|
551
|
+
if encode_method == "label":
|
|
749
552
|
min_freq = config.get("min_freq")
|
|
750
553
|
if min_freq is not None:
|
|
751
|
-
|
|
752
|
-
config["_token_counts"] = token_counts
|
|
554
|
+
config["_token_counts"] = counts
|
|
753
555
|
vocab = {
|
|
754
|
-
token
|
|
755
|
-
for token, count in token_counts.items()
|
|
756
|
-
if count >= min_freq
|
|
556
|
+
token for token, count in counts.items() if count >= min_freq
|
|
757
557
|
}
|
|
758
558
|
low_freq_types = sum(
|
|
759
|
-
1 for count in
|
|
559
|
+
1 for count in counts.values() if count < min_freq
|
|
760
560
|
)
|
|
761
|
-
total_types = len(
|
|
561
|
+
total_types = len(counts)
|
|
762
562
|
kept_types = total_types - low_freq_types
|
|
763
563
|
if not config.get("_min_freq_logged"):
|
|
764
564
|
logger.info(
|
|
@@ -769,29 +569,29 @@ class DataProcessor(FeatureSet):
|
|
|
769
569
|
)
|
|
770
570
|
config["_min_freq_logged"] = True
|
|
771
571
|
else:
|
|
772
|
-
vocab =
|
|
572
|
+
vocab = set(counts.keys())
|
|
773
573
|
if not vocab:
|
|
774
574
|
logger.warning(f"Sparse feature {name} has empty vocabulary")
|
|
775
575
|
continue
|
|
776
|
-
|
|
576
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
577
|
+
vocab_list = sorted(v for v in vocab if v is not None)
|
|
777
578
|
if "<UNK>" not in vocab_list:
|
|
778
579
|
vocab_list.append("<UNK>")
|
|
779
580
|
token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
|
|
780
581
|
config["_token_to_idx"] = token_to_idx
|
|
781
582
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
782
583
|
config["vocab_size"] = len(vocab_list)
|
|
783
|
-
elif
|
|
584
|
+
elif encode_method == "hash":
|
|
784
585
|
min_freq = config.get("min_freq")
|
|
785
586
|
if min_freq is not None:
|
|
786
|
-
|
|
787
|
-
config["_token_counts"] = token_counts
|
|
587
|
+
config["_token_counts"] = counts
|
|
788
588
|
config["_unk_hash"] = self.hash_string(
|
|
789
589
|
"<UNK>", int(config["hash_size"])
|
|
790
590
|
)
|
|
791
591
|
low_freq_types = sum(
|
|
792
|
-
1 for count in
|
|
592
|
+
1 for count in counts.values() if count < min_freq
|
|
793
593
|
)
|
|
794
|
-
total_types = len(
|
|
594
|
+
total_types = len(counts)
|
|
795
595
|
kept_types = total_types - low_freq_types
|
|
796
596
|
if not config.get("_min_freq_logged"):
|
|
797
597
|
logger.info(
|
|
@@ -803,22 +603,37 @@ class DataProcessor(FeatureSet):
|
|
|
803
603
|
config["_min_freq_logged"] = True
|
|
804
604
|
config["vocab_size"] = config["hash_size"]
|
|
805
605
|
|
|
806
|
-
#
|
|
606
|
+
# sequence features
|
|
807
607
|
for name, config in self.sequence_features.items():
|
|
808
|
-
if
|
|
608
|
+
if name not in schema:
|
|
609
|
+
continue
|
|
610
|
+
encode_method = config["encode_method"]
|
|
611
|
+
seq_col = self.sequence_expr(pl, name, config, schema)
|
|
612
|
+
tokens_df = (
|
|
613
|
+
lf.select(seq_col.alias("seq"))
|
|
614
|
+
.explode("seq")
|
|
615
|
+
.select(pl.col("seq").cast(pl.Utf8).alias("seq"))
|
|
616
|
+
.drop_nulls("seq")
|
|
617
|
+
.group_by("seq")
|
|
618
|
+
.agg(pl.len().alias("count"))
|
|
619
|
+
.collect()
|
|
620
|
+
)
|
|
621
|
+
counts = (
|
|
622
|
+
dict(zip(tokens_df["seq"].to_list(), tokens_df["count"].to_list()))
|
|
623
|
+
if tokens_df.height > 0
|
|
624
|
+
else {}
|
|
625
|
+
)
|
|
626
|
+
if encode_method == "label":
|
|
809
627
|
min_freq = config.get("min_freq")
|
|
810
628
|
if min_freq is not None:
|
|
811
|
-
|
|
812
|
-
config["_token_counts"] = token_counts
|
|
629
|
+
config["_token_counts"] = counts
|
|
813
630
|
vocab_set = {
|
|
814
|
-
token
|
|
815
|
-
for token, count in token_counts.items()
|
|
816
|
-
if count >= min_freq
|
|
631
|
+
token for token, count in counts.items() if count >= min_freq
|
|
817
632
|
}
|
|
818
633
|
low_freq_types = sum(
|
|
819
|
-
1 for count in
|
|
634
|
+
1 for count in counts.values() if count < min_freq
|
|
820
635
|
)
|
|
821
|
-
total_types = len(
|
|
636
|
+
total_types = len(counts)
|
|
822
637
|
kept_types = total_types - low_freq_types
|
|
823
638
|
if not config.get("_min_freq_logged"):
|
|
824
639
|
logger.info(
|
|
@@ -829,26 +644,30 @@ class DataProcessor(FeatureSet):
|
|
|
829
644
|
)
|
|
830
645
|
config["_min_freq_logged"] = True
|
|
831
646
|
else:
|
|
832
|
-
vocab_set =
|
|
833
|
-
|
|
647
|
+
vocab_set = set(counts.keys())
|
|
648
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
649
|
+
vocab_list = (
|
|
650
|
+
sorted(v for v in vocab_set if v is not None)
|
|
651
|
+
if vocab_set
|
|
652
|
+
else ["<PAD>"]
|
|
653
|
+
)
|
|
834
654
|
if "<UNK>" not in vocab_list:
|
|
835
655
|
vocab_list.append("<UNK>")
|
|
836
656
|
token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
|
|
837
657
|
config["_token_to_idx"] = token_to_idx
|
|
838
658
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
839
659
|
config["vocab_size"] = len(vocab_list)
|
|
840
|
-
elif
|
|
660
|
+
elif encode_method == "hash":
|
|
841
661
|
min_freq = config.get("min_freq")
|
|
842
662
|
if min_freq is not None:
|
|
843
|
-
|
|
844
|
-
config["_token_counts"] = token_counts
|
|
663
|
+
config["_token_counts"] = counts
|
|
845
664
|
config["_unk_hash"] = self.hash_string(
|
|
846
665
|
"<UNK>", int(config["hash_size"])
|
|
847
666
|
)
|
|
848
667
|
low_freq_types = sum(
|
|
849
|
-
1 for count in
|
|
668
|
+
1 for count in counts.values() if count < min_freq
|
|
850
669
|
)
|
|
851
|
-
total_types = len(
|
|
670
|
+
total_types = len(counts)
|
|
852
671
|
kept_types = total_types - low_freq_types
|
|
853
672
|
if not config.get("_min_freq_logged"):
|
|
854
673
|
logger.info(
|
|
@@ -860,14 +679,18 @@ class DataProcessor(FeatureSet):
|
|
|
860
679
|
config["_min_freq_logged"] = True
|
|
861
680
|
config["vocab_size"] = config["hash_size"]
|
|
862
681
|
|
|
863
|
-
#
|
|
682
|
+
# targets
|
|
864
683
|
for name, config in self.target_features.items():
|
|
865
|
-
if not
|
|
866
|
-
logger.warning(f"Target {name} has no valid values in provided files")
|
|
684
|
+
if name not in schema:
|
|
867
685
|
continue
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
686
|
+
if config.get("target_type") == "binary":
|
|
687
|
+
unique_vals = (
|
|
688
|
+
lf.select(pl.col(name).drop_nulls().unique())
|
|
689
|
+
.collect()
|
|
690
|
+
.to_series()
|
|
691
|
+
.to_list()
|
|
692
|
+
)
|
|
693
|
+
self.process_target_fit(unique_vals, config, name)
|
|
871
694
|
|
|
872
695
|
self.is_fitted = True
|
|
873
696
|
logger.info(
|
|
@@ -879,18 +702,11 @@ class DataProcessor(FeatureSet):
|
|
|
879
702
|
)
|
|
880
703
|
return self
|
|
881
704
|
|
|
882
|
-
def fit_from_files(
|
|
883
|
-
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
884
|
-
) -> "DataProcessor":
|
|
885
|
-
"""Fit processor statistics by streaming an explicit list of files.
|
|
886
|
-
|
|
887
|
-
This is useful when you want to fit statistics on training files only (exclude
|
|
888
|
-
validation files) in streaming mode.
|
|
889
|
-
"""
|
|
705
|
+
def fit_from_files(self, file_paths: list[str], file_type: str) -> "DataProcessor":
|
|
890
706
|
logger = logging.getLogger()
|
|
891
707
|
logger.info(
|
|
892
708
|
colorize(
|
|
893
|
-
"Fitting DataProcessor
|
|
709
|
+
"Fitting DataProcessor...",
|
|
894
710
|
color="cyan",
|
|
895
711
|
bold=True,
|
|
896
712
|
)
|
|
@@ -899,34 +715,15 @@ class DataProcessor(FeatureSet):
|
|
|
899
715
|
config.pop("_min_freq_logged", None)
|
|
900
716
|
for config in self.sequence_features.values():
|
|
901
717
|
config.pop("_min_freq_logged", None)
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
)
|
|
905
|
-
if uses_robust:
|
|
906
|
-
logger.warning(
|
|
907
|
-
"Robust scaler requires full data; loading provided files into memory. "
|
|
908
|
-
"Consider smaller chunk_size or different scaler if memory is limited."
|
|
909
|
-
)
|
|
910
|
-
frames = [read_table(p, file_type) for p in file_paths]
|
|
911
|
-
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
912
|
-
return self.fit(df)
|
|
913
|
-
return self.fit_from_file_paths(file_paths=file_paths, file_type=file_type, chunk_size=chunk_size)
|
|
914
|
-
|
|
915
|
-
def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
|
|
916
|
-
"""
|
|
917
|
-
Fit processor statistics by streaming files to reduce memory usage.
|
|
718
|
+
lf = self.polars_scan(file_paths, file_type)
|
|
719
|
+
schema = lf.collect_schema()
|
|
720
|
+
return self.polars_fit_from_lazy(lf, schema)
|
|
918
721
|
|
|
919
|
-
|
|
920
|
-
path (str): File or directory path.
|
|
921
|
-
chunk_size (int): Number of rows per chunk.
|
|
922
|
-
|
|
923
|
-
Returns:
|
|
924
|
-
DataProcessor: Fitted DataProcessor instance.
|
|
925
|
-
"""
|
|
722
|
+
def fit_from_path(self, path: str) -> "DataProcessor":
|
|
926
723
|
logger = logging.getLogger()
|
|
927
724
|
logger.info(
|
|
928
725
|
colorize(
|
|
929
|
-
"Fitting DataProcessor
|
|
726
|
+
"Fitting DataProcessor...",
|
|
930
727
|
color="cyan",
|
|
931
728
|
bold=True,
|
|
932
729
|
)
|
|
@@ -936,118 +733,35 @@ class DataProcessor(FeatureSet):
|
|
|
936
733
|
for config in self.sequence_features.values():
|
|
937
734
|
config.pop("_min_freq_logged", None)
|
|
938
735
|
file_paths, file_type = resolve_file_paths(path)
|
|
939
|
-
return self.
|
|
940
|
-
file_paths=file_paths,
|
|
941
|
-
file_type=file_type,
|
|
942
|
-
chunk_size=chunk_size,
|
|
943
|
-
)
|
|
944
|
-
|
|
945
|
-
@overload
|
|
946
|
-
def transform_in_memory(
|
|
947
|
-
self,
|
|
948
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
949
|
-
return_dict: Literal[True],
|
|
950
|
-
persist: bool,
|
|
951
|
-
save_format: Optional[str],
|
|
952
|
-
output_path: Optional[str],
|
|
953
|
-
warn_missing: bool = True,
|
|
954
|
-
) -> Dict[str, np.ndarray]: ...
|
|
736
|
+
return self.fit_from_files(file_paths=file_paths, file_type=file_type)
|
|
955
737
|
|
|
956
|
-
@overload
|
|
957
738
|
def transform_in_memory(
|
|
958
739
|
self,
|
|
959
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
960
|
-
return_dict: Literal[False],
|
|
961
|
-
persist: bool,
|
|
962
|
-
save_format: Optional[str],
|
|
963
|
-
output_path: Optional[str],
|
|
964
|
-
warn_missing: bool = True,
|
|
965
|
-
) -> pd.DataFrame: ...
|
|
966
|
-
|
|
967
|
-
def transform_in_memory(
|
|
968
|
-
self,
|
|
969
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
740
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
970
741
|
return_dict: bool,
|
|
971
742
|
persist: bool,
|
|
972
743
|
save_format: Optional[str],
|
|
973
744
|
output_path: Optional[str],
|
|
974
745
|
warn_missing: bool = True,
|
|
975
746
|
):
|
|
976
|
-
"""
|
|
977
|
-
Transform in-memory data and optionally persist the transformed data.
|
|
978
|
-
|
|
979
|
-
Args:
|
|
980
|
-
data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
|
|
981
|
-
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
982
|
-
persist (bool): Whether to persist the transformed data to disk.
|
|
983
|
-
save_format (Optional[str]): Format to save the data if persisting.
|
|
984
|
-
output_path (Optional[str]): Output path to save the data if persisting.
|
|
985
|
-
warn_missing (bool): Whether to warn about missing features in the data.
|
|
986
|
-
|
|
987
|
-
Returns:
|
|
988
|
-
Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
|
|
989
|
-
"""
|
|
990
|
-
|
|
991
747
|
logger = logging.getLogger()
|
|
992
|
-
data_dict = data if isinstance(data, dict) else None
|
|
993
748
|
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
result_dict[col] = df[col].to_numpy(copy=False)
|
|
749
|
+
if isinstance(data, dict):
|
|
750
|
+
df = pl.DataFrame(data)
|
|
751
|
+
elif isinstance(data, pd.DataFrame):
|
|
752
|
+
df = pl.from_pandas(data)
|
|
999
753
|
else:
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
result_dict[key] = value.to_numpy(copy=False)
|
|
1007
|
-
else:
|
|
1008
|
-
result_dict[key] = np.asarray(value)
|
|
1009
|
-
|
|
1010
|
-
data_columns = data.columns if isinstance(data, pd.DataFrame) else data_dict
|
|
1011
|
-
feature_groups = [
|
|
1012
|
-
("Numeric", self.numeric_features, self.process_numeric_feature_transform),
|
|
1013
|
-
("Sparse", self.sparse_features, self.process_sparse_feature_transform),
|
|
1014
|
-
(
|
|
1015
|
-
"Sequence",
|
|
1016
|
-
self.sequence_features,
|
|
1017
|
-
self.process_sequence_feature_transform,
|
|
1018
|
-
),
|
|
1019
|
-
("Target", self.target_features, self.process_target_transform),
|
|
1020
|
-
]
|
|
1021
|
-
for label, features, transform_fn in feature_groups:
|
|
1022
|
-
for name, config in features.items():
|
|
1023
|
-
present = name in data_columns # type: ignore[operator]
|
|
1024
|
-
if not present:
|
|
1025
|
-
if warn_missing:
|
|
1026
|
-
logger.warning(f"{label} feature {name} not found in data")
|
|
1027
|
-
continue
|
|
1028
|
-
series_data = (
|
|
1029
|
-
data[name]
|
|
1030
|
-
if isinstance(data, pd.DataFrame)
|
|
1031
|
-
else pd.Series(result_dict[name], name=name)
|
|
1032
|
-
)
|
|
1033
|
-
result_dict[name] = transform_fn(series_data, config)
|
|
1034
|
-
|
|
1035
|
-
def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
|
|
1036
|
-
# Convert all arrays to Series/lists at once to avoid fragmentation
|
|
1037
|
-
columns_dict = {}
|
|
1038
|
-
for key, value in result.items():
|
|
1039
|
-
if key in self.sequence_features:
|
|
1040
|
-
columns_dict[key] = np.asarray(value).tolist()
|
|
1041
|
-
else:
|
|
1042
|
-
columns_dict[key] = value
|
|
1043
|
-
return pd.DataFrame(columns_dict)
|
|
754
|
+
df = data
|
|
755
|
+
|
|
756
|
+
schema = df.schema
|
|
757
|
+
lf = df.lazy()
|
|
758
|
+
lf = self.apply_transforms(lf, schema, warn_missing=warn_missing)
|
|
759
|
+
out_df = lf.collect()
|
|
1044
760
|
|
|
1045
761
|
effective_format = save_format
|
|
1046
762
|
if persist:
|
|
1047
763
|
effective_format = save_format or "parquet"
|
|
1048
|
-
|
|
1049
|
-
if (not return_dict) or persist:
|
|
1050
|
-
result_df = dict_to_dataframe(result_dict)
|
|
764
|
+
|
|
1051
765
|
if persist:
|
|
1052
766
|
if effective_format not in FILE_FORMAT_CONFIG:
|
|
1053
767
|
raise ValueError(f"Unsupported save format: {effective_format}")
|
|
@@ -1059,68 +773,63 @@ class DataProcessor(FeatureSet):
|
|
|
1059
773
|
if output_dir.suffix:
|
|
1060
774
|
output_dir = output_dir.parent
|
|
1061
775
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1062
|
-
|
|
1063
776
|
suffix = FILE_FORMAT_CONFIG[effective_format]["extension"][0]
|
|
1064
777
|
save_path = output_dir / f"transformed_data{suffix}"
|
|
1065
|
-
assert result_df is not None, "DataFrame conversion failed"
|
|
1066
|
-
|
|
1067
|
-
# Save based on format
|
|
1068
778
|
if effective_format == "csv":
|
|
1069
|
-
|
|
779
|
+
out_df.write_csv(save_path)
|
|
1070
780
|
elif effective_format == "parquet":
|
|
1071
|
-
|
|
781
|
+
out_df.write_parquet(save_path)
|
|
1072
782
|
elif effective_format == "feather":
|
|
1073
|
-
|
|
1074
|
-
elif effective_format == "excel":
|
|
1075
|
-
result_df.to_excel(save_path, index=False)
|
|
1076
|
-
elif effective_format == "hdf5":
|
|
1077
|
-
result_df.to_hdf(save_path, key="data", mode="w")
|
|
783
|
+
out_df.write_ipc(save_path)
|
|
1078
784
|
else:
|
|
1079
|
-
raise ValueError(
|
|
1080
|
-
|
|
785
|
+
raise ValueError(
|
|
786
|
+
f"Format '{effective_format}' is not supported by the polars-only pipeline."
|
|
787
|
+
)
|
|
1081
788
|
logger.info(
|
|
1082
789
|
colorize(
|
|
1083
|
-
f"Transformed data saved to: {save_path.resolve()}",
|
|
790
|
+
f"Transformed data saved to: {save_path.resolve()}",
|
|
791
|
+
color="green",
|
|
1084
792
|
)
|
|
1085
793
|
)
|
|
794
|
+
|
|
1086
795
|
if return_dict:
|
|
796
|
+
result_dict = {}
|
|
797
|
+
for col in out_df.columns:
|
|
798
|
+
series = out_df.get_column(col)
|
|
799
|
+
if col in self.sequence_features:
|
|
800
|
+
result_dict[col] = np.asarray(series.to_list(), dtype=np.int64)
|
|
801
|
+
else:
|
|
802
|
+
result_dict[col] = series.to_numpy()
|
|
1087
803
|
return result_dict
|
|
1088
|
-
|
|
1089
|
-
return
|
|
804
|
+
|
|
805
|
+
return out_df
|
|
1090
806
|
|
|
1091
807
|
def transform_path(
|
|
1092
808
|
self,
|
|
1093
809
|
input_path: str,
|
|
1094
810
|
output_path: Optional[str],
|
|
1095
811
|
save_format: Optional[str],
|
|
1096
|
-
chunk_size: int = 200000,
|
|
1097
812
|
):
|
|
1098
|
-
"""Transform data from files under a path and save them
|
|
1099
|
-
|
|
1100
|
-
Uses chunked reading/writing to keep peak memory bounded for large files.
|
|
1101
|
-
|
|
1102
|
-
Args:
|
|
1103
|
-
input_path (str): Input file or directory path.
|
|
1104
|
-
output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
|
|
1105
|
-
save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
|
|
1106
|
-
chunk_size (int): Number of rows per chunk.
|
|
1107
|
-
"""
|
|
813
|
+
"""Transform data from files under a path and save them using polars lazy pipeline."""
|
|
1108
814
|
logger = logging.getLogger()
|
|
1109
815
|
file_paths, file_type = resolve_file_paths(input_path)
|
|
1110
816
|
target_format = save_format or file_type
|
|
1111
817
|
if target_format not in FILE_FORMAT_CONFIG:
|
|
1112
818
|
raise ValueError(f"Unsupported format: {target_format}")
|
|
1113
|
-
if
|
|
819
|
+
if target_format not in {"csv", "parquet", "feather"}:
|
|
820
|
+
raise ValueError(
|
|
821
|
+
f"Format '{target_format}' is not supported by the polars-only pipeline."
|
|
822
|
+
)
|
|
823
|
+
if not check_streaming_support(file_type):
|
|
1114
824
|
raise ValueError(
|
|
1115
825
|
f"Input format '{file_type}' does not support streaming reads. "
|
|
1116
|
-
"
|
|
826
|
+
"Polars backend supports csv/parquet only."
|
|
1117
827
|
)
|
|
1118
828
|
|
|
1119
|
-
# Warn about streaming support
|
|
1120
829
|
if not check_streaming_support(target_format):
|
|
1121
830
|
logger.warning(
|
|
1122
831
|
f"[Data Processor Warning] Format '{target_format}' does not support streaming writes. "
|
|
1123
|
-
"
|
|
832
|
+
"Data will be collected in memory before saving."
|
|
1124
833
|
)
|
|
1125
834
|
|
|
1126
835
|
base_output_dir = (
|
|
@@ -1131,122 +840,48 @@ class DataProcessor(FeatureSet):
|
|
|
1131
840
|
output_root = base_output_dir / "transformed_data"
|
|
1132
841
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
1133
842
|
saved_paths = []
|
|
843
|
+
|
|
1134
844
|
for file_path in progress(file_paths, description="Transforming files"):
|
|
1135
845
|
source_path = Path(file_path)
|
|
1136
846
|
suffix = FILE_FORMAT_CONFIG[target_format]["extension"][0]
|
|
1137
847
|
target_file = output_root / f"{source_path.stem}{suffix}"
|
|
1138
848
|
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
df = read_table(file_path, file_type)
|
|
1143
|
-
transformed_df = self.transform_in_memory(
|
|
1144
|
-
df,
|
|
1145
|
-
return_dict=False,
|
|
1146
|
-
persist=False,
|
|
1147
|
-
save_format=None,
|
|
1148
|
-
output_path=None,
|
|
1149
|
-
warn_missing=True,
|
|
1150
|
-
)
|
|
1151
|
-
assert isinstance(
|
|
1152
|
-
transformed_df, pd.DataFrame
|
|
1153
|
-
), "[Data Processor Error] Expected DataFrame when return_dict=False"
|
|
1154
|
-
|
|
1155
|
-
# Save based on format
|
|
1156
|
-
if target_format == "csv":
|
|
1157
|
-
transformed_df.to_csv(target_file, index=False)
|
|
1158
|
-
elif target_format == "parquet":
|
|
1159
|
-
transformed_df.to_parquet(target_file, index=False)
|
|
1160
|
-
elif target_format == "feather":
|
|
1161
|
-
transformed_df.to_feather(target_file)
|
|
1162
|
-
elif target_format == "excel":
|
|
1163
|
-
transformed_df.to_excel(target_file, index=False)
|
|
1164
|
-
elif target_format == "hdf5":
|
|
1165
|
-
transformed_df.to_hdf(target_file, key="data", mode="w")
|
|
1166
|
-
else:
|
|
1167
|
-
raise ValueError(f"Unsupported format: {target_format}")
|
|
1168
|
-
|
|
1169
|
-
saved_paths.append(str(target_file.resolve()))
|
|
1170
|
-
continue
|
|
849
|
+
lf = self.polars_scan([file_path], file_type)
|
|
850
|
+
schema = lf.collect_schema()
|
|
851
|
+
lf = self.apply_transforms(lf, schema, warn_missing=True)
|
|
1171
852
|
|
|
1172
|
-
first_chunk = True
|
|
1173
|
-
# Streaming write for supported formats
|
|
1174
853
|
if target_format == "parquet":
|
|
1175
|
-
|
|
1176
|
-
try:
|
|
1177
|
-
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
1178
|
-
transformed_df = self.transform_in_memory(
|
|
1179
|
-
chunk,
|
|
1180
|
-
return_dict=False,
|
|
1181
|
-
persist=False,
|
|
1182
|
-
save_format=None,
|
|
1183
|
-
output_path=None,
|
|
1184
|
-
warn_missing=first_chunk,
|
|
1185
|
-
)
|
|
1186
|
-
assert isinstance(
|
|
1187
|
-
transformed_df, pd.DataFrame
|
|
1188
|
-
), "[Data Processor Error] Expected DataFrame when return_dict=False"
|
|
1189
|
-
table = pa.Table.from_pandas(
|
|
1190
|
-
transformed_df, preserve_index=False
|
|
1191
|
-
)
|
|
1192
|
-
if parquet_writer is None:
|
|
1193
|
-
parquet_writer = pq.ParquetWriter(target_file, table.schema)
|
|
1194
|
-
parquet_writer.write_table(table)
|
|
1195
|
-
first_chunk = False
|
|
1196
|
-
finally:
|
|
1197
|
-
if parquet_writer is not None:
|
|
1198
|
-
parquet_writer.close()
|
|
854
|
+
lf.sink_parquet(target_file)
|
|
1199
855
|
elif target_format == "csv":
|
|
1200
|
-
# CSV
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
856
|
+
# CSV doesn't support nested data (lists), so convert list columns to string
|
|
857
|
+
transformed_schema = lf.collect_schema()
|
|
858
|
+
list_cols = [
|
|
859
|
+
name
|
|
860
|
+
for name, dtype in transformed_schema.items()
|
|
861
|
+
if isinstance(dtype, pl.List)
|
|
862
|
+
]
|
|
863
|
+
if list_cols:
|
|
864
|
+
# Convert list columns to string representation for CSV
|
|
865
|
+
# Format as [1, 2, 3] by casting elements to string, joining with ", ", and adding brackets
|
|
866
|
+
list_exprs = []
|
|
867
|
+
for name in list_cols:
|
|
868
|
+
# Convert list to string representation
|
|
869
|
+
list_exprs.append(
|
|
870
|
+
(
|
|
871
|
+
pl.lit("[")
|
|
872
|
+
+ pl.col(name)
|
|
873
|
+
.list.eval(pl.element().cast(pl.String))
|
|
874
|
+
.list.join(", ")
|
|
875
|
+
+ pl.lit("]")
|
|
876
|
+
).alias(name)
|
|
877
|
+
)
|
|
878
|
+
lf = lf.with_columns(list_exprs)
|
|
879
|
+
lf.sink_csv(target_file)
|
|
1220
880
|
else:
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
f"Format '{target_format}' doesn't support streaming writes. "
|
|
1224
|
-
f"Collecting all chunks in memory before saving."
|
|
1225
|
-
)
|
|
1226
|
-
all_chunks = []
|
|
1227
|
-
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
1228
|
-
transformed_df = self.transform_in_memory(
|
|
1229
|
-
chunk,
|
|
1230
|
-
return_dict=False,
|
|
1231
|
-
persist=False,
|
|
1232
|
-
save_format=None,
|
|
1233
|
-
output_path=None,
|
|
1234
|
-
warn_missing=first_chunk,
|
|
1235
|
-
)
|
|
1236
|
-
assert isinstance(transformed_df, pd.DataFrame)
|
|
1237
|
-
all_chunks.append(transformed_df)
|
|
1238
|
-
first_chunk = False
|
|
1239
|
-
|
|
1240
|
-
if all_chunks:
|
|
1241
|
-
combined_df = pd.concat(all_chunks, ignore_index=True)
|
|
1242
|
-
if target_format == "feather":
|
|
1243
|
-
combined_df.to_feather(target_file)
|
|
1244
|
-
elif target_format == "excel":
|
|
1245
|
-
combined_df.to_excel(target_file, index=False)
|
|
1246
|
-
elif target_format == "hdf5":
|
|
1247
|
-
combined_df.to_hdf(target_file, key="data", mode="w")
|
|
1248
|
-
|
|
881
|
+
df = lf.collect()
|
|
882
|
+
df.write_ipc(target_file)
|
|
1249
883
|
saved_paths.append(str(target_file.resolve()))
|
|
884
|
+
|
|
1250
885
|
logger.info(
|
|
1251
886
|
colorize(
|
|
1252
887
|
f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
|
|
@@ -1258,74 +893,51 @@ class DataProcessor(FeatureSet):
|
|
|
1258
893
|
# fit is nothing but registering the statistics from data so that we can transform the data later
|
|
1259
894
|
def fit(
|
|
1260
895
|
self,
|
|
1261
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1262
|
-
chunk_size: int = 200000,
|
|
896
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1263
897
|
):
|
|
1264
898
|
"""
|
|
1265
899
|
Fit the DataProcessor to the provided data.
|
|
1266
900
|
|
|
1267
901
|
Args:
|
|
1268
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
|
|
1269
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
902
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
|
|
1270
903
|
|
|
1271
904
|
Returns:
|
|
1272
905
|
DataProcessor: Fitted DataProcessor instance.
|
|
1273
906
|
"""
|
|
1274
907
|
|
|
1275
|
-
logger = logging.getLogger()
|
|
1276
908
|
for config in self.sparse_features.values():
|
|
1277
909
|
config.pop("_min_freq_logged", None)
|
|
1278
910
|
for config in self.sequence_features.values():
|
|
1279
911
|
config.pop("_min_freq_logged", None)
|
|
1280
912
|
if isinstance(data, (str, os.PathLike)):
|
|
1281
|
-
|
|
1282
|
-
uses_robust = any(
|
|
1283
|
-
cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
|
|
1284
|
-
)
|
|
1285
|
-
if uses_robust:
|
|
1286
|
-
logger.warning(
|
|
1287
|
-
"Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited."
|
|
1288
|
-
)
|
|
1289
|
-
data = self.load_dataframe_from_path(path_str)
|
|
1290
|
-
else:
|
|
1291
|
-
return self.fit_from_path(path_str, chunk_size)
|
|
913
|
+
return self.fit_from_path(str(data))
|
|
1292
914
|
if isinstance(data, dict):
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
for label, features, fit_fn in feature_groups:
|
|
1302
|
-
for name, config in features.items():
|
|
1303
|
-
if name not in data.columns:
|
|
1304
|
-
logger.warning(f"{label} feature {name} not found in data")
|
|
1305
|
-
continue
|
|
1306
|
-
fit_fn(data[name], config)
|
|
1307
|
-
self.is_fitted = True
|
|
1308
|
-
return self
|
|
915
|
+
df = pl.DataFrame(data)
|
|
916
|
+
elif isinstance(data, pd.DataFrame):
|
|
917
|
+
df = pl.from_pandas(data)
|
|
918
|
+
else:
|
|
919
|
+
df = data
|
|
920
|
+
lf = df.lazy()
|
|
921
|
+
schema = df.schema
|
|
922
|
+
return self.polars_fit_from_lazy(lf, schema)
|
|
1309
923
|
|
|
1310
924
|
@overload
|
|
1311
925
|
def transform(
|
|
1312
926
|
self,
|
|
1313
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
927
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1314
928
|
return_dict: Literal[True] = True,
|
|
1315
929
|
save_format: Optional[str] = None,
|
|
1316
930
|
output_path: Optional[str] = None,
|
|
1317
|
-
chunk_size: int = 200000,
|
|
1318
931
|
) -> Dict[str, np.ndarray]: ...
|
|
1319
932
|
|
|
1320
933
|
@overload
|
|
1321
934
|
def transform(
|
|
1322
935
|
self,
|
|
1323
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
936
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1324
937
|
return_dict: Literal[False] = False,
|
|
1325
938
|
save_format: Optional[str] = None,
|
|
1326
939
|
output_path: Optional[str] = None,
|
|
1327
|
-
|
|
1328
|
-
) -> pd.DataFrame: ...
|
|
940
|
+
) -> pl.DataFrame: ...
|
|
1329
941
|
|
|
1330
942
|
@overload
|
|
1331
943
|
def transform(
|
|
@@ -1334,28 +946,25 @@ class DataProcessor(FeatureSet):
|
|
|
1334
946
|
return_dict: Literal[False] = False,
|
|
1335
947
|
save_format: Optional[str] = None,
|
|
1336
948
|
output_path: Optional[str] = None,
|
|
1337
|
-
chunk_size: int = 200000,
|
|
1338
949
|
) -> list[str]: ...
|
|
1339
950
|
|
|
1340
951
|
def transform(
|
|
1341
952
|
self,
|
|
1342
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
953
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1343
954
|
return_dict: bool = True,
|
|
1344
955
|
save_format: Optional[str] = None,
|
|
1345
956
|
output_path: Optional[str] = None,
|
|
1346
|
-
chunk_size: int = 200000,
|
|
1347
957
|
):
|
|
1348
958
|
"""
|
|
1349
959
|
Transform the provided data using the fitted DataProcessor.
|
|
1350
960
|
|
|
1351
961
|
Args:
|
|
1352
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
|
|
962
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
|
|
1353
963
|
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
1354
964
|
save_format (Optional[str]): Format to save the data if output_path is provided.
|
|
1355
965
|
output_path (Optional[str]): Output path to save the transformed data.
|
|
1356
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
1357
966
|
Returns:
|
|
1358
|
-
Union[
|
|
967
|
+
Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
|
|
1359
968
|
"""
|
|
1360
969
|
|
|
1361
970
|
if not self.is_fitted:
|
|
@@ -1367,9 +976,7 @@ class DataProcessor(FeatureSet):
|
|
|
1367
976
|
raise ValueError(
|
|
1368
977
|
"[Data Processor Error] Path transform writes files only; set return_dict=False when passing a path."
|
|
1369
978
|
)
|
|
1370
|
-
return self.transform_path(
|
|
1371
|
-
str(data), output_path, save_format, chunk_size=chunk_size
|
|
1372
|
-
)
|
|
979
|
+
return self.transform_path(str(data), output_path, save_format)
|
|
1373
980
|
return self.transform_in_memory(
|
|
1374
981
|
data=data,
|
|
1375
982
|
return_dict=return_dict,
|
|
@@ -1380,26 +987,24 @@ class DataProcessor(FeatureSet):
|
|
|
1380
987
|
|
|
1381
988
|
def fit_transform(
|
|
1382
989
|
self,
|
|
1383
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
990
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1384
991
|
return_dict: bool = True,
|
|
1385
992
|
save_format: Optional[str] = None,
|
|
1386
993
|
output_path: Optional[str] = None,
|
|
1387
|
-
chunk_size: int = 200000,
|
|
1388
994
|
):
|
|
1389
995
|
"""
|
|
1390
996
|
Fit the DataProcessor to the provided data and then transform it.
|
|
1391
997
|
|
|
1392
998
|
Args:
|
|
1393
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
|
|
999
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
|
|
1394
1000
|
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
1395
1001
|
save_format (Optional[str]): Format to save the data if output_path is provided.
|
|
1396
|
-
output_path (Optional[str]): Output path to save the
|
|
1397
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
1002
|
+
output_path (Optional[str]): Output path to save the data.
|
|
1398
1003
|
Returns:
|
|
1399
|
-
Union[
|
|
1004
|
+
Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
|
|
1400
1005
|
"""
|
|
1401
1006
|
|
|
1402
|
-
self.fit(data
|
|
1007
|
+
self.fit(data)
|
|
1403
1008
|
return self.transform(
|
|
1404
1009
|
data,
|
|
1405
1010
|
return_dict=return_dict,
|