nextrec 0.4.34__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +7 -13
- nextrec/basic/layers.py +28 -94
- nextrec/basic/model.py +512 -4
- nextrec/cli.py +101 -18
- nextrec/data/data_processing.py +8 -13
- nextrec/data/preprocessor.py +449 -846
- nextrec/models/ranking/afm.py +4 -9
- nextrec/models/ranking/dien.py +7 -8
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/retrieval/sdm.py +1 -2
- nextrec/models/sequential/hstu.py +0 -2
- nextrec/utils/onnx_utils.py +252 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/RECORD +18 -18
- nextrec/models/multi_task/[pre]star.py +0 -192
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.34.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
nextrec/data/preprocessor.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
|
|
3
3
|
|
|
4
4
|
Date: create on 13/11/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 28/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -13,14 +13,12 @@ import logging
|
|
|
13
13
|
import os
|
|
14
14
|
import pickle
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import Any, Dict, Literal, Optional, Union, overload
|
|
16
|
+
from typing import Any, Dict, Iterable, Literal, Optional, Union, overload
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import pandas as pd
|
|
20
|
-
import
|
|
21
|
-
import pyarrow.parquet as pq
|
|
20
|
+
import polars as pl
|
|
22
21
|
from sklearn.preprocessing import (
|
|
23
|
-
LabelEncoder,
|
|
24
22
|
MaxAbsScaler,
|
|
25
23
|
MinMaxScaler,
|
|
26
24
|
RobustScaler,
|
|
@@ -37,9 +35,6 @@ from nextrec.utils.data import (
|
|
|
37
35
|
FILE_FORMAT_CONFIG,
|
|
38
36
|
check_streaming_support,
|
|
39
37
|
default_output_dir,
|
|
40
|
-
iter_file_chunks,
|
|
41
|
-
load_dataframes,
|
|
42
|
-
read_table,
|
|
43
38
|
resolve_file_paths,
|
|
44
39
|
)
|
|
45
40
|
|
|
@@ -63,7 +58,7 @@ class DataProcessor(FeatureSet):
|
|
|
63
58
|
self.is_fitted = False
|
|
64
59
|
|
|
65
60
|
self.scalers: Dict[str, Any] = {}
|
|
66
|
-
self.label_encoders: Dict[str,
|
|
61
|
+
self.label_encoders: Dict[str, Any] = {}
|
|
67
62
|
self.target_encoders: Dict[str, Dict[str, int]] = {}
|
|
68
63
|
self.set_target_id(target=[], id_columns=[])
|
|
69
64
|
|
|
@@ -186,318 +181,228 @@ class DataProcessor(FeatureSet):
|
|
|
186
181
|
def hash_string(self, s: str, hash_size: int) -> int:
|
|
187
182
|
return self.hash_fn(str(s), int(hash_size))
|
|
188
183
|
|
|
189
|
-
def
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
if
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
logger = logging.getLogger()
|
|
223
|
-
name = str(data.name)
|
|
224
|
-
scaler_type = config["scaler"]
|
|
225
|
-
fill_na_value = config.get("fill_na_value", 0)
|
|
226
|
-
filled_data = data.fillna(fill_na_value)
|
|
227
|
-
values = np.array(filled_data.values, dtype=np.float64)
|
|
228
|
-
if scaler_type == "log":
|
|
229
|
-
result = np.log1p(np.maximum(values, 0))
|
|
230
|
-
elif scaler_type == "none":
|
|
231
|
-
result = values
|
|
184
|
+
def polars_scan(self, file_paths: list[str], file_type: str):
|
|
185
|
+
file_type = file_type.lower()
|
|
186
|
+
if file_type == "csv":
|
|
187
|
+
return pl.scan_csv(file_paths, ignore_errors=True)
|
|
188
|
+
if file_type == "parquet":
|
|
189
|
+
return pl.scan_parquet(file_paths)
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[Data Processor Error] Polars backend only supports csv/parquet, got: {file_type}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def sequence_expr(
|
|
195
|
+
self, pl, name: str, config: Dict[str, Any], schema: Dict[str, Any]
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
generate polars expression for sequence feature processing
|
|
199
|
+
|
|
200
|
+
Example Input:
|
|
201
|
+
sequence_str: "1,2,3"
|
|
202
|
+
sequence_str: " 4, ,5 "
|
|
203
|
+
sequence_list: ["7", "8", "9"]
|
|
204
|
+
sequence_list: ["", "10", " 11 "]
|
|
205
|
+
|
|
206
|
+
Example Output:
|
|
207
|
+
sequence_str -> ["1","2","3"]
|
|
208
|
+
sequence_str -> ["4","5"]
|
|
209
|
+
sequence_list -> ["7","8","9"]
|
|
210
|
+
sequence_list -> ["10","11"]
|
|
211
|
+
"""
|
|
212
|
+
separator = config["separator"]
|
|
213
|
+
dtype = schema.get(name)
|
|
214
|
+
col = pl.col(name)
|
|
215
|
+
if dtype is not None and isinstance(dtype, pl.List):
|
|
216
|
+
seq_col = col
|
|
232
217
|
else:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
218
|
+
seq_col = col.cast(pl.Utf8).fill_null("").str.split(separator)
|
|
219
|
+
elem = pl.element().cast(pl.Utf8).str.strip_chars()
|
|
220
|
+
seq_col = seq_col.list.eval(
|
|
221
|
+
pl.when(elem == "").then(None).otherwise(elem)
|
|
222
|
+
).list.drop_nulls()
|
|
223
|
+
return seq_col
|
|
224
|
+
|
|
225
|
+
def apply_transforms(self, lf, schema: Dict[str, Any], warn_missing: bool):
|
|
226
|
+
"""
|
|
227
|
+
Apply all transformations to a Polars LazyFrame.
|
|
242
228
|
|
|
243
|
-
|
|
229
|
+
"""
|
|
244
230
|
logger = logging.getLogger()
|
|
231
|
+
expressions = []
|
|
245
232
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
if min_freq is not None:
|
|
252
|
-
counts = filled_data.value_counts()
|
|
253
|
-
config["_token_counts"] = counts.to_dict()
|
|
254
|
-
vocab = sorted(counts[counts >= min_freq].index.tolist())
|
|
255
|
-
low_freq_types = int((counts < min_freq).sum())
|
|
256
|
-
total_types = int(counts.size)
|
|
257
|
-
kept_types = total_types - low_freq_types
|
|
258
|
-
if not config.get("_min_freq_logged"):
|
|
259
|
-
logger.info(
|
|
260
|
-
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
261
|
-
f"{total_types} token types total, "
|
|
262
|
-
f"{low_freq_types} low-frequency, "
|
|
263
|
-
f"{kept_types} kept."
|
|
264
|
-
)
|
|
265
|
-
config["_min_freq_logged"] = True
|
|
266
|
-
else:
|
|
267
|
-
vocab = sorted(set(filled_data.tolist()))
|
|
268
|
-
if "<UNK>" not in vocab:
|
|
269
|
-
vocab.append("<UNK>")
|
|
270
|
-
token_to_idx = {token: idx for idx, token in enumerate(vocab)}
|
|
271
|
-
config["_token_to_idx"] = token_to_idx
|
|
272
|
-
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
273
|
-
config["vocab_size"] = len(vocab)
|
|
274
|
-
elif encode_method == "hash":
|
|
275
|
-
min_freq = config.get("min_freq")
|
|
276
|
-
if min_freq is not None:
|
|
277
|
-
counts = filled_data.value_counts()
|
|
278
|
-
config["_token_counts"] = counts.to_dict()
|
|
279
|
-
config["_unk_hash"] = self.hash_string(
|
|
280
|
-
"<UNK>", int(config["hash_size"])
|
|
281
|
-
)
|
|
282
|
-
low_freq_types = int((counts < min_freq).sum())
|
|
283
|
-
total_types = int(counts.size)
|
|
284
|
-
kept_types = total_types - low_freq_types
|
|
285
|
-
if not config.get("_min_freq_logged"):
|
|
286
|
-
logger.info(
|
|
287
|
-
f"Sparse feature {data.name} min_freq={min_freq}: "
|
|
288
|
-
f"{total_types} token types total, "
|
|
289
|
-
f"{low_freq_types} low-frequency, "
|
|
290
|
-
f"{kept_types} kept."
|
|
291
|
-
)
|
|
292
|
-
config["_min_freq_logged"] = True
|
|
293
|
-
config["vocab_size"] = config["hash_size"]
|
|
294
|
-
|
|
295
|
-
def process_sparse_feature_transform(
|
|
296
|
-
self, data: pd.Series, config: Dict[str, Any]
|
|
297
|
-
) -> np.ndarray:
|
|
298
|
-
name = str(data.name)
|
|
299
|
-
encode_method = config["encode_method"]
|
|
300
|
-
fill_na = config["fill_na"]
|
|
301
|
-
|
|
302
|
-
sparse_series = (
|
|
303
|
-
data if isinstance(data, pd.Series) else pd.Series(data, name=name)
|
|
304
|
-
)
|
|
305
|
-
sparse_series = sparse_series.fillna(fill_na).astype(str)
|
|
306
|
-
if encode_method == "label":
|
|
307
|
-
token_to_idx = config.get("_token_to_idx")
|
|
308
|
-
if isinstance(token_to_idx, dict):
|
|
309
|
-
unk_index = int(config.get("_unk_index", 0))
|
|
310
|
-
return np.fromiter(
|
|
311
|
-
(token_to_idx.get(v, unk_index) for v in sparse_series.to_numpy()),
|
|
312
|
-
dtype=np.int64,
|
|
313
|
-
count=sparse_series.size,
|
|
314
|
-
)
|
|
315
|
-
raise ValueError(
|
|
316
|
-
f"[Data Processor Error] Token index for {name} not fitted"
|
|
233
|
+
def map_with_default(expr, mapping: Dict[str, int], default: int, dtype):
|
|
234
|
+
# Compatible with older polars versions without Expr.map_dict
|
|
235
|
+
return expr.map_elements(
|
|
236
|
+
lambda x: mapping.get(x, default),
|
|
237
|
+
return_dtype=dtype,
|
|
317
238
|
)
|
|
318
239
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
240
|
+
def ensure_present(feature_name: str, label: str) -> bool:
|
|
241
|
+
if feature_name not in schema:
|
|
242
|
+
if warn_missing:
|
|
243
|
+
logger.warning(f"{label} feature {feature_name} not found in data")
|
|
244
|
+
return False
|
|
245
|
+
return True
|
|
246
|
+
|
|
247
|
+
# Numeric features
|
|
248
|
+
for name, config in self.numeric_features.items():
|
|
249
|
+
if not ensure_present(name, "Numeric"):
|
|
250
|
+
continue
|
|
251
|
+
scaler_type = config["scaler"]
|
|
252
|
+
fill_na_value = config.get("fill_na_value", 0)
|
|
253
|
+
col = pl.col(name).cast(pl.Float64).fill_null(fill_na_value)
|
|
254
|
+
if scaler_type == "log":
|
|
255
|
+
col = col.clip(lower_bound=0).log1p()
|
|
256
|
+
elif scaler_type == "none":
|
|
257
|
+
pass
|
|
258
|
+
else:
|
|
259
|
+
scaler = self.scalers.get(name)
|
|
260
|
+
if scaler is None:
|
|
261
|
+
logger.warning(
|
|
262
|
+
f"Scaler for {name} not fitted, returning original values"
|
|
336
263
|
)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
264
|
+
else:
|
|
265
|
+
if scaler_type == "standard":
|
|
266
|
+
mean = float(scaler.mean_[0])
|
|
267
|
+
scale = (
|
|
268
|
+
float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
|
|
269
|
+
)
|
|
270
|
+
col = (col - mean) / scale
|
|
271
|
+
elif scaler_type == "minmax":
|
|
272
|
+
scale = float(scaler.scale_[0])
|
|
273
|
+
min_val = float(scaler.min_[0])
|
|
274
|
+
col = col * scale + min_val
|
|
275
|
+
elif scaler_type == "maxabs":
|
|
276
|
+
max_abs = float(scaler.max_abs_[0]) or 1.0
|
|
277
|
+
col = col / max_abs
|
|
278
|
+
elif scaler_type == "robust":
|
|
279
|
+
center = float(scaler.center_[0])
|
|
280
|
+
scale = (
|
|
281
|
+
float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
|
|
282
|
+
)
|
|
283
|
+
col = (col - center) / scale
|
|
284
|
+
expressions.append(col.alias(name))
|
|
343
285
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
key = str(token)
|
|
357
|
-
token_counts[key] = token_counts.get(key, 0) + 1
|
|
358
|
-
if min_freq is not None:
|
|
359
|
-
config["_token_counts"] = token_counts
|
|
360
|
-
vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
|
|
361
|
-
low_freq_types = sum(
|
|
362
|
-
1 for count in token_counts.values() if count < min_freq
|
|
363
|
-
)
|
|
364
|
-
total_types = len(token_counts)
|
|
365
|
-
kept_types = total_types - low_freq_types
|
|
366
|
-
if not config.get("_min_freq_logged"):
|
|
367
|
-
logger.info(
|
|
368
|
-
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
369
|
-
f"{total_types} token types total, "
|
|
370
|
-
f"{low_freq_types} low-frequency, "
|
|
371
|
-
f"{kept_types} kept."
|
|
286
|
+
# Sparse features
|
|
287
|
+
for name, config in self.sparse_features.items():
|
|
288
|
+
if not ensure_present(name, "Sparse"):
|
|
289
|
+
continue
|
|
290
|
+
encode_method = config["encode_method"]
|
|
291
|
+
fill_na = config["fill_na"]
|
|
292
|
+
col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
|
|
293
|
+
if encode_method == "label":
|
|
294
|
+
token_to_idx = config.get("_token_to_idx")
|
|
295
|
+
if not isinstance(token_to_idx, dict):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"[Data Processor Error] Token index for {name} not fitted"
|
|
372
298
|
)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
tokens = self.extract_sequence_tokens(seq, separator)
|
|
390
|
-
for token in tokens:
|
|
391
|
-
if str(token).strip():
|
|
392
|
-
token_counts[str(token)] = (
|
|
393
|
-
token_counts.get(str(token), 0) + 1
|
|
394
|
-
)
|
|
395
|
-
config["_token_counts"] = token_counts
|
|
396
|
-
config["_unk_hash"] = self.hash_string(
|
|
397
|
-
"<UNK>", int(config["hash_size"])
|
|
398
|
-
)
|
|
399
|
-
low_freq_types = sum(
|
|
400
|
-
1 for count in token_counts.values() if count < min_freq
|
|
401
|
-
)
|
|
402
|
-
total_types = len(token_counts)
|
|
403
|
-
kept_types = total_types - low_freq_types
|
|
404
|
-
if not config.get("_min_freq_logged"):
|
|
405
|
-
logger.info(
|
|
406
|
-
f"Sequence feature {data.name} min_freq={min_freq}: "
|
|
407
|
-
f"{total_types} token types total, "
|
|
408
|
-
f"{low_freq_types} low-frequency, "
|
|
409
|
-
f"{kept_types} kept."
|
|
299
|
+
unk_index = int(config.get("_unk_index", 0))
|
|
300
|
+
col = map_with_default(col, token_to_idx, unk_index, pl.Int64)
|
|
301
|
+
elif encode_method == "hash":
|
|
302
|
+
hash_size = config["hash_size"]
|
|
303
|
+
hash_expr = col.hash().cast(pl.UInt64) % int(hash_size)
|
|
304
|
+
min_freq = config.get("min_freq")
|
|
305
|
+
token_counts = config.get("_token_counts")
|
|
306
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
307
|
+
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
308
|
+
unk_hash = config.get("_unk_hash")
|
|
309
|
+
if unk_hash is None:
|
|
310
|
+
unk_hash = self.hash_string("<UNK>", int(hash_size))
|
|
311
|
+
hash_expr = (
|
|
312
|
+
pl.when(col.is_in(low_freq))
|
|
313
|
+
.then(int(unk_hash))
|
|
314
|
+
.otherwise(hash_expr)
|
|
410
315
|
)
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
arr = np.asarray(data, dtype=object)
|
|
425
|
-
n = arr.shape[0]
|
|
426
|
-
output = np.full((n, max_len), pad_value, dtype=np.int64)
|
|
427
|
-
# Shared helpers cached locally for speed and cross-platform consistency
|
|
428
|
-
split_fn = str.split
|
|
429
|
-
is_nan = np.isnan
|
|
430
|
-
if encode_method == "label":
|
|
431
|
-
class_to_idx = config.get("_token_to_idx")
|
|
432
|
-
if class_to_idx is None:
|
|
433
|
-
raise ValueError(
|
|
434
|
-
f"[Data Processor Error] Token index for {name} not fitted"
|
|
435
|
-
)
|
|
436
|
-
unk_index = int(config.get("_unk_index", class_to_idx.get("<UNK>", 0)))
|
|
437
|
-
else:
|
|
438
|
-
class_to_idx = None # type: ignore
|
|
439
|
-
unk_index = 0
|
|
440
|
-
hash_fn = self.hash_string
|
|
441
|
-
hash_size = config.get("hash_size")
|
|
442
|
-
min_freq = config.get("min_freq")
|
|
443
|
-
token_counts = config.get("_token_counts")
|
|
444
|
-
if min_freq is not None and isinstance(token_counts, dict):
|
|
445
|
-
unk_hash = config.get("_unk_hash")
|
|
446
|
-
if unk_hash is None and hash_size is not None:
|
|
447
|
-
unk_hash = hash_fn("<UNK>", hash_size)
|
|
448
|
-
for i, seq in enumerate(arr):
|
|
449
|
-
# normalize sequence to a list of strings
|
|
450
|
-
tokens = []
|
|
451
|
-
if seq is None:
|
|
452
|
-
tokens = []
|
|
453
|
-
elif isinstance(seq, (float, np.floating)):
|
|
454
|
-
tokens = [] if is_nan(seq) else [str(seq)]
|
|
455
|
-
elif isinstance(seq, str):
|
|
456
|
-
seq_str = seq.strip()
|
|
457
|
-
tokens = [] if not seq_str else split_fn(seq_str, separator)
|
|
458
|
-
elif isinstance(seq, (list, tuple, np.ndarray)):
|
|
459
|
-
tokens = [str(t) for t in seq]
|
|
460
|
-
else:
|
|
461
|
-
tokens = []
|
|
316
|
+
col = hash_expr.cast(pl.Int64)
|
|
317
|
+
expressions.append(col.alias(name))
|
|
318
|
+
|
|
319
|
+
# Sequence features
|
|
320
|
+
for name, config in self.sequence_features.items():
|
|
321
|
+
if not ensure_present(name, "Sequence"):
|
|
322
|
+
continue
|
|
323
|
+
encode_method = config["encode_method"]
|
|
324
|
+
max_len = int(config["max_len"])
|
|
325
|
+
pad_value = int(config["pad_value"])
|
|
326
|
+
truncate = config["truncate"]
|
|
327
|
+
seq_col = self.sequence_expr(pl, name, config, schema)
|
|
328
|
+
|
|
462
329
|
if encode_method == "label":
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
330
|
+
token_to_idx = config.get("_token_to_idx")
|
|
331
|
+
if not isinstance(token_to_idx, dict):
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"[Data Processor Error] Token index for {name} not fitted"
|
|
334
|
+
)
|
|
335
|
+
unk_index = int(config.get("_unk_index", 0))
|
|
336
|
+
seq_col = seq_col.list.eval(
|
|
337
|
+
map_with_default(pl.element(), token_to_idx, unk_index, pl.Int64)
|
|
338
|
+
)
|
|
468
339
|
elif encode_method == "hash":
|
|
340
|
+
hash_size = config.get("hash_size")
|
|
469
341
|
if hash_size is None:
|
|
470
342
|
raise ValueError(
|
|
471
343
|
"[Data Processor Error] hash_size must be set for hash encoding"
|
|
472
344
|
)
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
345
|
+
elem = pl.element().cast(pl.Utf8)
|
|
346
|
+
hash_expr = elem.hash().cast(pl.UInt64) % int(hash_size)
|
|
347
|
+
min_freq = config.get("min_freq")
|
|
348
|
+
token_counts = config.get("_token_counts")
|
|
349
|
+
if min_freq is not None and isinstance(token_counts, dict):
|
|
350
|
+
low_freq = [k for k, v in token_counts.items() if v < min_freq]
|
|
351
|
+
unk_hash = config.get("_unk_hash")
|
|
352
|
+
if unk_hash is None:
|
|
353
|
+
unk_hash = self.hash_string("<UNK>", int(hash_size))
|
|
354
|
+
hash_expr = (
|
|
355
|
+
pl.when(elem.is_in(low_freq))
|
|
356
|
+
.then(int(unk_hash))
|
|
357
|
+
.otherwise(hash_expr)
|
|
480
358
|
)
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
359
|
+
seq_col = seq_col.list.eval(hash_expr)
|
|
360
|
+
|
|
361
|
+
if truncate == "pre":
|
|
362
|
+
seq_col = seq_col.list.tail(max_len)
|
|
484
363
|
else:
|
|
485
|
-
|
|
486
|
-
|
|
364
|
+
seq_col = seq_col.list.head(max_len)
|
|
365
|
+
pad_list = [pad_value] * max_len
|
|
366
|
+
seq_col = pl.concat_list([seq_col, pl.lit(pad_list)]).list.head(max_len)
|
|
367
|
+
expressions.append(seq_col.alias(name))
|
|
368
|
+
|
|
369
|
+
# Target features
|
|
370
|
+
for name, config in self.target_features.items():
|
|
371
|
+
if not ensure_present(name, "Target"):
|
|
487
372
|
continue
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
373
|
+
target_type = config.get("target_type")
|
|
374
|
+
col = pl.col(name)
|
|
375
|
+
if target_type == "regression":
|
|
376
|
+
col = col.cast(pl.Float32)
|
|
377
|
+
elif target_type == "binary":
|
|
378
|
+
label_map = self.target_encoders.get(name)
|
|
379
|
+
if label_map is None:
|
|
380
|
+
raise ValueError(
|
|
381
|
+
f"[Data Processor Error] Target encoder for {name} not fitted"
|
|
382
|
+
)
|
|
383
|
+
col = map_with_default(col.cast(pl.Utf8), label_map, 0, pl.Int64).cast(
|
|
384
|
+
pl.Float32
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"[Data Processor Error] Unsupported target type: {target_type}"
|
|
389
|
+
)
|
|
390
|
+
expressions.append(col.alias(name))
|
|
391
|
+
|
|
392
|
+
if not expressions:
|
|
393
|
+
return lf
|
|
394
|
+
return lf.with_columns(expressions)
|
|
492
395
|
|
|
493
|
-
def process_target_fit(
|
|
494
|
-
|
|
396
|
+
def process_target_fit(
|
|
397
|
+
self, data: Iterable[Any], config: Dict[str, Any], name: str
|
|
398
|
+
) -> None:
|
|
495
399
|
target_type = config["target_type"]
|
|
496
400
|
label_map = config.get("label_map")
|
|
497
401
|
if target_type == "binary":
|
|
498
402
|
if label_map is None:
|
|
499
|
-
unique_values = data
|
|
500
|
-
|
|
403
|
+
unique_values = {v for v in data if v is not None}
|
|
404
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
405
|
+
sorted_values = sorted(v for v in unique_values if v is not None)
|
|
501
406
|
try:
|
|
502
407
|
int_values = [int(v) for v in sorted_values]
|
|
503
408
|
if int_values == list(range(len(int_values))):
|
|
@@ -511,254 +416,149 @@ class DataProcessor(FeatureSet):
|
|
|
511
416
|
config["label_map"] = label_map
|
|
512
417
|
self.target_encoders[name] = label_map
|
|
513
418
|
|
|
514
|
-
def
|
|
515
|
-
self, data: pd.Series, config: Dict[str, Any]
|
|
516
|
-
) -> np.ndarray:
|
|
419
|
+
def polars_fit_from_lazy(self, lf, schema: Dict[str, Any]) -> "DataProcessor":
|
|
517
420
|
logger = logging.getLogger()
|
|
518
|
-
name = str(data.name)
|
|
519
|
-
target_type = config.get("target_type")
|
|
520
|
-
if target_type == "regression":
|
|
521
|
-
values = np.array(data.values, dtype=np.float32)
|
|
522
|
-
return values
|
|
523
|
-
if target_type == "binary":
|
|
524
|
-
label_map = self.target_encoders.get(name)
|
|
525
|
-
if label_map is None:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"[Data Processor Error] Target encoder for {name} not fitted"
|
|
528
|
-
)
|
|
529
|
-
result = []
|
|
530
|
-
for val in data:
|
|
531
|
-
str_val = str(val)
|
|
532
|
-
if str_val in label_map:
|
|
533
|
-
result.append(label_map[str_val])
|
|
534
|
-
else:
|
|
535
|
-
logger.warning(f"Unknown target value: {val}, mapping to 0")
|
|
536
|
-
result.append(0)
|
|
537
|
-
return np.array(result, dtype=np.float32)
|
|
538
|
-
raise ValueError(
|
|
539
|
-
f"[Data Processor Error] Unsupported target type: {target_type}"
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
|
|
543
|
-
"""
|
|
544
|
-
Load all data from a file or directory path into a single DataFrame.
|
|
545
|
-
|
|
546
|
-
Args:
|
|
547
|
-
path (str): File or directory path.
|
|
548
|
-
|
|
549
|
-
Returns:
|
|
550
|
-
pd.DataFrame: Loaded DataFrame.
|
|
551
|
-
"""
|
|
552
|
-
file_paths, file_type = resolve_file_paths(path)
|
|
553
|
-
frames = load_dataframes(file_paths, file_type)
|
|
554
|
-
return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
555
|
-
|
|
556
|
-
def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
|
|
557
|
-
"""Extract sequence tokens from a single value."""
|
|
558
|
-
if value is None:
|
|
559
|
-
return []
|
|
560
|
-
if isinstance(value, (float, np.floating)) and np.isnan(value):
|
|
561
|
-
return []
|
|
562
|
-
if isinstance(value, str):
|
|
563
|
-
stripped = value.strip()
|
|
564
|
-
return [] if not stripped else stripped.split(separator)
|
|
565
|
-
if isinstance(value, (list, tuple, np.ndarray)):
|
|
566
|
-
return [str(v) for v in value]
|
|
567
|
-
return [str(value)]
|
|
568
|
-
|
|
569
|
-
def fit_from_file_paths(
|
|
570
|
-
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
571
|
-
) -> "DataProcessor":
|
|
572
|
-
logger = logging.getLogger()
|
|
573
|
-
if not file_paths:
|
|
574
|
-
raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
|
|
575
|
-
if not check_streaming_support(file_type):
|
|
576
|
-
raise ValueError(
|
|
577
|
-
f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
|
|
578
|
-
"Streaming fit only supports csv, parquet to avoid high memory usage."
|
|
579
|
-
)
|
|
580
|
-
|
|
581
|
-
numeric_acc = {}
|
|
582
|
-
for name in self.numeric_features.keys():
|
|
583
|
-
numeric_acc[name] = {
|
|
584
|
-
"sum": 0.0,
|
|
585
|
-
"sumsq": 0.0,
|
|
586
|
-
"count": 0.0,
|
|
587
|
-
"min": np.inf,
|
|
588
|
-
"max": -np.inf,
|
|
589
|
-
"max_abs": 0.0,
|
|
590
|
-
}
|
|
591
|
-
sparse_vocab: Dict[str, set[str]] = {
|
|
592
|
-
name: set() for name in self.sparse_features.keys()
|
|
593
|
-
}
|
|
594
|
-
seq_vocab: Dict[str, set[str]] = {
|
|
595
|
-
name: set() for name in self.sequence_features.keys()
|
|
596
|
-
}
|
|
597
|
-
sparse_label_counts: Dict[str, Dict[str, int]] = {
|
|
598
|
-
name: {}
|
|
599
|
-
for name, config in self.sparse_features.items()
|
|
600
|
-
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
601
|
-
}
|
|
602
|
-
seq_label_counts: Dict[str, Dict[str, int]] = {
|
|
603
|
-
name: {}
|
|
604
|
-
for name, config in self.sequence_features.items()
|
|
605
|
-
if config.get("encode_method") == "label" and config.get("min_freq")
|
|
606
|
-
}
|
|
607
|
-
sparse_hash_counts: Dict[str, Dict[str, int]] = {
|
|
608
|
-
name: {}
|
|
609
|
-
for name, config in self.sparse_features.items()
|
|
610
|
-
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
611
|
-
}
|
|
612
|
-
seq_hash_counts: Dict[str, Dict[str, int]] = {
|
|
613
|
-
name: {}
|
|
614
|
-
for name, config in self.sequence_features.items()
|
|
615
|
-
if config.get("encode_method") == "hash" and config.get("min_freq")
|
|
616
|
-
}
|
|
617
|
-
target_values: Dict[str, set[Any]] = {
|
|
618
|
-
name: set() for name in self.target_features.keys()
|
|
619
|
-
}
|
|
620
421
|
|
|
621
422
|
missing_features = set()
|
|
622
|
-
for
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
series = chunk[name]
|
|
635
|
-
if group == "numeric":
|
|
636
|
-
values = pd.to_numeric(series, errors="coerce").dropna()
|
|
637
|
-
if values.empty:
|
|
638
|
-
continue
|
|
639
|
-
acc = numeric_acc[name]
|
|
640
|
-
arr = values.to_numpy(dtype=np.float64, copy=False)
|
|
641
|
-
acc["count"] += arr.size
|
|
642
|
-
acc["sum"] += float(arr.sum())
|
|
643
|
-
acc["sumsq"] += float(np.square(arr).sum())
|
|
644
|
-
acc["min"] = min(acc["min"], float(arr.min()))
|
|
645
|
-
acc["max"] = max(acc["max"], float(arr.max()))
|
|
646
|
-
acc["max_abs"] = max(
|
|
647
|
-
acc["max_abs"], float(np.abs(arr).max())
|
|
648
|
-
)
|
|
649
|
-
elif group == "sparse":
|
|
650
|
-
fill_na = config["fill_na"]
|
|
651
|
-
series = series.fillna(fill_na).astype(str)
|
|
652
|
-
sparse_vocab[name].update(series.tolist())
|
|
653
|
-
if name in sparse_label_counts:
|
|
654
|
-
counts = sparse_label_counts[name]
|
|
655
|
-
for token in series.tolist():
|
|
656
|
-
counts[token] = counts.get(token, 0) + 1
|
|
657
|
-
if name in sparse_hash_counts:
|
|
658
|
-
counts = sparse_hash_counts[name]
|
|
659
|
-
for token in series.tolist():
|
|
660
|
-
counts[token] = counts.get(token, 0) + 1
|
|
661
|
-
else:
|
|
662
|
-
separator = config["separator"]
|
|
663
|
-
tokens = []
|
|
664
|
-
for val in series:
|
|
665
|
-
tokens.extend(
|
|
666
|
-
self.extract_sequence_tokens(val, separator)
|
|
667
|
-
)
|
|
668
|
-
seq_vocab[name].update(tokens)
|
|
669
|
-
if name in seq_label_counts:
|
|
670
|
-
counts = seq_label_counts[name]
|
|
671
|
-
for token in tokens:
|
|
672
|
-
if str(token).strip():
|
|
673
|
-
key = str(token)
|
|
674
|
-
counts[key] = counts.get(key, 0) + 1
|
|
675
|
-
if name in seq_hash_counts:
|
|
676
|
-
counts = seq_hash_counts[name]
|
|
677
|
-
for token in tokens:
|
|
678
|
-
if str(token).strip():
|
|
679
|
-
key = str(token)
|
|
680
|
-
counts[key] = counts.get(key, 0) + 1
|
|
681
|
-
|
|
682
|
-
# target features
|
|
683
|
-
missing_features.update(self.target_features.keys() - columns)
|
|
684
|
-
for name in self.target_features.keys() & columns:
|
|
685
|
-
vals = chunk[name].dropna().tolist()
|
|
686
|
-
target_values[name].update(vals)
|
|
687
|
-
|
|
423
|
+
for name in self.numeric_features.keys():
|
|
424
|
+
if name not in schema:
|
|
425
|
+
missing_features.add(name)
|
|
426
|
+
for name in self.sparse_features.keys():
|
|
427
|
+
if name not in schema:
|
|
428
|
+
missing_features.add(name)
|
|
429
|
+
for name in self.sequence_features.keys():
|
|
430
|
+
if name not in schema:
|
|
431
|
+
missing_features.add(name)
|
|
432
|
+
for name in self.target_features.keys():
|
|
433
|
+
if name not in schema:
|
|
434
|
+
missing_features.add(name)
|
|
688
435
|
if missing_features:
|
|
689
436
|
logger.warning(
|
|
690
|
-
f"The following configured features were not found in provided
|
|
437
|
+
f"The following configured features were not found in provided data: {sorted(missing_features)}"
|
|
691
438
|
)
|
|
692
439
|
|
|
693
|
-
#
|
|
440
|
+
# numeric aggregates in a single pass
|
|
441
|
+
if self.numeric_features:
|
|
442
|
+
agg_exprs = []
|
|
443
|
+
for name in self.numeric_features.keys():
|
|
444
|
+
if name not in schema:
|
|
445
|
+
continue
|
|
446
|
+
col = pl.col(name).cast(pl.Float64)
|
|
447
|
+
agg_exprs.extend(
|
|
448
|
+
[
|
|
449
|
+
col.sum().alias(f"{name}__sum"),
|
|
450
|
+
(col * col).sum().alias(f"{name}__sumsq"),
|
|
451
|
+
col.count().alias(f"{name}__count"),
|
|
452
|
+
col.min().alias(f"{name}__min"),
|
|
453
|
+
col.max().alias(f"{name}__max"),
|
|
454
|
+
col.abs().max().alias(f"{name}__max_abs"),
|
|
455
|
+
]
|
|
456
|
+
)
|
|
457
|
+
if self.numeric_features[name].get("scaler") == "robust":
|
|
458
|
+
agg_exprs.extend(
|
|
459
|
+
[
|
|
460
|
+
col.quantile(0.25).alias(f"{name}__q1"),
|
|
461
|
+
col.quantile(0.75).alias(f"{name}__q3"),
|
|
462
|
+
col.median().alias(f"{name}__median"),
|
|
463
|
+
]
|
|
464
|
+
)
|
|
465
|
+
stats = lf.select(agg_exprs).collect().to_dicts()[0] if agg_exprs else {}
|
|
466
|
+
else:
|
|
467
|
+
stats = {}
|
|
468
|
+
|
|
694
469
|
for name, config in self.numeric_features.items():
|
|
695
|
-
|
|
696
|
-
|
|
470
|
+
if name not in schema:
|
|
471
|
+
continue
|
|
472
|
+
count = float(stats.get(f"{name}__count", 0) or 0)
|
|
473
|
+
if count == 0:
|
|
697
474
|
logger.warning(
|
|
698
|
-
f"Numeric feature {name} has no valid values in provided
|
|
475
|
+
f"Numeric feature {name} has no valid values in provided data"
|
|
699
476
|
)
|
|
700
477
|
continue
|
|
701
|
-
|
|
478
|
+
sum_val = float(stats.get(f"{name}__sum", 0) or 0)
|
|
479
|
+
sumsq = float(stats.get(f"{name}__sumsq", 0) or 0)
|
|
480
|
+
mean_val = sum_val / count
|
|
702
481
|
if config["fill_na"] is not None:
|
|
703
482
|
config["fill_na_value"] = config["fill_na"]
|
|
704
483
|
else:
|
|
705
484
|
config["fill_na_value"] = mean_val
|
|
706
485
|
scaler_type = config["scaler"]
|
|
707
486
|
if scaler_type == "standard":
|
|
708
|
-
var = max(
|
|
487
|
+
var = max(sumsq / count - mean_val * mean_val, 0.0)
|
|
709
488
|
scaler = StandardScaler()
|
|
710
489
|
scaler.mean_ = np.array([mean_val], dtype=np.float64)
|
|
711
490
|
scaler.var_ = np.array([var], dtype=np.float64)
|
|
712
491
|
scaler.scale_ = np.array(
|
|
713
492
|
[np.sqrt(var) if var > 0 else 1.0], dtype=np.float64
|
|
714
493
|
)
|
|
715
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
494
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
716
495
|
self.scalers[name] = scaler
|
|
717
|
-
|
|
718
496
|
elif scaler_type == "minmax":
|
|
719
|
-
data_min =
|
|
720
|
-
data_max =
|
|
497
|
+
data_min = float(stats.get(f"{name}__min", 0) or 0)
|
|
498
|
+
data_max = float(stats.get(f"{name}__max", data_min) or data_min)
|
|
721
499
|
scaler = MinMaxScaler()
|
|
722
500
|
scaler.data_min_ = np.array([data_min], dtype=np.float64)
|
|
723
501
|
scaler.data_max_ = np.array([data_max], dtype=np.float64)
|
|
724
502
|
scaler.data_range_ = scaler.data_max_ - scaler.data_min_
|
|
725
503
|
scaler.data_range_[scaler.data_range_ == 0] = 1.0
|
|
726
|
-
# Manually set scale_/min_ for streaming fit to mirror sklearn's internal fit logic
|
|
727
504
|
feature_min, feature_max = scaler.feature_range
|
|
728
505
|
scale = (feature_max - feature_min) / scaler.data_range_
|
|
729
506
|
scaler.scale_ = scale
|
|
730
507
|
scaler.min_ = feature_min - scaler.data_min_ * scale
|
|
731
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
508
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
732
509
|
self.scalers[name] = scaler
|
|
733
|
-
|
|
734
510
|
elif scaler_type == "maxabs":
|
|
511
|
+
max_abs = float(stats.get(f"{name}__max_abs", 1.0) or 1.0)
|
|
735
512
|
scaler = MaxAbsScaler()
|
|
736
|
-
scaler.max_abs_ = np.array([
|
|
737
|
-
scaler.n_samples_seen_ = np.array([int(
|
|
513
|
+
scaler.max_abs_ = np.array([max_abs], dtype=np.float64)
|
|
514
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
738
515
|
self.scalers[name] = scaler
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
516
|
+
elif scaler_type == "robust":
|
|
517
|
+
q1 = float(stats.get(f"{name}__q1", 0) or 0)
|
|
518
|
+
q3 = float(stats.get(f"{name}__q3", q1) or q1)
|
|
519
|
+
median = float(stats.get(f"{name}__median", 0) or 0)
|
|
520
|
+
scale = q3 - q1
|
|
521
|
+
if scale == 0:
|
|
522
|
+
scale = 1.0
|
|
523
|
+
scaler = RobustScaler()
|
|
524
|
+
scaler.center_ = np.array([median], dtype=np.float64)
|
|
525
|
+
scaler.scale_ = np.array([scale], dtype=np.float64)
|
|
526
|
+
scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
|
|
527
|
+
self.scalers[name] = scaler
|
|
528
|
+
elif scaler_type in ("log", "none"):
|
|
742
529
|
continue
|
|
743
530
|
else:
|
|
744
531
|
raise ValueError(f"Unknown scaler type: {scaler_type}")
|
|
745
532
|
|
|
746
|
-
#
|
|
533
|
+
# sparse features
|
|
747
534
|
for name, config in self.sparse_features.items():
|
|
748
|
-
if
|
|
535
|
+
if name not in schema:
|
|
536
|
+
continue
|
|
537
|
+
encode_method = config["encode_method"]
|
|
538
|
+
fill_na = config["fill_na"]
|
|
539
|
+
col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
|
|
540
|
+
counts_df = (
|
|
541
|
+
lf.select(col.alias(name))
|
|
542
|
+
.group_by(name)
|
|
543
|
+
.agg(pl.len().alias("count"))
|
|
544
|
+
.collect()
|
|
545
|
+
)
|
|
546
|
+
counts = (
|
|
547
|
+
dict(zip(counts_df[name].to_list(), counts_df["count"].to_list()))
|
|
548
|
+
if counts_df.height > 0
|
|
549
|
+
else {}
|
|
550
|
+
)
|
|
551
|
+
if encode_method == "label":
|
|
749
552
|
min_freq = config.get("min_freq")
|
|
750
553
|
if min_freq is not None:
|
|
751
|
-
|
|
752
|
-
config["_token_counts"] = token_counts
|
|
554
|
+
config["_token_counts"] = counts
|
|
753
555
|
vocab = {
|
|
754
|
-
token
|
|
755
|
-
for token, count in token_counts.items()
|
|
756
|
-
if count >= min_freq
|
|
556
|
+
token for token, count in counts.items() if count >= min_freq
|
|
757
557
|
}
|
|
758
558
|
low_freq_types = sum(
|
|
759
|
-
1 for count in
|
|
559
|
+
1 for count in counts.values() if count < min_freq
|
|
760
560
|
)
|
|
761
|
-
total_types = len(
|
|
561
|
+
total_types = len(counts)
|
|
762
562
|
kept_types = total_types - low_freq_types
|
|
763
563
|
if not config.get("_min_freq_logged"):
|
|
764
564
|
logger.info(
|
|
@@ -769,29 +569,29 @@ class DataProcessor(FeatureSet):
|
|
|
769
569
|
)
|
|
770
570
|
config["_min_freq_logged"] = True
|
|
771
571
|
else:
|
|
772
|
-
vocab =
|
|
572
|
+
vocab = set(counts.keys())
|
|
773
573
|
if not vocab:
|
|
774
574
|
logger.warning(f"Sparse feature {name} has empty vocabulary")
|
|
775
575
|
continue
|
|
776
|
-
|
|
576
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
577
|
+
vocab_list = sorted(v for v in vocab if v is not None)
|
|
777
578
|
if "<UNK>" not in vocab_list:
|
|
778
579
|
vocab_list.append("<UNK>")
|
|
779
580
|
token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
|
|
780
581
|
config["_token_to_idx"] = token_to_idx
|
|
781
582
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
782
583
|
config["vocab_size"] = len(vocab_list)
|
|
783
|
-
elif
|
|
584
|
+
elif encode_method == "hash":
|
|
784
585
|
min_freq = config.get("min_freq")
|
|
785
586
|
if min_freq is not None:
|
|
786
|
-
|
|
787
|
-
config["_token_counts"] = token_counts
|
|
587
|
+
config["_token_counts"] = counts
|
|
788
588
|
config["_unk_hash"] = self.hash_string(
|
|
789
589
|
"<UNK>", int(config["hash_size"])
|
|
790
590
|
)
|
|
791
591
|
low_freq_types = sum(
|
|
792
|
-
1 for count in
|
|
592
|
+
1 for count in counts.values() if count < min_freq
|
|
793
593
|
)
|
|
794
|
-
total_types = len(
|
|
594
|
+
total_types = len(counts)
|
|
795
595
|
kept_types = total_types - low_freq_types
|
|
796
596
|
if not config.get("_min_freq_logged"):
|
|
797
597
|
logger.info(
|
|
@@ -803,22 +603,37 @@ class DataProcessor(FeatureSet):
|
|
|
803
603
|
config["_min_freq_logged"] = True
|
|
804
604
|
config["vocab_size"] = config["hash_size"]
|
|
805
605
|
|
|
806
|
-
#
|
|
606
|
+
# sequence features
|
|
807
607
|
for name, config in self.sequence_features.items():
|
|
808
|
-
if
|
|
608
|
+
if name not in schema:
|
|
609
|
+
continue
|
|
610
|
+
encode_method = config["encode_method"]
|
|
611
|
+
seq_col = self.sequence_expr(pl, name, config, schema)
|
|
612
|
+
tokens_df = (
|
|
613
|
+
lf.select(seq_col.alias("seq"))
|
|
614
|
+
.explode("seq")
|
|
615
|
+
.select(pl.col("seq").cast(pl.Utf8).alias("seq"))
|
|
616
|
+
.drop_nulls("seq")
|
|
617
|
+
.group_by("seq")
|
|
618
|
+
.agg(pl.len().alias("count"))
|
|
619
|
+
.collect()
|
|
620
|
+
)
|
|
621
|
+
counts = (
|
|
622
|
+
dict(zip(tokens_df["seq"].to_list(), tokens_df["count"].to_list()))
|
|
623
|
+
if tokens_df.height > 0
|
|
624
|
+
else {}
|
|
625
|
+
)
|
|
626
|
+
if encode_method == "label":
|
|
809
627
|
min_freq = config.get("min_freq")
|
|
810
628
|
if min_freq is not None:
|
|
811
|
-
|
|
812
|
-
config["_token_counts"] = token_counts
|
|
629
|
+
config["_token_counts"] = counts
|
|
813
630
|
vocab_set = {
|
|
814
|
-
token
|
|
815
|
-
for token, count in token_counts.items()
|
|
816
|
-
if count >= min_freq
|
|
631
|
+
token for token, count in counts.items() if count >= min_freq
|
|
817
632
|
}
|
|
818
633
|
low_freq_types = sum(
|
|
819
|
-
1 for count in
|
|
634
|
+
1 for count in counts.values() if count < min_freq
|
|
820
635
|
)
|
|
821
|
-
total_types = len(
|
|
636
|
+
total_types = len(counts)
|
|
822
637
|
kept_types = total_types - low_freq_types
|
|
823
638
|
if not config.get("_min_freq_logged"):
|
|
824
639
|
logger.info(
|
|
@@ -829,26 +644,30 @@ class DataProcessor(FeatureSet):
|
|
|
829
644
|
)
|
|
830
645
|
config["_min_freq_logged"] = True
|
|
831
646
|
else:
|
|
832
|
-
vocab_set =
|
|
833
|
-
|
|
647
|
+
vocab_set = set(counts.keys())
|
|
648
|
+
# Filter out None values before sorting to avoid comparison errors
|
|
649
|
+
vocab_list = (
|
|
650
|
+
sorted(v for v in vocab_set if v is not None)
|
|
651
|
+
if vocab_set
|
|
652
|
+
else ["<PAD>"]
|
|
653
|
+
)
|
|
834
654
|
if "<UNK>" not in vocab_list:
|
|
835
655
|
vocab_list.append("<UNK>")
|
|
836
656
|
token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
|
|
837
657
|
config["_token_to_idx"] = token_to_idx
|
|
838
658
|
config["_unk_index"] = token_to_idx["<UNK>"]
|
|
839
659
|
config["vocab_size"] = len(vocab_list)
|
|
840
|
-
elif
|
|
660
|
+
elif encode_method == "hash":
|
|
841
661
|
min_freq = config.get("min_freq")
|
|
842
662
|
if min_freq is not None:
|
|
843
|
-
|
|
844
|
-
config["_token_counts"] = token_counts
|
|
663
|
+
config["_token_counts"] = counts
|
|
845
664
|
config["_unk_hash"] = self.hash_string(
|
|
846
665
|
"<UNK>", int(config["hash_size"])
|
|
847
666
|
)
|
|
848
667
|
low_freq_types = sum(
|
|
849
|
-
1 for count in
|
|
668
|
+
1 for count in counts.values() if count < min_freq
|
|
850
669
|
)
|
|
851
|
-
total_types = len(
|
|
670
|
+
total_types = len(counts)
|
|
852
671
|
kept_types = total_types - low_freq_types
|
|
853
672
|
if not config.get("_min_freq_logged"):
|
|
854
673
|
logger.info(
|
|
@@ -860,14 +679,18 @@ class DataProcessor(FeatureSet):
|
|
|
860
679
|
config["_min_freq_logged"] = True
|
|
861
680
|
config["vocab_size"] = config["hash_size"]
|
|
862
681
|
|
|
863
|
-
#
|
|
682
|
+
# targets
|
|
864
683
|
for name, config in self.target_features.items():
|
|
865
|
-
if not
|
|
866
|
-
logger.warning(f"Target {name} has no valid values in provided files")
|
|
684
|
+
if name not in schema:
|
|
867
685
|
continue
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
686
|
+
if config.get("target_type") == "binary":
|
|
687
|
+
unique_vals = (
|
|
688
|
+
lf.select(pl.col(name).drop_nulls().unique())
|
|
689
|
+
.collect()
|
|
690
|
+
.to_series()
|
|
691
|
+
.to_list()
|
|
692
|
+
)
|
|
693
|
+
self.process_target_fit(unique_vals, config, name)
|
|
871
694
|
|
|
872
695
|
self.is_fitted = True
|
|
873
696
|
logger.info(
|
|
@@ -879,18 +702,11 @@ class DataProcessor(FeatureSet):
|
|
|
879
702
|
)
|
|
880
703
|
return self
|
|
881
704
|
|
|
882
|
-
def fit_from_files(
|
|
883
|
-
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
884
|
-
) -> "DataProcessor":
|
|
885
|
-
"""Fit processor statistics by streaming an explicit list of files.
|
|
886
|
-
|
|
887
|
-
This is useful when you want to fit statistics on training files only (exclude
|
|
888
|
-
validation files) in streaming mode.
|
|
889
|
-
"""
|
|
705
|
+
def fit_from_files(self, file_paths: list[str], file_type: str) -> "DataProcessor":
|
|
890
706
|
logger = logging.getLogger()
|
|
891
707
|
logger.info(
|
|
892
708
|
colorize(
|
|
893
|
-
"Fitting DataProcessor
|
|
709
|
+
"Fitting DataProcessor...",
|
|
894
710
|
color="cyan",
|
|
895
711
|
bold=True,
|
|
896
712
|
)
|
|
@@ -899,36 +715,15 @@ class DataProcessor(FeatureSet):
|
|
|
899
715
|
config.pop("_min_freq_logged", None)
|
|
900
716
|
for config in self.sequence_features.values():
|
|
901
717
|
config.pop("_min_freq_logged", None)
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
)
|
|
905
|
-
if uses_robust:
|
|
906
|
-
logger.warning(
|
|
907
|
-
"Robust scaler requires full data; loading provided files into memory. "
|
|
908
|
-
"Consider smaller chunk_size or different scaler if memory is limited."
|
|
909
|
-
)
|
|
910
|
-
frames = [read_table(p, file_type) for p in file_paths]
|
|
911
|
-
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
912
|
-
return self.fit(df)
|
|
913
|
-
return self.fit_from_file_paths(
|
|
914
|
-
file_paths=file_paths, file_type=file_type, chunk_size=chunk_size
|
|
915
|
-
)
|
|
916
|
-
|
|
917
|
-
def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
|
|
918
|
-
"""
|
|
919
|
-
Fit processor statistics by streaming files to reduce memory usage.
|
|
920
|
-
|
|
921
|
-
Args:
|
|
922
|
-
path (str): File or directory path.
|
|
923
|
-
chunk_size (int): Number of rows per chunk.
|
|
718
|
+
lf = self.polars_scan(file_paths, file_type)
|
|
719
|
+
schema = lf.collect_schema()
|
|
720
|
+
return self.polars_fit_from_lazy(lf, schema)
|
|
924
721
|
|
|
925
|
-
|
|
926
|
-
DataProcessor: Fitted DataProcessor instance.
|
|
927
|
-
"""
|
|
722
|
+
def fit_from_path(self, path: str) -> "DataProcessor":
|
|
928
723
|
logger = logging.getLogger()
|
|
929
724
|
logger.info(
|
|
930
725
|
colorize(
|
|
931
|
-
"Fitting DataProcessor
|
|
726
|
+
"Fitting DataProcessor...",
|
|
932
727
|
color="cyan",
|
|
933
728
|
bold=True,
|
|
934
729
|
)
|
|
@@ -938,118 +733,35 @@ class DataProcessor(FeatureSet):
|
|
|
938
733
|
for config in self.sequence_features.values():
|
|
939
734
|
config.pop("_min_freq_logged", None)
|
|
940
735
|
file_paths, file_type = resolve_file_paths(path)
|
|
941
|
-
return self.
|
|
942
|
-
file_paths=file_paths,
|
|
943
|
-
file_type=file_type,
|
|
944
|
-
chunk_size=chunk_size,
|
|
945
|
-
)
|
|
946
|
-
|
|
947
|
-
@overload
|
|
948
|
-
def transform_in_memory(
|
|
949
|
-
self,
|
|
950
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
951
|
-
return_dict: Literal[True],
|
|
952
|
-
persist: bool,
|
|
953
|
-
save_format: Optional[str],
|
|
954
|
-
output_path: Optional[str],
|
|
955
|
-
warn_missing: bool = True,
|
|
956
|
-
) -> Dict[str, np.ndarray]: ...
|
|
736
|
+
return self.fit_from_files(file_paths=file_paths, file_type=file_type)
|
|
957
737
|
|
|
958
|
-
@overload
|
|
959
738
|
def transform_in_memory(
|
|
960
739
|
self,
|
|
961
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
962
|
-
return_dict: Literal[False],
|
|
963
|
-
persist: bool,
|
|
964
|
-
save_format: Optional[str],
|
|
965
|
-
output_path: Optional[str],
|
|
966
|
-
warn_missing: bool = True,
|
|
967
|
-
) -> pd.DataFrame: ...
|
|
968
|
-
|
|
969
|
-
def transform_in_memory(
|
|
970
|
-
self,
|
|
971
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
740
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
972
741
|
return_dict: bool,
|
|
973
742
|
persist: bool,
|
|
974
743
|
save_format: Optional[str],
|
|
975
744
|
output_path: Optional[str],
|
|
976
745
|
warn_missing: bool = True,
|
|
977
746
|
):
|
|
978
|
-
"""
|
|
979
|
-
Transform in-memory data and optionally persist the transformed data.
|
|
980
|
-
|
|
981
|
-
Args:
|
|
982
|
-
data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
|
|
983
|
-
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
984
|
-
persist (bool): Whether to persist the transformed data to disk.
|
|
985
|
-
save_format (Optional[str]): Format to save the data if persisting.
|
|
986
|
-
output_path (Optional[str]): Output path to save the data if persisting.
|
|
987
|
-
warn_missing (bool): Whether to warn about missing features in the data.
|
|
988
|
-
|
|
989
|
-
Returns:
|
|
990
|
-
Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
|
|
991
|
-
"""
|
|
992
|
-
|
|
993
747
|
logger = logging.getLogger()
|
|
994
|
-
data_dict = data if isinstance(data, dict) else None
|
|
995
748
|
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
result_dict[col] = df[col].to_numpy(copy=False)
|
|
749
|
+
if isinstance(data, dict):
|
|
750
|
+
df = pl.DataFrame(data)
|
|
751
|
+
elif isinstance(data, pd.DataFrame):
|
|
752
|
+
df = pl.from_pandas(data)
|
|
1001
753
|
else:
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
result_dict[key] = value.to_numpy(copy=False)
|
|
1009
|
-
else:
|
|
1010
|
-
result_dict[key] = np.asarray(value)
|
|
1011
|
-
|
|
1012
|
-
data_columns = data.columns if isinstance(data, pd.DataFrame) else data_dict
|
|
1013
|
-
feature_groups = [
|
|
1014
|
-
("Numeric", self.numeric_features, self.process_numeric_feature_transform),
|
|
1015
|
-
("Sparse", self.sparse_features, self.process_sparse_feature_transform),
|
|
1016
|
-
(
|
|
1017
|
-
"Sequence",
|
|
1018
|
-
self.sequence_features,
|
|
1019
|
-
self.process_sequence_feature_transform,
|
|
1020
|
-
),
|
|
1021
|
-
("Target", self.target_features, self.process_target_transform),
|
|
1022
|
-
]
|
|
1023
|
-
for label, features, transform_fn in feature_groups:
|
|
1024
|
-
for name, config in features.items():
|
|
1025
|
-
present = name in data_columns # type: ignore[operator]
|
|
1026
|
-
if not present:
|
|
1027
|
-
if warn_missing:
|
|
1028
|
-
logger.warning(f"{label} feature {name} not found in data")
|
|
1029
|
-
continue
|
|
1030
|
-
series_data = (
|
|
1031
|
-
data[name]
|
|
1032
|
-
if isinstance(data, pd.DataFrame)
|
|
1033
|
-
else pd.Series(result_dict[name], name=name)
|
|
1034
|
-
)
|
|
1035
|
-
result_dict[name] = transform_fn(series_data, config)
|
|
1036
|
-
|
|
1037
|
-
def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
|
|
1038
|
-
# Convert all arrays to Series/lists at once to avoid fragmentation
|
|
1039
|
-
columns_dict = {}
|
|
1040
|
-
for key, value in result.items():
|
|
1041
|
-
if key in self.sequence_features:
|
|
1042
|
-
columns_dict[key] = np.asarray(value).tolist()
|
|
1043
|
-
else:
|
|
1044
|
-
columns_dict[key] = value
|
|
1045
|
-
return pd.DataFrame(columns_dict)
|
|
754
|
+
df = data
|
|
755
|
+
|
|
756
|
+
schema = df.schema
|
|
757
|
+
lf = df.lazy()
|
|
758
|
+
lf = self.apply_transforms(lf, schema, warn_missing=warn_missing)
|
|
759
|
+
out_df = lf.collect()
|
|
1046
760
|
|
|
1047
761
|
effective_format = save_format
|
|
1048
762
|
if persist:
|
|
1049
763
|
effective_format = save_format or "parquet"
|
|
1050
|
-
|
|
1051
|
-
if (not return_dict) or persist:
|
|
1052
|
-
result_df = dict_to_dataframe(result_dict)
|
|
764
|
+
|
|
1053
765
|
if persist:
|
|
1054
766
|
if effective_format not in FILE_FORMAT_CONFIG:
|
|
1055
767
|
raise ValueError(f"Unsupported save format: {effective_format}")
|
|
@@ -1061,68 +773,63 @@ class DataProcessor(FeatureSet):
|
|
|
1061
773
|
if output_dir.suffix:
|
|
1062
774
|
output_dir = output_dir.parent
|
|
1063
775
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1064
|
-
|
|
1065
776
|
suffix = FILE_FORMAT_CONFIG[effective_format]["extension"][0]
|
|
1066
777
|
save_path = output_dir / f"transformed_data{suffix}"
|
|
1067
|
-
assert result_df is not None, "DataFrame conversion failed"
|
|
1068
|
-
|
|
1069
|
-
# Save based on format
|
|
1070
778
|
if effective_format == "csv":
|
|
1071
|
-
|
|
779
|
+
out_df.write_csv(save_path)
|
|
1072
780
|
elif effective_format == "parquet":
|
|
1073
|
-
|
|
781
|
+
out_df.write_parquet(save_path)
|
|
1074
782
|
elif effective_format == "feather":
|
|
1075
|
-
|
|
1076
|
-
elif effective_format == "excel":
|
|
1077
|
-
result_df.to_excel(save_path, index=False)
|
|
1078
|
-
elif effective_format == "hdf5":
|
|
1079
|
-
result_df.to_hdf(save_path, key="data", mode="w")
|
|
783
|
+
out_df.write_ipc(save_path)
|
|
1080
784
|
else:
|
|
1081
|
-
raise ValueError(
|
|
1082
|
-
|
|
785
|
+
raise ValueError(
|
|
786
|
+
f"Format '{effective_format}' is not supported by the polars-only pipeline."
|
|
787
|
+
)
|
|
1083
788
|
logger.info(
|
|
1084
789
|
colorize(
|
|
1085
|
-
f"Transformed data saved to: {save_path.resolve()}",
|
|
790
|
+
f"Transformed data saved to: {save_path.resolve()}",
|
|
791
|
+
color="green",
|
|
1086
792
|
)
|
|
1087
793
|
)
|
|
794
|
+
|
|
1088
795
|
if return_dict:
|
|
796
|
+
result_dict = {}
|
|
797
|
+
for col in out_df.columns:
|
|
798
|
+
series = out_df.get_column(col)
|
|
799
|
+
if col in self.sequence_features:
|
|
800
|
+
result_dict[col] = np.asarray(series.to_list(), dtype=np.int64)
|
|
801
|
+
else:
|
|
802
|
+
result_dict[col] = series.to_numpy()
|
|
1089
803
|
return result_dict
|
|
1090
|
-
|
|
1091
|
-
return
|
|
804
|
+
|
|
805
|
+
return out_df
|
|
1092
806
|
|
|
1093
807
|
def transform_path(
|
|
1094
808
|
self,
|
|
1095
809
|
input_path: str,
|
|
1096
810
|
output_path: Optional[str],
|
|
1097
811
|
save_format: Optional[str],
|
|
1098
|
-
chunk_size: int = 200000,
|
|
1099
812
|
):
|
|
1100
|
-
"""Transform data from files under a path and save them
|
|
1101
|
-
|
|
1102
|
-
Uses chunked reading/writing to keep peak memory bounded for large files.
|
|
1103
|
-
|
|
1104
|
-
Args:
|
|
1105
|
-
input_path (str): Input file or directory path.
|
|
1106
|
-
output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
|
|
1107
|
-
save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
|
|
1108
|
-
chunk_size (int): Number of rows per chunk.
|
|
1109
|
-
"""
|
|
813
|
+
"""Transform data from files under a path and save them using polars lazy pipeline."""
|
|
1110
814
|
logger = logging.getLogger()
|
|
1111
815
|
file_paths, file_type = resolve_file_paths(input_path)
|
|
1112
816
|
target_format = save_format or file_type
|
|
1113
817
|
if target_format not in FILE_FORMAT_CONFIG:
|
|
1114
818
|
raise ValueError(f"Unsupported format: {target_format}")
|
|
1115
|
-
if
|
|
819
|
+
if target_format not in {"csv", "parquet", "feather"}:
|
|
820
|
+
raise ValueError(
|
|
821
|
+
f"Format '{target_format}' is not supported by the polars-only pipeline."
|
|
822
|
+
)
|
|
823
|
+
if not check_streaming_support(file_type):
|
|
1116
824
|
raise ValueError(
|
|
1117
825
|
f"Input format '{file_type}' does not support streaming reads. "
|
|
1118
|
-
"
|
|
826
|
+
"Polars backend supports csv/parquet only."
|
|
1119
827
|
)
|
|
1120
828
|
|
|
1121
|
-
# Warn about streaming support
|
|
1122
829
|
if not check_streaming_support(target_format):
|
|
1123
830
|
logger.warning(
|
|
1124
831
|
f"[Data Processor Warning] Format '{target_format}' does not support streaming writes. "
|
|
1125
|
-
"
|
|
832
|
+
"Data will be collected in memory before saving."
|
|
1126
833
|
)
|
|
1127
834
|
|
|
1128
835
|
base_output_dir = (
|
|
@@ -1133,122 +840,48 @@ class DataProcessor(FeatureSet):
|
|
|
1133
840
|
output_root = base_output_dir / "transformed_data"
|
|
1134
841
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
1135
842
|
saved_paths = []
|
|
843
|
+
|
|
1136
844
|
for file_path in progress(file_paths, description="Transforming files"):
|
|
1137
845
|
source_path = Path(file_path)
|
|
1138
846
|
suffix = FILE_FORMAT_CONFIG[target_format]["extension"][0]
|
|
1139
847
|
target_file = output_root / f"{source_path.stem}{suffix}"
|
|
1140
848
|
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
df = read_table(file_path, file_type)
|
|
1145
|
-
transformed_df = self.transform_in_memory(
|
|
1146
|
-
df,
|
|
1147
|
-
return_dict=False,
|
|
1148
|
-
persist=False,
|
|
1149
|
-
save_format=None,
|
|
1150
|
-
output_path=None,
|
|
1151
|
-
warn_missing=True,
|
|
1152
|
-
)
|
|
1153
|
-
assert isinstance(
|
|
1154
|
-
transformed_df, pd.DataFrame
|
|
1155
|
-
), "[Data Processor Error] Expected DataFrame when return_dict=False"
|
|
1156
|
-
|
|
1157
|
-
# Save based on format
|
|
1158
|
-
if target_format == "csv":
|
|
1159
|
-
transformed_df.to_csv(target_file, index=False)
|
|
1160
|
-
elif target_format == "parquet":
|
|
1161
|
-
transformed_df.to_parquet(target_file, index=False)
|
|
1162
|
-
elif target_format == "feather":
|
|
1163
|
-
transformed_df.to_feather(target_file)
|
|
1164
|
-
elif target_format == "excel":
|
|
1165
|
-
transformed_df.to_excel(target_file, index=False)
|
|
1166
|
-
elif target_format == "hdf5":
|
|
1167
|
-
transformed_df.to_hdf(target_file, key="data", mode="w")
|
|
1168
|
-
else:
|
|
1169
|
-
raise ValueError(f"Unsupported format: {target_format}")
|
|
1170
|
-
|
|
1171
|
-
saved_paths.append(str(target_file.resolve()))
|
|
1172
|
-
continue
|
|
849
|
+
lf = self.polars_scan([file_path], file_type)
|
|
850
|
+
schema = lf.collect_schema()
|
|
851
|
+
lf = self.apply_transforms(lf, schema, warn_missing=True)
|
|
1173
852
|
|
|
1174
|
-
first_chunk = True
|
|
1175
|
-
# Streaming write for supported formats
|
|
1176
853
|
if target_format == "parquet":
|
|
1177
|
-
|
|
1178
|
-
try:
|
|
1179
|
-
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
1180
|
-
transformed_df = self.transform_in_memory(
|
|
1181
|
-
chunk,
|
|
1182
|
-
return_dict=False,
|
|
1183
|
-
persist=False,
|
|
1184
|
-
save_format=None,
|
|
1185
|
-
output_path=None,
|
|
1186
|
-
warn_missing=first_chunk,
|
|
1187
|
-
)
|
|
1188
|
-
assert isinstance(
|
|
1189
|
-
transformed_df, pd.DataFrame
|
|
1190
|
-
), "[Data Processor Error] Expected DataFrame when return_dict=False"
|
|
1191
|
-
table = pa.Table.from_pandas(
|
|
1192
|
-
transformed_df, preserve_index=False
|
|
1193
|
-
)
|
|
1194
|
-
if parquet_writer is None:
|
|
1195
|
-
parquet_writer = pq.ParquetWriter(target_file, table.schema)
|
|
1196
|
-
parquet_writer.write_table(table)
|
|
1197
|
-
first_chunk = False
|
|
1198
|
-
finally:
|
|
1199
|
-
if parquet_writer is not None:
|
|
1200
|
-
parquet_writer.close()
|
|
854
|
+
lf.sink_parquet(target_file)
|
|
1201
855
|
elif target_format == "csv":
|
|
1202
|
-
# CSV
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
856
|
+
# CSV doesn't support nested data (lists), so convert list columns to string
|
|
857
|
+
transformed_schema = lf.collect_schema()
|
|
858
|
+
list_cols = [
|
|
859
|
+
name
|
|
860
|
+
for name, dtype in transformed_schema.items()
|
|
861
|
+
if isinstance(dtype, pl.List)
|
|
862
|
+
]
|
|
863
|
+
if list_cols:
|
|
864
|
+
# Convert list columns to string representation for CSV
|
|
865
|
+
# Format as [1, 2, 3] by casting elements to string, joining with ", ", and adding brackets
|
|
866
|
+
list_exprs = []
|
|
867
|
+
for name in list_cols:
|
|
868
|
+
# Convert list to string representation
|
|
869
|
+
list_exprs.append(
|
|
870
|
+
(
|
|
871
|
+
pl.lit("[")
|
|
872
|
+
+ pl.col(name)
|
|
873
|
+
.list.eval(pl.element().cast(pl.String))
|
|
874
|
+
.list.join(", ")
|
|
875
|
+
+ pl.lit("]")
|
|
876
|
+
).alias(name)
|
|
877
|
+
)
|
|
878
|
+
lf = lf.with_columns(list_exprs)
|
|
879
|
+
lf.sink_csv(target_file)
|
|
1222
880
|
else:
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
f"Format '{target_format}' doesn't support streaming writes. "
|
|
1226
|
-
f"Collecting all chunks in memory before saving."
|
|
1227
|
-
)
|
|
1228
|
-
all_chunks = []
|
|
1229
|
-
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
1230
|
-
transformed_df = self.transform_in_memory(
|
|
1231
|
-
chunk,
|
|
1232
|
-
return_dict=False,
|
|
1233
|
-
persist=False,
|
|
1234
|
-
save_format=None,
|
|
1235
|
-
output_path=None,
|
|
1236
|
-
warn_missing=first_chunk,
|
|
1237
|
-
)
|
|
1238
|
-
assert isinstance(transformed_df, pd.DataFrame)
|
|
1239
|
-
all_chunks.append(transformed_df)
|
|
1240
|
-
first_chunk = False
|
|
1241
|
-
|
|
1242
|
-
if all_chunks:
|
|
1243
|
-
combined_df = pd.concat(all_chunks, ignore_index=True)
|
|
1244
|
-
if target_format == "feather":
|
|
1245
|
-
combined_df.to_feather(target_file)
|
|
1246
|
-
elif target_format == "excel":
|
|
1247
|
-
combined_df.to_excel(target_file, index=False)
|
|
1248
|
-
elif target_format == "hdf5":
|
|
1249
|
-
combined_df.to_hdf(target_file, key="data", mode="w")
|
|
1250
|
-
|
|
881
|
+
df = lf.collect()
|
|
882
|
+
df.write_ipc(target_file)
|
|
1251
883
|
saved_paths.append(str(target_file.resolve()))
|
|
884
|
+
|
|
1252
885
|
logger.info(
|
|
1253
886
|
colorize(
|
|
1254
887
|
f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
|
|
@@ -1260,74 +893,51 @@ class DataProcessor(FeatureSet):
|
|
|
1260
893
|
# fit is nothing but registering the statistics from data so that we can transform the data later
|
|
1261
894
|
def fit(
|
|
1262
895
|
self,
|
|
1263
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1264
|
-
chunk_size: int = 200000,
|
|
896
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1265
897
|
):
|
|
1266
898
|
"""
|
|
1267
899
|
Fit the DataProcessor to the provided data.
|
|
1268
900
|
|
|
1269
901
|
Args:
|
|
1270
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
|
|
1271
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
902
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
|
|
1272
903
|
|
|
1273
904
|
Returns:
|
|
1274
905
|
DataProcessor: Fitted DataProcessor instance.
|
|
1275
906
|
"""
|
|
1276
907
|
|
|
1277
|
-
logger = logging.getLogger()
|
|
1278
908
|
for config in self.sparse_features.values():
|
|
1279
909
|
config.pop("_min_freq_logged", None)
|
|
1280
910
|
for config in self.sequence_features.values():
|
|
1281
911
|
config.pop("_min_freq_logged", None)
|
|
1282
912
|
if isinstance(data, (str, os.PathLike)):
|
|
1283
|
-
|
|
1284
|
-
uses_robust = any(
|
|
1285
|
-
cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
|
|
1286
|
-
)
|
|
1287
|
-
if uses_robust:
|
|
1288
|
-
logger.warning(
|
|
1289
|
-
"Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited."
|
|
1290
|
-
)
|
|
1291
|
-
data = self.load_dataframe_from_path(path_str)
|
|
1292
|
-
else:
|
|
1293
|
-
return self.fit_from_path(path_str, chunk_size)
|
|
913
|
+
return self.fit_from_path(str(data))
|
|
1294
914
|
if isinstance(data, dict):
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
for label, features, fit_fn in feature_groups:
|
|
1304
|
-
for name, config in features.items():
|
|
1305
|
-
if name not in data.columns:
|
|
1306
|
-
logger.warning(f"{label} feature {name} not found in data")
|
|
1307
|
-
continue
|
|
1308
|
-
fit_fn(data[name], config)
|
|
1309
|
-
self.is_fitted = True
|
|
1310
|
-
return self
|
|
915
|
+
df = pl.DataFrame(data)
|
|
916
|
+
elif isinstance(data, pd.DataFrame):
|
|
917
|
+
df = pl.from_pandas(data)
|
|
918
|
+
else:
|
|
919
|
+
df = data
|
|
920
|
+
lf = df.lazy()
|
|
921
|
+
schema = df.schema
|
|
922
|
+
return self.polars_fit_from_lazy(lf, schema)
|
|
1311
923
|
|
|
1312
924
|
@overload
|
|
1313
925
|
def transform(
|
|
1314
926
|
self,
|
|
1315
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
927
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1316
928
|
return_dict: Literal[True] = True,
|
|
1317
929
|
save_format: Optional[str] = None,
|
|
1318
930
|
output_path: Optional[str] = None,
|
|
1319
|
-
chunk_size: int = 200000,
|
|
1320
931
|
) -> Dict[str, np.ndarray]: ...
|
|
1321
932
|
|
|
1322
933
|
@overload
|
|
1323
934
|
def transform(
|
|
1324
935
|
self,
|
|
1325
|
-
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
936
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
|
|
1326
937
|
return_dict: Literal[False] = False,
|
|
1327
938
|
save_format: Optional[str] = None,
|
|
1328
939
|
output_path: Optional[str] = None,
|
|
1329
|
-
|
|
1330
|
-
) -> pd.DataFrame: ...
|
|
940
|
+
) -> pl.DataFrame: ...
|
|
1331
941
|
|
|
1332
942
|
@overload
|
|
1333
943
|
def transform(
|
|
@@ -1336,28 +946,25 @@ class DataProcessor(FeatureSet):
|
|
|
1336
946
|
return_dict: Literal[False] = False,
|
|
1337
947
|
save_format: Optional[str] = None,
|
|
1338
948
|
output_path: Optional[str] = None,
|
|
1339
|
-
chunk_size: int = 200000,
|
|
1340
949
|
) -> list[str]: ...
|
|
1341
950
|
|
|
1342
951
|
def transform(
|
|
1343
952
|
self,
|
|
1344
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
953
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1345
954
|
return_dict: bool = True,
|
|
1346
955
|
save_format: Optional[str] = None,
|
|
1347
956
|
output_path: Optional[str] = None,
|
|
1348
|
-
chunk_size: int = 200000,
|
|
1349
957
|
):
|
|
1350
958
|
"""
|
|
1351
959
|
Transform the provided data using the fitted DataProcessor.
|
|
1352
960
|
|
|
1353
961
|
Args:
|
|
1354
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
|
|
962
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
|
|
1355
963
|
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
1356
964
|
save_format (Optional[str]): Format to save the data if output_path is provided.
|
|
1357
965
|
output_path (Optional[str]): Output path to save the transformed data.
|
|
1358
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
1359
966
|
Returns:
|
|
1360
|
-
Union[
|
|
967
|
+
Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
|
|
1361
968
|
"""
|
|
1362
969
|
|
|
1363
970
|
if not self.is_fitted:
|
|
@@ -1369,9 +976,7 @@ class DataProcessor(FeatureSet):
|
|
|
1369
976
|
raise ValueError(
|
|
1370
977
|
"[Data Processor Error] Path transform writes files only; set return_dict=False when passing a path."
|
|
1371
978
|
)
|
|
1372
|
-
return self.transform_path(
|
|
1373
|
-
str(data), output_path, save_format, chunk_size=chunk_size
|
|
1374
|
-
)
|
|
979
|
+
return self.transform_path(str(data), output_path, save_format)
|
|
1375
980
|
return self.transform_in_memory(
|
|
1376
981
|
data=data,
|
|
1377
982
|
return_dict=return_dict,
|
|
@@ -1382,26 +987,24 @@ class DataProcessor(FeatureSet):
|
|
|
1382
987
|
|
|
1383
988
|
def fit_transform(
|
|
1384
989
|
self,
|
|
1385
|
-
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
990
|
+
data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
1386
991
|
return_dict: bool = True,
|
|
1387
992
|
save_format: Optional[str] = None,
|
|
1388
993
|
output_path: Optional[str] = None,
|
|
1389
|
-
chunk_size: int = 200000,
|
|
1390
994
|
):
|
|
1391
995
|
"""
|
|
1392
996
|
Fit the DataProcessor to the provided data and then transform it.
|
|
1393
997
|
|
|
1394
998
|
Args:
|
|
1395
|
-
data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
|
|
999
|
+
data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
|
|
1396
1000
|
return_dict (bool): Whether to return a dictionary of numpy arrays.
|
|
1397
1001
|
save_format (Optional[str]): Format to save the data if output_path is provided.
|
|
1398
|
-
output_path (Optional[str]): Output path to save the
|
|
1399
|
-
chunk_size (int): Number of rows per chunk when streaming from path.
|
|
1002
|
+
output_path (Optional[str]): Output path to save the data.
|
|
1400
1003
|
Returns:
|
|
1401
|
-
Union[
|
|
1004
|
+
Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
|
|
1402
1005
|
"""
|
|
1403
1006
|
|
|
1404
|
-
self.fit(data
|
|
1007
|
+
self.fit(data)
|
|
1405
1008
|
return self.transform(
|
|
1406
1009
|
data,
|
|
1407
1010
|
return_dict=return_dict,
|