nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +10 -18
  3. nextrec/basic/asserts.py +1 -22
  4. nextrec/basic/callback.py +2 -2
  5. nextrec/basic/features.py +6 -37
  6. nextrec/basic/heads.py +13 -1
  7. nextrec/basic/layers.py +33 -123
  8. nextrec/basic/loggers.py +3 -2
  9. nextrec/basic/metrics.py +85 -4
  10. nextrec/basic/model.py +518 -7
  11. nextrec/basic/summary.py +88 -42
  12. nextrec/cli.py +117 -30
  13. nextrec/data/data_processing.py +8 -13
  14. nextrec/data/preprocessor.py +449 -844
  15. nextrec/loss/grad_norm.py +78 -76
  16. nextrec/models/multi_task/ple.py +1 -0
  17. nextrec/models/multi_task/share_bottom.py +1 -0
  18. nextrec/models/ranking/afm.py +4 -9
  19. nextrec/models/ranking/dien.py +7 -8
  20. nextrec/models/ranking/ffm.py +2 -2
  21. nextrec/models/retrieval/sdm.py +1 -2
  22. nextrec/models/sequential/hstu.py +0 -2
  23. nextrec/models/tree_base/base.py +1 -1
  24. nextrec/utils/__init__.py +2 -1
  25. nextrec/utils/config.py +1 -1
  26. nextrec/utils/console.py +1 -1
  27. nextrec/utils/onnx_utils.py +252 -0
  28. nextrec/utils/torch_utils.py +63 -56
  29. nextrec/utils/types.py +43 -0
  30. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
  31. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
  32. nextrec/models/multi_task/[pre]star.py +0 -192
  33. nextrec/models/representation/autorec.py +0 -0
  34. nextrec/models/representation/bpr.py +0 -0
  35. nextrec/models/representation/cl4srec.py +0 -0
  36. nextrec/models/representation/lightgcn.py +0 -0
  37. nextrec/models/representation/mf.py +0 -0
  38. nextrec/models/representation/s3rec.py +0 -0
  39. nextrec/models/sequential/sasrec.py +0 -0
  40. nextrec/utils/feature.py +0 -29
  41. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
  42. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
  43. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Checkpoint: edit on 29/12/2025
5
+ Checkpoint: edit on 28/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -13,14 +13,12 @@ import logging
13
13
  import os
14
14
  import pickle
15
15
  from pathlib import Path
16
- from typing import Any, Dict, Literal, Optional, Union, overload
16
+ from typing import Any, Dict, Iterable, Literal, Optional, Union, overload
17
17
 
18
18
  import numpy as np
19
19
  import pandas as pd
20
- import pyarrow as pa
21
- import pyarrow.parquet as pq
20
+ import polars as pl
22
21
  from sklearn.preprocessing import (
23
- LabelEncoder,
24
22
  MaxAbsScaler,
25
23
  MinMaxScaler,
26
24
  RobustScaler,
@@ -37,9 +35,6 @@ from nextrec.utils.data import (
37
35
  FILE_FORMAT_CONFIG,
38
36
  check_streaming_support,
39
37
  default_output_dir,
40
- iter_file_chunks,
41
- load_dataframes,
42
- read_table,
43
38
  resolve_file_paths,
44
39
  )
45
40
 
@@ -63,7 +58,7 @@ class DataProcessor(FeatureSet):
63
58
  self.is_fitted = False
64
59
 
65
60
  self.scalers: Dict[str, Any] = {}
66
- self.label_encoders: Dict[str, LabelEncoder] = {}
61
+ self.label_encoders: Dict[str, Any] = {}
67
62
  self.target_encoders: Dict[str, Dict[str, int]] = {}
68
63
  self.set_target_id(target=[], id_columns=[])
69
64
 
@@ -186,318 +181,228 @@ class DataProcessor(FeatureSet):
186
181
  def hash_string(self, s: str, hash_size: int) -> int:
187
182
  return self.hash_fn(str(s), int(hash_size))
188
183
 
189
- def process_numeric_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
190
- name = str(data.name)
191
- scaler_type = config["scaler"]
192
- fill_na = config["fill_na"]
193
- if data.isna().any():
194
- if fill_na is None:
195
- # Default use mean value to fill missing values for numeric features
196
- fill_na = data.mean()
197
- config["fill_na_value"] = fill_na
198
- scaler_map = {
199
- "standard": StandardScaler,
200
- "minmax": MinMaxScaler,
201
- "robust": RobustScaler,
202
- "maxabs": MaxAbsScaler,
203
- }
204
- if scaler_type in ("log", "none"):
205
- scaler = None
206
- else:
207
- scaler_cls = scaler_map.get(scaler_type)
208
- if scaler_cls is None:
209
- raise ValueError(
210
- f"[Data Processor Error] Unknown scaler type: {scaler_type}"
211
- )
212
- scaler = scaler_cls()
213
- if scaler is not None:
214
- filled_data = data.fillna(config.get("fill_na_value", 0))
215
- values = np.array(filled_data.values, dtype=np.float64).reshape(-1, 1)
216
- scaler.fit(values)
217
- self.scalers[name] = scaler
218
-
219
- def process_numeric_feature_transform(
220
- self, data: pd.Series, config: Dict[str, Any]
221
- ) -> np.ndarray:
222
- logger = logging.getLogger()
223
- name = str(data.name)
224
- scaler_type = config["scaler"]
225
- fill_na_value = config.get("fill_na_value", 0)
226
- filled_data = data.fillna(fill_na_value)
227
- values = np.array(filled_data.values, dtype=np.float64)
228
- if scaler_type == "log":
229
- result = np.log1p(np.maximum(values, 0))
230
- elif scaler_type == "none":
231
- result = values
184
+ def polars_scan(self, file_paths: list[str], file_type: str):
185
+ file_type = file_type.lower()
186
+ if file_type == "csv":
187
+ return pl.scan_csv(file_paths, ignore_errors=True)
188
+ if file_type == "parquet":
189
+ return pl.scan_parquet(file_paths)
190
+ raise ValueError(
191
+ f"[Data Processor Error] Polars backend only supports csv/parquet, got: {file_type}"
192
+ )
193
+
194
+ def sequence_expr(
195
+ self, pl, name: str, config: Dict[str, Any], schema: Dict[str, Any]
196
+ ):
197
+ """
198
+ generate polars expression for sequence feature processing
199
+
200
+ Example Input:
201
+ sequence_str: "1,2,3"
202
+ sequence_str: " 4, ,5 "
203
+ sequence_list: ["7", "8", "9"]
204
+ sequence_list: ["", "10", " 11 "]
205
+
206
+ Example Output:
207
+ sequence_str -> ["1","2","3"]
208
+ sequence_str -> ["4","5"]
209
+ sequence_list -> ["7","8","9"]
210
+ sequence_list -> ["10","11"]
211
+ """
212
+ separator = config["separator"]
213
+ dtype = schema.get(name)
214
+ col = pl.col(name)
215
+ if dtype is not None and isinstance(dtype, pl.List):
216
+ seq_col = col
232
217
  else:
233
- scaler = self.scalers.get(name)
234
- if scaler is None:
235
- logger.warning(
236
- f"Scaler for {name} not fitted, returning original values"
237
- )
238
- result = values
239
- else:
240
- result = scaler.transform(values.reshape(-1, 1)).ravel()
241
- return result
218
+ seq_col = col.cast(pl.Utf8).fill_null("").str.split(separator)
219
+ elem = pl.element().cast(pl.Utf8).str.strip_chars()
220
+ seq_col = seq_col.list.eval(
221
+ pl.when(elem == "").then(None).otherwise(elem)
222
+ ).list.drop_nulls()
223
+ return seq_col
224
+
225
+ def apply_transforms(self, lf, schema: Dict[str, Any], warn_missing: bool):
226
+ """
227
+ Apply all transformations to a Polars LazyFrame.
242
228
 
243
- def process_sparse_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
229
+ """
244
230
  logger = logging.getLogger()
231
+ expressions = []
245
232
 
246
- encode_method = config["encode_method"]
247
- fill_na = config["fill_na"] # <UNK>
248
- filled_data = data.fillna(fill_na).astype(str)
249
- if encode_method == "label":
250
- min_freq = config.get("min_freq")
251
- if min_freq is not None:
252
- counts = filled_data.value_counts()
253
- config["_token_counts"] = counts.to_dict()
254
- vocab = sorted(counts[counts >= min_freq].index.tolist())
255
- low_freq_types = int((counts < min_freq).sum())
256
- total_types = int(counts.size)
257
- kept_types = total_types - low_freq_types
258
- if not config.get("_min_freq_logged"):
259
- logger.info(
260
- f"Sparse feature {data.name} min_freq={min_freq}: "
261
- f"{total_types} token types total, "
262
- f"{low_freq_types} low-frequency, "
263
- f"{kept_types} kept."
264
- )
265
- config["_min_freq_logged"] = True
266
- else:
267
- vocab = sorted(set(filled_data.tolist()))
268
- if "<UNK>" not in vocab:
269
- vocab.append("<UNK>")
270
- token_to_idx = {token: idx for idx, token in enumerate(vocab)}
271
- config["_token_to_idx"] = token_to_idx
272
- config["_unk_index"] = token_to_idx["<UNK>"]
273
- config["vocab_size"] = len(vocab)
274
- elif encode_method == "hash":
275
- min_freq = config.get("min_freq")
276
- if min_freq is not None:
277
- counts = filled_data.value_counts()
278
- config["_token_counts"] = counts.to_dict()
279
- config["_unk_hash"] = self.hash_string(
280
- "<UNK>", int(config["hash_size"])
281
- )
282
- low_freq_types = int((counts < min_freq).sum())
283
- total_types = int(counts.size)
284
- kept_types = total_types - low_freq_types
285
- if not config.get("_min_freq_logged"):
286
- logger.info(
287
- f"Sparse feature {data.name} min_freq={min_freq}: "
288
- f"{total_types} token types total, "
289
- f"{low_freq_types} low-frequency, "
290
- f"{kept_types} kept."
291
- )
292
- config["_min_freq_logged"] = True
293
- config["vocab_size"] = config["hash_size"]
294
-
295
- def process_sparse_feature_transform(
296
- self, data: pd.Series, config: Dict[str, Any]
297
- ) -> np.ndarray:
298
- name = str(data.name)
299
- encode_method = config["encode_method"]
300
- fill_na = config["fill_na"]
301
-
302
- sparse_series = (
303
- data if isinstance(data, pd.Series) else pd.Series(data, name=name)
304
- )
305
- sparse_series = sparse_series.fillna(fill_na).astype(str)
306
- if encode_method == "label":
307
- token_to_idx = config.get("_token_to_idx")
308
- if isinstance(token_to_idx, dict):
309
- unk_index = int(config.get("_unk_index", 0))
310
- return np.fromiter(
311
- (token_to_idx.get(v, unk_index) for v in sparse_series.to_numpy()),
312
- dtype=np.int64,
313
- count=sparse_series.size,
314
- )
315
- raise ValueError(
316
- f"[Data Processor Error] Token index for {name} not fitted"
233
+ def map_with_default(expr, mapping: Dict[str, int], default: int, dtype):
234
+ # Compatible with older polars versions without Expr.map_dict
235
+ return expr.map_elements(
236
+ lambda x: mapping.get(x, default),
237
+ return_dtype=dtype,
317
238
  )
318
239
 
319
- if encode_method == "hash":
320
- hash_size = config["hash_size"]
321
- hash_fn = self.hash_string
322
- min_freq = config.get("min_freq")
323
- token_counts = config.get("_token_counts")
324
- if min_freq is not None and isinstance(token_counts, dict):
325
- unk_hash = config.get("_unk_hash")
326
- if unk_hash is None:
327
- unk_hash = hash_fn("<UNK>", hash_size)
328
- return np.fromiter(
329
- (
330
- (
331
- unk_hash
332
- if min_freq is not None
333
- and isinstance(token_counts, dict)
334
- and token_counts.get(v, 0) < min_freq
335
- else hash_fn(v, hash_size)
240
+ def ensure_present(feature_name: str, label: str) -> bool:
241
+ if feature_name not in schema:
242
+ if warn_missing:
243
+ logger.warning(f"{label} feature {feature_name} not found in data")
244
+ return False
245
+ return True
246
+
247
+ # Numeric features
248
+ for name, config in self.numeric_features.items():
249
+ if not ensure_present(name, "Numeric"):
250
+ continue
251
+ scaler_type = config["scaler"]
252
+ fill_na_value = config.get("fill_na_value", 0)
253
+ col = pl.col(name).cast(pl.Float64).fill_null(fill_na_value)
254
+ if scaler_type == "log":
255
+ col = col.clip(lower_bound=0).log1p()
256
+ elif scaler_type == "none":
257
+ pass
258
+ else:
259
+ scaler = self.scalers.get(name)
260
+ if scaler is None:
261
+ logger.warning(
262
+ f"Scaler for {name} not fitted, returning original values"
336
263
  )
337
- for v in sparse_series.to_numpy()
338
- ),
339
- dtype=np.int64,
340
- count=sparse_series.size,
341
- )
342
- return np.array([], dtype=np.int64)
264
+ else:
265
+ if scaler_type == "standard":
266
+ mean = float(scaler.mean_[0])
267
+ scale = (
268
+ float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
269
+ )
270
+ col = (col - mean) / scale
271
+ elif scaler_type == "minmax":
272
+ scale = float(scaler.scale_[0])
273
+ min_val = float(scaler.min_[0])
274
+ col = col * scale + min_val
275
+ elif scaler_type == "maxabs":
276
+ max_abs = float(scaler.max_abs_[0]) or 1.0
277
+ col = col / max_abs
278
+ elif scaler_type == "robust":
279
+ center = float(scaler.center_[0])
280
+ scale = (
281
+ float(scaler.scale_[0]) if scaler.scale_[0] != 0 else 1.0
282
+ )
283
+ col = (col - center) / scale
284
+ expressions.append(col.alias(name))
343
285
 
344
- def process_sequence_feature_fit(self, data: pd.Series, config: Dict[str, Any]):
345
- logger = logging.getLogger()
346
- _ = str(data.name)
347
- encode_method = config["encode_method"]
348
- separator = config["separator"]
349
- if encode_method == "label":
350
- min_freq = config.get("min_freq")
351
- token_counts: Dict[str, int] = {}
352
- for seq in data:
353
- tokens = self.extract_sequence_tokens(seq, separator)
354
- for token in tokens:
355
- if str(token).strip():
356
- key = str(token)
357
- token_counts[key] = token_counts.get(key, 0) + 1
358
- if min_freq is not None:
359
- config["_token_counts"] = token_counts
360
- vocab = sorted([k for k, v in token_counts.items() if v >= min_freq])
361
- low_freq_types = sum(
362
- 1 for count in token_counts.values() if count < min_freq
363
- )
364
- total_types = len(token_counts)
365
- kept_types = total_types - low_freq_types
366
- if not config.get("_min_freq_logged"):
367
- logger.info(
368
- f"Sequence feature {data.name} min_freq={min_freq}: "
369
- f"{total_types} token types total, "
370
- f"{low_freq_types} low-frequency, "
371
- f"{kept_types} kept."
286
+ # Sparse features
287
+ for name, config in self.sparse_features.items():
288
+ if not ensure_present(name, "Sparse"):
289
+ continue
290
+ encode_method = config["encode_method"]
291
+ fill_na = config["fill_na"]
292
+ col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
293
+ if encode_method == "label":
294
+ token_to_idx = config.get("_token_to_idx")
295
+ if not isinstance(token_to_idx, dict):
296
+ raise ValueError(
297
+ f"[Data Processor Error] Token index for {name} not fitted"
372
298
  )
373
- config["_min_freq_logged"] = True
374
- else:
375
- vocab = sorted(token_counts.keys())
376
- if not vocab:
377
- vocab = ["<PAD>"]
378
- if "<UNK>" not in vocab:
379
- vocab.append("<UNK>")
380
- token_to_idx = {token: idx for idx, token in enumerate(vocab)}
381
- config["_token_to_idx"] = token_to_idx
382
- config["_unk_index"] = token_to_idx["<UNK>"]
383
- config["vocab_size"] = len(vocab)
384
- elif encode_method == "hash":
385
- min_freq = config.get("min_freq")
386
- if min_freq is not None:
387
- token_counts: Dict[str, int] = {}
388
- for seq in data:
389
- tokens = self.extract_sequence_tokens(seq, separator)
390
- for token in tokens:
391
- if str(token).strip():
392
- token_counts[str(token)] = (
393
- token_counts.get(str(token), 0) + 1
394
- )
395
- config["_token_counts"] = token_counts
396
- config["_unk_hash"] = self.hash_string(
397
- "<UNK>", int(config["hash_size"])
398
- )
399
- low_freq_types = sum(
400
- 1 for count in token_counts.values() if count < min_freq
401
- )
402
- total_types = len(token_counts)
403
- kept_types = total_types - low_freq_types
404
- if not config.get("_min_freq_logged"):
405
- logger.info(
406
- f"Sequence feature {data.name} min_freq={min_freq}: "
407
- f"{total_types} token types total, "
408
- f"{low_freq_types} low-frequency, "
409
- f"{kept_types} kept."
299
+ unk_index = int(config.get("_unk_index", 0))
300
+ col = map_with_default(col, token_to_idx, unk_index, pl.Int64)
301
+ elif encode_method == "hash":
302
+ hash_size = config["hash_size"]
303
+ hash_expr = col.hash().cast(pl.UInt64) % int(hash_size)
304
+ min_freq = config.get("min_freq")
305
+ token_counts = config.get("_token_counts")
306
+ if min_freq is not None and isinstance(token_counts, dict):
307
+ low_freq = [k for k, v in token_counts.items() if v < min_freq]
308
+ unk_hash = config.get("_unk_hash")
309
+ if unk_hash is None:
310
+ unk_hash = self.hash_string("<UNK>", int(hash_size))
311
+ hash_expr = (
312
+ pl.when(col.is_in(low_freq))
313
+ .then(int(unk_hash))
314
+ .otherwise(hash_expr)
410
315
  )
411
- config["_min_freq_logged"] = True
412
- config["vocab_size"] = config["hash_size"]
413
-
414
- def process_sequence_feature_transform(
415
- self, data: pd.Series, config: Dict[str, Any]
416
- ) -> np.ndarray:
417
- """Optimized sequence transform with preallocation and cached vocab map."""
418
- name = str(data.name)
419
- encode_method = config["encode_method"]
420
- max_len = config["max_len"]
421
- pad_value = config["pad_value"]
422
- truncate = config["truncate"]
423
- separator = config["separator"]
424
- arr = np.asarray(data, dtype=object)
425
- n = arr.shape[0]
426
- output = np.full((n, max_len), pad_value, dtype=np.int64)
427
- # Shared helpers cached locally for speed and cross-platform consistency
428
- split_fn = str.split
429
- is_nan = np.isnan
430
- if encode_method == "label":
431
- class_to_idx = config.get("_token_to_idx")
432
- if class_to_idx is None:
433
- raise ValueError(
434
- f"[Data Processor Error] Token index for {name} not fitted"
435
- )
436
- unk_index = int(config.get("_unk_index", class_to_idx.get("<UNK>", 0)))
437
- else:
438
- class_to_idx = None # type: ignore
439
- unk_index = 0
440
- hash_fn = self.hash_string
441
- hash_size = config.get("hash_size")
442
- min_freq = config.get("min_freq")
443
- token_counts = config.get("_token_counts")
444
- if min_freq is not None and isinstance(token_counts, dict):
445
- unk_hash = config.get("_unk_hash")
446
- if unk_hash is None and hash_size is not None:
447
- unk_hash = hash_fn("<UNK>", hash_size)
448
- for i, seq in enumerate(arr):
449
- # normalize sequence to a list of strings
450
- tokens = []
451
- if seq is None:
452
- tokens = []
453
- elif isinstance(seq, (float, np.floating)):
454
- tokens = [] if is_nan(seq) else [str(seq)]
455
- elif isinstance(seq, str):
456
- seq_str = seq.strip()
457
- tokens = [] if not seq_str else split_fn(seq_str, separator)
458
- elif isinstance(seq, (list, tuple, np.ndarray)):
459
- tokens = [str(t) for t in seq]
460
- else:
461
- tokens = []
316
+ col = hash_expr.cast(pl.Int64)
317
+ expressions.append(col.alias(name))
318
+
319
+ # Sequence features
320
+ for name, config in self.sequence_features.items():
321
+ if not ensure_present(name, "Sequence"):
322
+ continue
323
+ encode_method = config["encode_method"]
324
+ max_len = int(config["max_len"])
325
+ pad_value = int(config["pad_value"])
326
+ truncate = config["truncate"]
327
+ seq_col = self.sequence_expr(pl, name, config, schema)
328
+
462
329
  if encode_method == "label":
463
- encoded = [
464
- class_to_idx.get(token.strip(), unk_index) # type: ignore[union-attr]
465
- for token in tokens
466
- if token is not None and token != ""
467
- ]
330
+ token_to_idx = config.get("_token_to_idx")
331
+ if not isinstance(token_to_idx, dict):
332
+ raise ValueError(
333
+ f"[Data Processor Error] Token index for {name} not fitted"
334
+ )
335
+ unk_index = int(config.get("_unk_index", 0))
336
+ seq_col = seq_col.list.eval(
337
+ map_with_default(pl.element(), token_to_idx, unk_index, pl.Int64)
338
+ )
468
339
  elif encode_method == "hash":
340
+ hash_size = config.get("hash_size")
469
341
  if hash_size is None:
470
342
  raise ValueError(
471
343
  "[Data Processor Error] hash_size must be set for hash encoding"
472
344
  )
473
- encoded = [
474
- (
475
- unk_hash
476
- if min_freq is not None
477
- and isinstance(token_counts, dict)
478
- and token_counts.get(str(token), 0) < min_freq
479
- else hash_fn(str(token), hash_size)
345
+ elem = pl.element().cast(pl.Utf8)
346
+ hash_expr = elem.hash().cast(pl.UInt64) % int(hash_size)
347
+ min_freq = config.get("min_freq")
348
+ token_counts = config.get("_token_counts")
349
+ if min_freq is not None and isinstance(token_counts, dict):
350
+ low_freq = [k for k, v in token_counts.items() if v < min_freq]
351
+ unk_hash = config.get("_unk_hash")
352
+ if unk_hash is None:
353
+ unk_hash = self.hash_string("<UNK>", int(hash_size))
354
+ hash_expr = (
355
+ pl.when(elem.is_in(low_freq))
356
+ .then(int(unk_hash))
357
+ .otherwise(hash_expr)
480
358
  )
481
- for token in tokens
482
- if str(token).strip()
483
- ]
359
+ seq_col = seq_col.list.eval(hash_expr)
360
+
361
+ if truncate == "pre":
362
+ seq_col = seq_col.list.tail(max_len)
484
363
  else:
485
- encoded = []
486
- if not encoded:
364
+ seq_col = seq_col.list.head(max_len)
365
+ pad_list = [pad_value] * max_len
366
+ seq_col = pl.concat_list([seq_col, pl.lit(pad_list)]).list.head(max_len)
367
+ expressions.append(seq_col.alias(name))
368
+
369
+ # Target features
370
+ for name, config in self.target_features.items():
371
+ if not ensure_present(name, "Target"):
487
372
  continue
488
- if len(encoded) > max_len:
489
- encoded = encoded[-max_len:] if truncate == "pre" else encoded[:max_len]
490
- output[i, : len(encoded)] = encoded
491
- return output
373
+ target_type = config.get("target_type")
374
+ col = pl.col(name)
375
+ if target_type == "regression":
376
+ col = col.cast(pl.Float32)
377
+ elif target_type == "binary":
378
+ label_map = self.target_encoders.get(name)
379
+ if label_map is None:
380
+ raise ValueError(
381
+ f"[Data Processor Error] Target encoder for {name} not fitted"
382
+ )
383
+ col = map_with_default(col.cast(pl.Utf8), label_map, 0, pl.Int64).cast(
384
+ pl.Float32
385
+ )
386
+ else:
387
+ raise ValueError(
388
+ f"[Data Processor Error] Unsupported target type: {target_type}"
389
+ )
390
+ expressions.append(col.alias(name))
391
+
392
+ if not expressions:
393
+ return lf
394
+ return lf.with_columns(expressions)
492
395
 
493
- def process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
494
- name = str(data.name)
396
+ def process_target_fit(
397
+ self, data: Iterable[Any], config: Dict[str, Any], name: str
398
+ ) -> None:
495
399
  target_type = config["target_type"]
496
400
  label_map = config.get("label_map")
497
401
  if target_type == "binary":
498
402
  if label_map is None:
499
- unique_values = data.dropna().unique()
500
- sorted_values = sorted(unique_values)
403
+ unique_values = {v for v in data if v is not None}
404
+ # Filter out None values before sorting to avoid comparison errors
405
+ sorted_values = sorted(v for v in unique_values if v is not None)
501
406
  try:
502
407
  int_values = [int(v) for v in sorted_values]
503
408
  if int_values == list(range(len(int_values))):
@@ -511,254 +416,149 @@ class DataProcessor(FeatureSet):
511
416
  config["label_map"] = label_map
512
417
  self.target_encoders[name] = label_map
513
418
 
514
- def process_target_transform(
515
- self, data: pd.Series, config: Dict[str, Any]
516
- ) -> np.ndarray:
419
+ def polars_fit_from_lazy(self, lf, schema: Dict[str, Any]) -> "DataProcessor":
517
420
  logger = logging.getLogger()
518
- name = str(data.name)
519
- target_type = config.get("target_type")
520
- if target_type == "regression":
521
- values = np.array(data.values, dtype=np.float32)
522
- return values
523
- if target_type == "binary":
524
- label_map = self.target_encoders.get(name)
525
- if label_map is None:
526
- raise ValueError(
527
- f"[Data Processor Error] Target encoder for {name} not fitted"
528
- )
529
- result = []
530
- for val in data:
531
- str_val = str(val)
532
- if str_val in label_map:
533
- result.append(label_map[str_val])
534
- else:
535
- logger.warning(f"Unknown target value: {val}, mapping to 0")
536
- result.append(0)
537
- return np.array(result, dtype=np.float32)
538
- raise ValueError(
539
- f"[Data Processor Error] Unsupported target type: {target_type}"
540
- )
541
-
542
- def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
543
- """
544
- Load all data from a file or directory path into a single DataFrame.
545
-
546
- Args:
547
- path (str): File or directory path.
548
-
549
- Returns:
550
- pd.DataFrame: Loaded DataFrame.
551
- """
552
- file_paths, file_type = resolve_file_paths(path)
553
- frames = load_dataframes(file_paths, file_type)
554
- return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
555
-
556
- def extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
557
- """Extract sequence tokens from a single value."""
558
- if value is None:
559
- return []
560
- if isinstance(value, (float, np.floating)) and np.isnan(value):
561
- return []
562
- if isinstance(value, str):
563
- stripped = value.strip()
564
- return [] if not stripped else stripped.split(separator)
565
- if isinstance(value, (list, tuple, np.ndarray)):
566
- return [str(v) for v in value]
567
- return [str(value)]
568
-
569
- def fit_from_file_paths(
570
- self, file_paths: list[str], file_type: str, chunk_size: int
571
- ) -> "DataProcessor":
572
- logger = logging.getLogger()
573
- if not file_paths:
574
- raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
575
- if not check_streaming_support(file_type):
576
- raise ValueError(
577
- f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
578
- "Streaming fit only supports csv, parquet to avoid high memory usage."
579
- )
580
-
581
- numeric_acc = {}
582
- for name in self.numeric_features.keys():
583
- numeric_acc[name] = {
584
- "sum": 0.0,
585
- "sumsq": 0.0,
586
- "count": 0.0,
587
- "min": np.inf,
588
- "max": -np.inf,
589
- "max_abs": 0.0,
590
- }
591
- sparse_vocab: Dict[str, set[str]] = {
592
- name: set() for name in self.sparse_features.keys()
593
- }
594
- seq_vocab: Dict[str, set[str]] = {
595
- name: set() for name in self.sequence_features.keys()
596
- }
597
- sparse_label_counts: Dict[str, Dict[str, int]] = {
598
- name: {}
599
- for name, config in self.sparse_features.items()
600
- if config.get("encode_method") == "label" and config.get("min_freq")
601
- }
602
- seq_label_counts: Dict[str, Dict[str, int]] = {
603
- name: {}
604
- for name, config in self.sequence_features.items()
605
- if config.get("encode_method") == "label" and config.get("min_freq")
606
- }
607
- sparse_hash_counts: Dict[str, Dict[str, int]] = {
608
- name: {}
609
- for name, config in self.sparse_features.items()
610
- if config.get("encode_method") == "hash" and config.get("min_freq")
611
- }
612
- seq_hash_counts: Dict[str, Dict[str, int]] = {
613
- name: {}
614
- for name, config in self.sequence_features.items()
615
- if config.get("encode_method") == "hash" and config.get("min_freq")
616
- }
617
- target_values: Dict[str, set[Any]] = {
618
- name: set() for name in self.target_features.keys()
619
- }
620
421
 
621
422
  missing_features = set()
622
- for file_path in file_paths:
623
- for chunk in iter_file_chunks(file_path, file_type, chunk_size):
624
- columns = set(chunk.columns)
625
- feature_groups = [
626
- ("numeric", self.numeric_features),
627
- ("sparse", self.sparse_features),
628
- ("sequence", self.sequence_features),
629
- ]
630
- for group, features in feature_groups:
631
- missing_features.update(features.keys() - columns)
632
- for name in features.keys() & columns:
633
- config = features[name]
634
- series = chunk[name]
635
- if group == "numeric":
636
- values = pd.to_numeric(series, errors="coerce").dropna()
637
- if values.empty:
638
- continue
639
- acc = numeric_acc[name]
640
- arr = values.to_numpy(dtype=np.float64, copy=False)
641
- acc["count"] += arr.size
642
- acc["sum"] += float(arr.sum())
643
- acc["sumsq"] += float(np.square(arr).sum())
644
- acc["min"] = min(acc["min"], float(arr.min()))
645
- acc["max"] = max(acc["max"], float(arr.max()))
646
- acc["max_abs"] = max(
647
- acc["max_abs"], float(np.abs(arr).max())
648
- )
649
- elif group == "sparse":
650
- fill_na = config["fill_na"]
651
- series = series.fillna(fill_na).astype(str)
652
- sparse_vocab[name].update(series.tolist())
653
- if name in sparse_label_counts:
654
- counts = sparse_label_counts[name]
655
- for token in series.tolist():
656
- counts[token] = counts.get(token, 0) + 1
657
- if name in sparse_hash_counts:
658
- counts = sparse_hash_counts[name]
659
- for token in series.tolist():
660
- counts[token] = counts.get(token, 0) + 1
661
- else:
662
- separator = config["separator"]
663
- tokens = []
664
- for val in series:
665
- tokens.extend(
666
- self.extract_sequence_tokens(val, separator)
667
- )
668
- seq_vocab[name].update(tokens)
669
- if name in seq_label_counts:
670
- counts = seq_label_counts[name]
671
- for token in tokens:
672
- if str(token).strip():
673
- key = str(token)
674
- counts[key] = counts.get(key, 0) + 1
675
- if name in seq_hash_counts:
676
- counts = seq_hash_counts[name]
677
- for token in tokens:
678
- if str(token).strip():
679
- key = str(token)
680
- counts[key] = counts.get(key, 0) + 1
681
-
682
- # target features
683
- missing_features.update(self.target_features.keys() - columns)
684
- for name in self.target_features.keys() & columns:
685
- vals = chunk[name].dropna().tolist()
686
- target_values[name].update(vals)
687
-
423
+ for name in self.numeric_features.keys():
424
+ if name not in schema:
425
+ missing_features.add(name)
426
+ for name in self.sparse_features.keys():
427
+ if name not in schema:
428
+ missing_features.add(name)
429
+ for name in self.sequence_features.keys():
430
+ if name not in schema:
431
+ missing_features.add(name)
432
+ for name in self.target_features.keys():
433
+ if name not in schema:
434
+ missing_features.add(name)
688
435
  if missing_features:
689
436
  logger.warning(
690
- f"The following configured features were not found in provided files: {sorted(missing_features)}"
437
+ f"The following configured features were not found in provided data: {sorted(missing_features)}"
691
438
  )
692
439
 
693
- # finalize numeric scalers
440
+ # numeric aggregates in a single pass
441
+ if self.numeric_features:
442
+ agg_exprs = []
443
+ for name in self.numeric_features.keys():
444
+ if name not in schema:
445
+ continue
446
+ col = pl.col(name).cast(pl.Float64)
447
+ agg_exprs.extend(
448
+ [
449
+ col.sum().alias(f"{name}__sum"),
450
+ (col * col).sum().alias(f"{name}__sumsq"),
451
+ col.count().alias(f"{name}__count"),
452
+ col.min().alias(f"{name}__min"),
453
+ col.max().alias(f"{name}__max"),
454
+ col.abs().max().alias(f"{name}__max_abs"),
455
+ ]
456
+ )
457
+ if self.numeric_features[name].get("scaler") == "robust":
458
+ agg_exprs.extend(
459
+ [
460
+ col.quantile(0.25).alias(f"{name}__q1"),
461
+ col.quantile(0.75).alias(f"{name}__q3"),
462
+ col.median().alias(f"{name}__median"),
463
+ ]
464
+ )
465
+ stats = lf.select(agg_exprs).collect().to_dicts()[0] if agg_exprs else {}
466
+ else:
467
+ stats = {}
468
+
694
469
  for name, config in self.numeric_features.items():
695
- acc = numeric_acc[name]
696
- if acc["count"] == 0:
470
+ if name not in schema:
471
+ continue
472
+ count = float(stats.get(f"{name}__count", 0) or 0)
473
+ if count == 0:
697
474
  logger.warning(
698
- f"Numeric feature {name} has no valid values in provided files"
475
+ f"Numeric feature {name} has no valid values in provided data"
699
476
  )
700
477
  continue
701
- mean_val = acc["sum"] / acc["count"]
478
+ sum_val = float(stats.get(f"{name}__sum", 0) or 0)
479
+ sumsq = float(stats.get(f"{name}__sumsq", 0) or 0)
480
+ mean_val = sum_val / count
702
481
  if config["fill_na"] is not None:
703
482
  config["fill_na_value"] = config["fill_na"]
704
483
  else:
705
484
  config["fill_na_value"] = mean_val
706
485
  scaler_type = config["scaler"]
707
486
  if scaler_type == "standard":
708
- var = max(acc["sumsq"] / acc["count"] - mean_val * mean_val, 0.0)
487
+ var = max(sumsq / count - mean_val * mean_val, 0.0)
709
488
  scaler = StandardScaler()
710
489
  scaler.mean_ = np.array([mean_val], dtype=np.float64)
711
490
  scaler.var_ = np.array([var], dtype=np.float64)
712
491
  scaler.scale_ = np.array(
713
492
  [np.sqrt(var) if var > 0 else 1.0], dtype=np.float64
714
493
  )
715
- scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
494
+ scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
716
495
  self.scalers[name] = scaler
717
-
718
496
  elif scaler_type == "minmax":
719
- data_min = acc["min"] if np.isfinite(acc["min"]) else 0.0
720
- data_max = acc["max"] if np.isfinite(acc["max"]) else data_min
497
+ data_min = float(stats.get(f"{name}__min", 0) or 0)
498
+ data_max = float(stats.get(f"{name}__max", data_min) or data_min)
721
499
  scaler = MinMaxScaler()
722
500
  scaler.data_min_ = np.array([data_min], dtype=np.float64)
723
501
  scaler.data_max_ = np.array([data_max], dtype=np.float64)
724
502
  scaler.data_range_ = scaler.data_max_ - scaler.data_min_
725
503
  scaler.data_range_[scaler.data_range_ == 0] = 1.0
726
- # Manually set scale_/min_ for streaming fit to mirror sklearn's internal fit logic
727
504
  feature_min, feature_max = scaler.feature_range
728
505
  scale = (feature_max - feature_min) / scaler.data_range_
729
506
  scaler.scale_ = scale
730
507
  scaler.min_ = feature_min - scaler.data_min_ * scale
731
- scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
508
+ scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
732
509
  self.scalers[name] = scaler
733
-
734
510
  elif scaler_type == "maxabs":
511
+ max_abs = float(stats.get(f"{name}__max_abs", 1.0) or 1.0)
735
512
  scaler = MaxAbsScaler()
736
- scaler.max_abs_ = np.array([acc["max_abs"]], dtype=np.float64)
737
- scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
513
+ scaler.max_abs_ = np.array([max_abs], dtype=np.float64)
514
+ scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
738
515
  self.scalers[name] = scaler
739
-
740
- elif scaler_type in ("log", "none", "robust"):
741
- # log and none do not require fitting; robust requires full data and is handled earlier
516
+ elif scaler_type == "robust":
517
+ q1 = float(stats.get(f"{name}__q1", 0) or 0)
518
+ q3 = float(stats.get(f"{name}__q3", q1) or q1)
519
+ median = float(stats.get(f"{name}__median", 0) or 0)
520
+ scale = q3 - q1
521
+ if scale == 0:
522
+ scale = 1.0
523
+ scaler = RobustScaler()
524
+ scaler.center_ = np.array([median], dtype=np.float64)
525
+ scaler.scale_ = np.array([scale], dtype=np.float64)
526
+ scaler.n_samples_seen_ = np.array([int(count)], dtype=np.int64)
527
+ self.scalers[name] = scaler
528
+ elif scaler_type in ("log", "none"):
742
529
  continue
743
530
  else:
744
531
  raise ValueError(f"Unknown scaler type: {scaler_type}")
745
532
 
746
- # finalize sparse label encoders
533
+ # sparse features
747
534
  for name, config in self.sparse_features.items():
748
- if config["encode_method"] == "label":
535
+ if name not in schema:
536
+ continue
537
+ encode_method = config["encode_method"]
538
+ fill_na = config["fill_na"]
539
+ col = pl.col(name).cast(pl.Utf8).fill_null(fill_na)
540
+ counts_df = (
541
+ lf.select(col.alias(name))
542
+ .group_by(name)
543
+ .agg(pl.len().alias("count"))
544
+ .collect()
545
+ )
546
+ counts = (
547
+ dict(zip(counts_df[name].to_list(), counts_df["count"].to_list()))
548
+ if counts_df.height > 0
549
+ else {}
550
+ )
551
+ if encode_method == "label":
749
552
  min_freq = config.get("min_freq")
750
553
  if min_freq is not None:
751
- token_counts = sparse_label_counts.get(name, {})
752
- config["_token_counts"] = token_counts
554
+ config["_token_counts"] = counts
753
555
  vocab = {
754
- token
755
- for token, count in token_counts.items()
756
- if count >= min_freq
556
+ token for token, count in counts.items() if count >= min_freq
757
557
  }
758
558
  low_freq_types = sum(
759
- 1 for count in token_counts.values() if count < min_freq
559
+ 1 for count in counts.values() if count < min_freq
760
560
  )
761
- total_types = len(token_counts)
561
+ total_types = len(counts)
762
562
  kept_types = total_types - low_freq_types
763
563
  if not config.get("_min_freq_logged"):
764
564
  logger.info(
@@ -769,29 +569,29 @@ class DataProcessor(FeatureSet):
769
569
  )
770
570
  config["_min_freq_logged"] = True
771
571
  else:
772
- vocab = sparse_vocab[name]
572
+ vocab = set(counts.keys())
773
573
  if not vocab:
774
574
  logger.warning(f"Sparse feature {name} has empty vocabulary")
775
575
  continue
776
- vocab_list = sorted(vocab)
576
+ # Filter out None values before sorting to avoid comparison errors
577
+ vocab_list = sorted(v for v in vocab if v is not None)
777
578
  if "<UNK>" not in vocab_list:
778
579
  vocab_list.append("<UNK>")
779
580
  token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
780
581
  config["_token_to_idx"] = token_to_idx
781
582
  config["_unk_index"] = token_to_idx["<UNK>"]
782
583
  config["vocab_size"] = len(vocab_list)
783
- elif config["encode_method"] == "hash":
584
+ elif encode_method == "hash":
784
585
  min_freq = config.get("min_freq")
785
586
  if min_freq is not None:
786
- token_counts = sparse_hash_counts.get(name, {})
787
- config["_token_counts"] = token_counts
587
+ config["_token_counts"] = counts
788
588
  config["_unk_hash"] = self.hash_string(
789
589
  "<UNK>", int(config["hash_size"])
790
590
  )
791
591
  low_freq_types = sum(
792
- 1 for count in token_counts.values() if count < min_freq
592
+ 1 for count in counts.values() if count < min_freq
793
593
  )
794
- total_types = len(token_counts)
594
+ total_types = len(counts)
795
595
  kept_types = total_types - low_freq_types
796
596
  if not config.get("_min_freq_logged"):
797
597
  logger.info(
@@ -803,22 +603,37 @@ class DataProcessor(FeatureSet):
803
603
  config["_min_freq_logged"] = True
804
604
  config["vocab_size"] = config["hash_size"]
805
605
 
806
- # finalize sequence vocabularies
606
+ # sequence features
807
607
  for name, config in self.sequence_features.items():
808
- if config["encode_method"] == "label":
608
+ if name not in schema:
609
+ continue
610
+ encode_method = config["encode_method"]
611
+ seq_col = self.sequence_expr(pl, name, config, schema)
612
+ tokens_df = (
613
+ lf.select(seq_col.alias("seq"))
614
+ .explode("seq")
615
+ .select(pl.col("seq").cast(pl.Utf8).alias("seq"))
616
+ .drop_nulls("seq")
617
+ .group_by("seq")
618
+ .agg(pl.len().alias("count"))
619
+ .collect()
620
+ )
621
+ counts = (
622
+ dict(zip(tokens_df["seq"].to_list(), tokens_df["count"].to_list()))
623
+ if tokens_df.height > 0
624
+ else {}
625
+ )
626
+ if encode_method == "label":
809
627
  min_freq = config.get("min_freq")
810
628
  if min_freq is not None:
811
- token_counts = seq_label_counts.get(name, {})
812
- config["_token_counts"] = token_counts
629
+ config["_token_counts"] = counts
813
630
  vocab_set = {
814
- token
815
- for token, count in token_counts.items()
816
- if count >= min_freq
631
+ token for token, count in counts.items() if count >= min_freq
817
632
  }
818
633
  low_freq_types = sum(
819
- 1 for count in token_counts.values() if count < min_freq
634
+ 1 for count in counts.values() if count < min_freq
820
635
  )
821
- total_types = len(token_counts)
636
+ total_types = len(counts)
822
637
  kept_types = total_types - low_freq_types
823
638
  if not config.get("_min_freq_logged"):
824
639
  logger.info(
@@ -829,26 +644,30 @@ class DataProcessor(FeatureSet):
829
644
  )
830
645
  config["_min_freq_logged"] = True
831
646
  else:
832
- vocab_set = seq_vocab[name]
833
- vocab_list = sorted(vocab_set) if vocab_set else ["<PAD>"]
647
+ vocab_set = set(counts.keys())
648
+ # Filter out None values before sorting to avoid comparison errors
649
+ vocab_list = (
650
+ sorted(v for v in vocab_set if v is not None)
651
+ if vocab_set
652
+ else ["<PAD>"]
653
+ )
834
654
  if "<UNK>" not in vocab_list:
835
655
  vocab_list.append("<UNK>")
836
656
  token_to_idx = {token: idx for idx, token in enumerate(vocab_list)}
837
657
  config["_token_to_idx"] = token_to_idx
838
658
  config["_unk_index"] = token_to_idx["<UNK>"]
839
659
  config["vocab_size"] = len(vocab_list)
840
- elif config["encode_method"] == "hash":
660
+ elif encode_method == "hash":
841
661
  min_freq = config.get("min_freq")
842
662
  if min_freq is not None:
843
- token_counts = seq_hash_counts.get(name, {})
844
- config["_token_counts"] = token_counts
663
+ config["_token_counts"] = counts
845
664
  config["_unk_hash"] = self.hash_string(
846
665
  "<UNK>", int(config["hash_size"])
847
666
  )
848
667
  low_freq_types = sum(
849
- 1 for count in token_counts.values() if count < min_freq
668
+ 1 for count in counts.values() if count < min_freq
850
669
  )
851
- total_types = len(token_counts)
670
+ total_types = len(counts)
852
671
  kept_types = total_types - low_freq_types
853
672
  if not config.get("_min_freq_logged"):
854
673
  logger.info(
@@ -860,14 +679,18 @@ class DataProcessor(FeatureSet):
860
679
  config["_min_freq_logged"] = True
861
680
  config["vocab_size"] = config["hash_size"]
862
681
 
863
- # finalize targets
682
+ # targets
864
683
  for name, config in self.target_features.items():
865
- if not target_values[name]:
866
- logger.warning(f"Target {name} has no valid values in provided files")
684
+ if name not in schema:
867
685
  continue
868
- self.process_target_fit(
869
- pd.Series(list(target_values[name]), name=name), config
870
- )
686
+ if config.get("target_type") == "binary":
687
+ unique_vals = (
688
+ lf.select(pl.col(name).drop_nulls().unique())
689
+ .collect()
690
+ .to_series()
691
+ .to_list()
692
+ )
693
+ self.process_target_fit(unique_vals, config, name)
871
694
 
872
695
  self.is_fitted = True
873
696
  logger.info(
@@ -879,18 +702,11 @@ class DataProcessor(FeatureSet):
879
702
  )
880
703
  return self
881
704
 
882
- def fit_from_files(
883
- self, file_paths: list[str], file_type: str, chunk_size: int
884
- ) -> "DataProcessor":
885
- """Fit processor statistics by streaming an explicit list of files.
886
-
887
- This is useful when you want to fit statistics on training files only (exclude
888
- validation files) in streaming mode.
889
- """
705
+ def fit_from_files(self, file_paths: list[str], file_type: str) -> "DataProcessor":
890
706
  logger = logging.getLogger()
891
707
  logger.info(
892
708
  colorize(
893
- "Fitting DataProcessor (streaming files mode)...",
709
+ "Fitting DataProcessor...",
894
710
  color="cyan",
895
711
  bold=True,
896
712
  )
@@ -899,34 +715,15 @@ class DataProcessor(FeatureSet):
899
715
  config.pop("_min_freq_logged", None)
900
716
  for config in self.sequence_features.values():
901
717
  config.pop("_min_freq_logged", None)
902
- uses_robust = any(
903
- cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
904
- )
905
- if uses_robust:
906
- logger.warning(
907
- "Robust scaler requires full data; loading provided files into memory. "
908
- "Consider smaller chunk_size or different scaler if memory is limited."
909
- )
910
- frames = [read_table(p, file_type) for p in file_paths]
911
- df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
912
- return self.fit(df)
913
- return self.fit_from_file_paths(file_paths=file_paths, file_type=file_type, chunk_size=chunk_size)
914
-
915
- def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
916
- """
917
- Fit processor statistics by streaming files to reduce memory usage.
718
+ lf = self.polars_scan(file_paths, file_type)
719
+ schema = lf.collect_schema()
720
+ return self.polars_fit_from_lazy(lf, schema)
918
721
 
919
- Args:
920
- path (str): File or directory path.
921
- chunk_size (int): Number of rows per chunk.
922
-
923
- Returns:
924
- DataProcessor: Fitted DataProcessor instance.
925
- """
722
+ def fit_from_path(self, path: str) -> "DataProcessor":
926
723
  logger = logging.getLogger()
927
724
  logger.info(
928
725
  colorize(
929
- "Fitting DataProcessor (streaming path mode)...",
726
+ "Fitting DataProcessor...",
930
727
  color="cyan",
931
728
  bold=True,
932
729
  )
@@ -936,118 +733,35 @@ class DataProcessor(FeatureSet):
936
733
  for config in self.sequence_features.values():
937
734
  config.pop("_min_freq_logged", None)
938
735
  file_paths, file_type = resolve_file_paths(path)
939
- return self.fit_from_file_paths(
940
- file_paths=file_paths,
941
- file_type=file_type,
942
- chunk_size=chunk_size,
943
- )
944
-
945
- @overload
946
- def transform_in_memory(
947
- self,
948
- data: Union[pd.DataFrame, Dict[str, Any]],
949
- return_dict: Literal[True],
950
- persist: bool,
951
- save_format: Optional[str],
952
- output_path: Optional[str],
953
- warn_missing: bool = True,
954
- ) -> Dict[str, np.ndarray]: ...
736
+ return self.fit_from_files(file_paths=file_paths, file_type=file_type)
955
737
 
956
- @overload
957
738
  def transform_in_memory(
958
739
  self,
959
- data: Union[pd.DataFrame, Dict[str, Any]],
960
- return_dict: Literal[False],
961
- persist: bool,
962
- save_format: Optional[str],
963
- output_path: Optional[str],
964
- warn_missing: bool = True,
965
- ) -> pd.DataFrame: ...
966
-
967
- def transform_in_memory(
968
- self,
969
- data: Union[pd.DataFrame, Dict[str, Any]],
740
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
970
741
  return_dict: bool,
971
742
  persist: bool,
972
743
  save_format: Optional[str],
973
744
  output_path: Optional[str],
974
745
  warn_missing: bool = True,
975
746
  ):
976
- """
977
- Transform in-memory data and optionally persist the transformed data.
978
-
979
- Args:
980
- data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
981
- return_dict (bool): Whether to return a dictionary of numpy arrays.
982
- persist (bool): Whether to persist the transformed data to disk.
983
- save_format (Optional[str]): Format to save the data if persisting.
984
- output_path (Optional[str]): Output path to save the data if persisting.
985
- warn_missing (bool): Whether to warn about missing features in the data.
986
-
987
- Returns:
988
- Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
989
- """
990
-
991
747
  logger = logging.getLogger()
992
- data_dict = data if isinstance(data, dict) else None
993
748
 
994
- result_dict = {}
995
- if isinstance(data, pd.DataFrame):
996
- df = data # type: ignore[assignment]
997
- for col in df.columns:
998
- result_dict[col] = df[col].to_numpy(copy=False)
749
+ if isinstance(data, dict):
750
+ df = pl.DataFrame(data)
751
+ elif isinstance(data, pd.DataFrame):
752
+ df = pl.from_pandas(data)
999
753
  else:
1000
- if data_dict is None:
1001
- raise ValueError(
1002
- f"[Data Processor Error] Unsupported data type: {type(data)}"
1003
- )
1004
- for key, value in data_dict.items():
1005
- if isinstance(value, pd.Series):
1006
- result_dict[key] = value.to_numpy(copy=False)
1007
- else:
1008
- result_dict[key] = np.asarray(value)
1009
-
1010
- data_columns = data.columns if isinstance(data, pd.DataFrame) else data_dict
1011
- feature_groups = [
1012
- ("Numeric", self.numeric_features, self.process_numeric_feature_transform),
1013
- ("Sparse", self.sparse_features, self.process_sparse_feature_transform),
1014
- (
1015
- "Sequence",
1016
- self.sequence_features,
1017
- self.process_sequence_feature_transform,
1018
- ),
1019
- ("Target", self.target_features, self.process_target_transform),
1020
- ]
1021
- for label, features, transform_fn in feature_groups:
1022
- for name, config in features.items():
1023
- present = name in data_columns # type: ignore[operator]
1024
- if not present:
1025
- if warn_missing:
1026
- logger.warning(f"{label} feature {name} not found in data")
1027
- continue
1028
- series_data = (
1029
- data[name]
1030
- if isinstance(data, pd.DataFrame)
1031
- else pd.Series(result_dict[name], name=name)
1032
- )
1033
- result_dict[name] = transform_fn(series_data, config)
1034
-
1035
- def dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
1036
- # Convert all arrays to Series/lists at once to avoid fragmentation
1037
- columns_dict = {}
1038
- for key, value in result.items():
1039
- if key in self.sequence_features:
1040
- columns_dict[key] = np.asarray(value).tolist()
1041
- else:
1042
- columns_dict[key] = value
1043
- return pd.DataFrame(columns_dict)
754
+ df = data
755
+
756
+ schema = df.schema
757
+ lf = df.lazy()
758
+ lf = self.apply_transforms(lf, schema, warn_missing=warn_missing)
759
+ out_df = lf.collect()
1044
760
 
1045
761
  effective_format = save_format
1046
762
  if persist:
1047
763
  effective_format = save_format or "parquet"
1048
- result_df = None
1049
- if (not return_dict) or persist:
1050
- result_df = dict_to_dataframe(result_dict)
764
+
1051
765
  if persist:
1052
766
  if effective_format not in FILE_FORMAT_CONFIG:
1053
767
  raise ValueError(f"Unsupported save format: {effective_format}")
@@ -1059,68 +773,63 @@ class DataProcessor(FeatureSet):
1059
773
  if output_dir.suffix:
1060
774
  output_dir = output_dir.parent
1061
775
  output_dir.mkdir(parents=True, exist_ok=True)
1062
-
1063
776
  suffix = FILE_FORMAT_CONFIG[effective_format]["extension"][0]
1064
777
  save_path = output_dir / f"transformed_data{suffix}"
1065
- assert result_df is not None, "DataFrame conversion failed"
1066
-
1067
- # Save based on format
1068
778
  if effective_format == "csv":
1069
- result_df.to_csv(save_path, index=False)
779
+ out_df.write_csv(save_path)
1070
780
  elif effective_format == "parquet":
1071
- result_df.to_parquet(save_path, index=False)
781
+ out_df.write_parquet(save_path)
1072
782
  elif effective_format == "feather":
1073
- result_df.to_feather(save_path)
1074
- elif effective_format == "excel":
1075
- result_df.to_excel(save_path, index=False)
1076
- elif effective_format == "hdf5":
1077
- result_df.to_hdf(save_path, key="data", mode="w")
783
+ out_df.write_ipc(save_path)
1078
784
  else:
1079
- raise ValueError(f"Unsupported save format: {effective_format}")
1080
-
785
+ raise ValueError(
786
+ f"Format '{effective_format}' is not supported by the polars-only pipeline."
787
+ )
1081
788
  logger.info(
1082
789
  colorize(
1083
- f"Transformed data saved to: {save_path.resolve()}", color="green"
790
+ f"Transformed data saved to: {save_path.resolve()}",
791
+ color="green",
1084
792
  )
1085
793
  )
794
+
1086
795
  if return_dict:
796
+ result_dict = {}
797
+ for col in out_df.columns:
798
+ series = out_df.get_column(col)
799
+ if col in self.sequence_features:
800
+ result_dict[col] = np.asarray(series.to_list(), dtype=np.int64)
801
+ else:
802
+ result_dict[col] = series.to_numpy()
1087
803
  return result_dict
1088
- assert result_df is not None, "DataFrame is None after transform"
1089
- return result_df
804
+
805
+ return out_df
1090
806
 
1091
807
  def transform_path(
1092
808
  self,
1093
809
  input_path: str,
1094
810
  output_path: Optional[str],
1095
811
  save_format: Optional[str],
1096
- chunk_size: int = 200000,
1097
812
  ):
1098
- """Transform data from files under a path and save them to a new location.
1099
-
1100
- Uses chunked reading/writing to keep peak memory bounded for large files.
1101
-
1102
- Args:
1103
- input_path (str): Input file or directory path.
1104
- output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
1105
- save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
1106
- chunk_size (int): Number of rows per chunk.
1107
- """
813
+ """Transform data from files under a path and save them using polars lazy pipeline."""
1108
814
  logger = logging.getLogger()
1109
815
  file_paths, file_type = resolve_file_paths(input_path)
1110
816
  target_format = save_format or file_type
1111
817
  if target_format not in FILE_FORMAT_CONFIG:
1112
818
  raise ValueError(f"Unsupported format: {target_format}")
1113
- if chunk_size > 0 and not check_streaming_support(file_type):
819
+ if target_format not in {"csv", "parquet", "feather"}:
820
+ raise ValueError(
821
+ f"Format '{target_format}' is not supported by the polars-only pipeline."
822
+ )
823
+ if not check_streaming_support(file_type):
1114
824
  raise ValueError(
1115
825
  f"Input format '{file_type}' does not support streaming reads. "
1116
- "Set chunk_size<=0 to use full-load transform."
826
+ "Polars backend supports csv/parquet only."
1117
827
  )
1118
828
 
1119
- # Warn about streaming support
1120
829
  if not check_streaming_support(target_format):
1121
830
  logger.warning(
1122
831
  f"[Data Processor Warning] Format '{target_format}' does not support streaming writes. "
1123
- "Large files may require more memory. Use csv or parquet for better streaming support."
832
+ "Data will be collected in memory before saving."
1124
833
  )
1125
834
 
1126
835
  base_output_dir = (
@@ -1131,122 +840,48 @@ class DataProcessor(FeatureSet):
1131
840
  output_root = base_output_dir / "transformed_data"
1132
841
  output_root.mkdir(parents=True, exist_ok=True)
1133
842
  saved_paths = []
843
+
1134
844
  for file_path in progress(file_paths, description="Transforming files"):
1135
845
  source_path = Path(file_path)
1136
846
  suffix = FILE_FORMAT_CONFIG[target_format]["extension"][0]
1137
847
  target_file = output_root / f"{source_path.stem}{suffix}"
1138
848
 
1139
- # Stream transform for large files
1140
- if chunk_size <= 0:
1141
- # fallback to full load behavior
1142
- df = read_table(file_path, file_type)
1143
- transformed_df = self.transform_in_memory(
1144
- df,
1145
- return_dict=False,
1146
- persist=False,
1147
- save_format=None,
1148
- output_path=None,
1149
- warn_missing=True,
1150
- )
1151
- assert isinstance(
1152
- transformed_df, pd.DataFrame
1153
- ), "[Data Processor Error] Expected DataFrame when return_dict=False"
1154
-
1155
- # Save based on format
1156
- if target_format == "csv":
1157
- transformed_df.to_csv(target_file, index=False)
1158
- elif target_format == "parquet":
1159
- transformed_df.to_parquet(target_file, index=False)
1160
- elif target_format == "feather":
1161
- transformed_df.to_feather(target_file)
1162
- elif target_format == "excel":
1163
- transformed_df.to_excel(target_file, index=False)
1164
- elif target_format == "hdf5":
1165
- transformed_df.to_hdf(target_file, key="data", mode="w")
1166
- else:
1167
- raise ValueError(f"Unsupported format: {target_format}")
1168
-
1169
- saved_paths.append(str(target_file.resolve()))
1170
- continue
849
+ lf = self.polars_scan([file_path], file_type)
850
+ schema = lf.collect_schema()
851
+ lf = self.apply_transforms(lf, schema, warn_missing=True)
1171
852
 
1172
- first_chunk = True
1173
- # Streaming write for supported formats
1174
853
  if target_format == "parquet":
1175
- parquet_writer = None
1176
- try:
1177
- for chunk in iter_file_chunks(file_path, file_type, chunk_size):
1178
- transformed_df = self.transform_in_memory(
1179
- chunk,
1180
- return_dict=False,
1181
- persist=False,
1182
- save_format=None,
1183
- output_path=None,
1184
- warn_missing=first_chunk,
1185
- )
1186
- assert isinstance(
1187
- transformed_df, pd.DataFrame
1188
- ), "[Data Processor Error] Expected DataFrame when return_dict=False"
1189
- table = pa.Table.from_pandas(
1190
- transformed_df, preserve_index=False
1191
- )
1192
- if parquet_writer is None:
1193
- parquet_writer = pq.ParquetWriter(target_file, table.schema)
1194
- parquet_writer.write_table(table)
1195
- first_chunk = False
1196
- finally:
1197
- if parquet_writer is not None:
1198
- parquet_writer.close()
854
+ lf.sink_parquet(target_file)
1199
855
  elif target_format == "csv":
1200
- # CSV: append chunks; header only once
1201
- target_file.parent.mkdir(parents=True, exist_ok=True)
1202
- with open(target_file, "w", encoding="utf-8", newline="") as f:
1203
- f.write("")
1204
- for chunk in iter_file_chunks(file_path, file_type, chunk_size):
1205
- transformed_df = self.transform_in_memory(
1206
- chunk,
1207
- return_dict=False,
1208
- persist=False,
1209
- save_format=None,
1210
- output_path=None,
1211
- warn_missing=first_chunk,
1212
- )
1213
- assert isinstance(
1214
- transformed_df, pd.DataFrame
1215
- ), "[Data Processor Error] Expected DataFrame when return_dict=False"
1216
- transformed_df.to_csv(
1217
- target_file, index=False, mode="a", header=first_chunk
1218
- )
1219
- first_chunk = False
856
+ # CSV doesn't support nested data (lists), so convert list columns to string
857
+ transformed_schema = lf.collect_schema()
858
+ list_cols = [
859
+ name
860
+ for name, dtype in transformed_schema.items()
861
+ if isinstance(dtype, pl.List)
862
+ ]
863
+ if list_cols:
864
+ # Convert list columns to string representation for CSV
865
+ # Format as [1, 2, 3] by casting elements to string, joining with ", ", and adding brackets
866
+ list_exprs = []
867
+ for name in list_cols:
868
+ # Convert list to string representation
869
+ list_exprs.append(
870
+ (
871
+ pl.lit("[")
872
+ + pl.col(name)
873
+ .list.eval(pl.element().cast(pl.String))
874
+ .list.join(", ")
875
+ + pl.lit("]")
876
+ ).alias(name)
877
+ )
878
+ lf = lf.with_columns(list_exprs)
879
+ lf.sink_csv(target_file)
1220
880
  else:
1221
- # Non-streaming formats: collect all chunks and save once
1222
- logger.warning(
1223
- f"Format '{target_format}' doesn't support streaming writes. "
1224
- f"Collecting all chunks in memory before saving."
1225
- )
1226
- all_chunks = []
1227
- for chunk in iter_file_chunks(file_path, file_type, chunk_size):
1228
- transformed_df = self.transform_in_memory(
1229
- chunk,
1230
- return_dict=False,
1231
- persist=False,
1232
- save_format=None,
1233
- output_path=None,
1234
- warn_missing=first_chunk,
1235
- )
1236
- assert isinstance(transformed_df, pd.DataFrame)
1237
- all_chunks.append(transformed_df)
1238
- first_chunk = False
1239
-
1240
- if all_chunks:
1241
- combined_df = pd.concat(all_chunks, ignore_index=True)
1242
- if target_format == "feather":
1243
- combined_df.to_feather(target_file)
1244
- elif target_format == "excel":
1245
- combined_df.to_excel(target_file, index=False)
1246
- elif target_format == "hdf5":
1247
- combined_df.to_hdf(target_file, key="data", mode="w")
1248
-
881
+ df = lf.collect()
882
+ df.write_ipc(target_file)
1249
883
  saved_paths.append(str(target_file.resolve()))
884
+
1250
885
  logger.info(
1251
886
  colorize(
1252
887
  f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
@@ -1258,74 +893,51 @@ class DataProcessor(FeatureSet):
1258
893
  # fit is nothing but registering the statistics from data so that we can transform the data later
1259
894
  def fit(
1260
895
  self,
1261
- data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
1262
- chunk_size: int = 200000,
896
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
1263
897
  ):
1264
898
  """
1265
899
  Fit the DataProcessor to the provided data.
1266
900
 
1267
901
  Args:
1268
- data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
1269
- chunk_size (int): Number of rows per chunk when streaming from path.
902
+ data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
1270
903
 
1271
904
  Returns:
1272
905
  DataProcessor: Fitted DataProcessor instance.
1273
906
  """
1274
907
 
1275
- logger = logging.getLogger()
1276
908
  for config in self.sparse_features.values():
1277
909
  config.pop("_min_freq_logged", None)
1278
910
  for config in self.sequence_features.values():
1279
911
  config.pop("_min_freq_logged", None)
1280
912
  if isinstance(data, (str, os.PathLike)):
1281
- path_str = str(data)
1282
- uses_robust = any(
1283
- cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
1284
- )
1285
- if uses_robust:
1286
- logger.warning(
1287
- "Robust scaler requires full data; loading all files into memory. Consider smaller chunk_size or different scaler if memory is limited."
1288
- )
1289
- data = self.load_dataframe_from_path(path_str)
1290
- else:
1291
- return self.fit_from_path(path_str, chunk_size)
913
+ return self.fit_from_path(str(data))
1292
914
  if isinstance(data, dict):
1293
- data = pd.DataFrame(data)
1294
- logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
1295
- feature_groups = [
1296
- ("Numeric", self.numeric_features, self.process_numeric_feature_fit),
1297
- ("Sparse", self.sparse_features, self.process_sparse_feature_fit),
1298
- ("Sequence", self.sequence_features, self.process_sequence_feature_fit),
1299
- ("Target", self.target_features, self.process_target_fit),
1300
- ]
1301
- for label, features, fit_fn in feature_groups:
1302
- for name, config in features.items():
1303
- if name not in data.columns:
1304
- logger.warning(f"{label} feature {name} not found in data")
1305
- continue
1306
- fit_fn(data[name], config)
1307
- self.is_fitted = True
1308
- return self
915
+ df = pl.DataFrame(data)
916
+ elif isinstance(data, pd.DataFrame):
917
+ df = pl.from_pandas(data)
918
+ else:
919
+ df = data
920
+ lf = df.lazy()
921
+ schema = df.schema
922
+ return self.polars_fit_from_lazy(lf, schema)
1309
923
 
1310
924
  @overload
1311
925
  def transform(
1312
926
  self,
1313
- data: Union[pd.DataFrame, Dict[str, Any]],
927
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
1314
928
  return_dict: Literal[True] = True,
1315
929
  save_format: Optional[str] = None,
1316
930
  output_path: Optional[str] = None,
1317
- chunk_size: int = 200000,
1318
931
  ) -> Dict[str, np.ndarray]: ...
1319
932
 
1320
933
  @overload
1321
934
  def transform(
1322
935
  self,
1323
- data: Union[pd.DataFrame, Dict[str, Any]],
936
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any]],
1324
937
  return_dict: Literal[False] = False,
1325
938
  save_format: Optional[str] = None,
1326
939
  output_path: Optional[str] = None,
1327
- chunk_size: int = 200000,
1328
- ) -> pd.DataFrame: ...
940
+ ) -> pl.DataFrame: ...
1329
941
 
1330
942
  @overload
1331
943
  def transform(
@@ -1334,28 +946,25 @@ class DataProcessor(FeatureSet):
1334
946
  return_dict: Literal[False] = False,
1335
947
  save_format: Optional[str] = None,
1336
948
  output_path: Optional[str] = None,
1337
- chunk_size: int = 200000,
1338
949
  ) -> list[str]: ...
1339
950
 
1340
951
  def transform(
1341
952
  self,
1342
- data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
953
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
1343
954
  return_dict: bool = True,
1344
955
  save_format: Optional[str] = None,
1345
956
  output_path: Optional[str] = None,
1346
- chunk_size: int = 200000,
1347
957
  ):
1348
958
  """
1349
959
  Transform the provided data using the fitted DataProcessor.
1350
960
 
1351
961
  Args:
1352
- data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
962
+ data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
1353
963
  return_dict (bool): Whether to return a dictionary of numpy arrays.
1354
964
  save_format (Optional[str]): Format to save the data if output_path is provided.
1355
965
  output_path (Optional[str]): Output path to save the transformed data.
1356
- chunk_size (int): Number of rows per chunk when streaming from path.
1357
966
  Returns:
1358
- Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
967
+ Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1359
968
  """
1360
969
 
1361
970
  if not self.is_fitted:
@@ -1367,9 +976,7 @@ class DataProcessor(FeatureSet):
1367
976
  raise ValueError(
1368
977
  "[Data Processor Error] Path transform writes files only; set return_dict=False when passing a path."
1369
978
  )
1370
- return self.transform_path(
1371
- str(data), output_path, save_format, chunk_size=chunk_size
1372
- )
979
+ return self.transform_path(str(data), output_path, save_format)
1373
980
  return self.transform_in_memory(
1374
981
  data=data,
1375
982
  return_dict=return_dict,
@@ -1380,26 +987,24 @@ class DataProcessor(FeatureSet):
1380
987
 
1381
988
  def fit_transform(
1382
989
  self,
1383
- data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
990
+ data: Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike],
1384
991
  return_dict: bool = True,
1385
992
  save_format: Optional[str] = None,
1386
993
  output_path: Optional[str] = None,
1387
- chunk_size: int = 200000,
1388
994
  ):
1389
995
  """
1390
996
  Fit the DataProcessor to the provided data and then transform it.
1391
997
 
1392
998
  Args:
1393
- data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
999
+ data (Union[pl.DataFrame, pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
1394
1000
  return_dict (bool): Whether to return a dictionary of numpy arrays.
1395
1001
  save_format (Optional[str]): Format to save the data if output_path is provided.
1396
- output_path (Optional[str]): Output path to save the transformed data.
1397
- chunk_size (int): Number of rows per chunk when streaming from path.
1002
+ output_path (Optional[str]): Output path to save the data.
1398
1003
  Returns:
1399
- Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1004
+ Union[pl.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1400
1005
  """
1401
1006
 
1402
- self.fit(data, chunk_size=chunk_size)
1007
+ self.fit(data)
1403
1008
  return self.transform(
1404
1009
  data,
1405
1010
  return_dict=return_dict,