nextrec 0.1.10__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +1 -2
  3. nextrec/basic/callback.py +1 -2
  4. nextrec/basic/features.py +39 -8
  5. nextrec/basic/layers.py +1 -2
  6. nextrec/basic/loggers.py +15 -10
  7. nextrec/basic/metrics.py +1 -2
  8. nextrec/basic/model.py +87 -84
  9. nextrec/basic/session.py +150 -0
  10. nextrec/data/__init__.py +13 -2
  11. nextrec/data/data_utils.py +74 -22
  12. nextrec/data/dataloader.py +513 -0
  13. nextrec/data/preprocessor.py +494 -134
  14. nextrec/loss/listwise.py +6 -0
  15. nextrec/loss/loss_utils.py +1 -2
  16. nextrec/loss/match_losses.py +4 -5
  17. nextrec/loss/pairwise.py +6 -0
  18. nextrec/loss/pointwise.py +6 -0
  19. nextrec/models/match/dssm.py +2 -2
  20. nextrec/models/match/dssm_v2.py +2 -2
  21. nextrec/models/match/mind.py +2 -2
  22. nextrec/models/match/sdm.py +2 -2
  23. nextrec/models/match/youtube_dnn.py +2 -2
  24. nextrec/models/multi_task/esmm.py +3 -3
  25. nextrec/models/multi_task/mmoe.py +3 -3
  26. nextrec/models/multi_task/ple.py +3 -3
  27. nextrec/models/multi_task/share_bottom.py +3 -3
  28. nextrec/models/ranking/afm.py +2 -3
  29. nextrec/models/ranking/autoint.py +3 -3
  30. nextrec/models/ranking/dcn.py +3 -3
  31. nextrec/models/ranking/deepfm.py +2 -3
  32. nextrec/models/ranking/dien.py +3 -3
  33. nextrec/models/ranking/din.py +3 -3
  34. nextrec/models/ranking/fibinet.py +3 -3
  35. nextrec/models/ranking/fm.py +3 -3
  36. nextrec/models/ranking/masknet.py +3 -3
  37. nextrec/models/ranking/pnn.py +3 -3
  38. nextrec/models/ranking/widedeep.py +3 -3
  39. nextrec/models/ranking/xdeepfm.py +3 -3
  40. nextrec/utils/__init__.py +4 -8
  41. nextrec/utils/embedding.py +2 -4
  42. nextrec/utils/initializer.py +1 -2
  43. nextrec/utils/optimizer.py +1 -2
  44. {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/METADATA +4 -5
  45. nextrec-0.2.1.dist-info/RECORD +54 -0
  46. nextrec/basic/dataloader.py +0 -447
  47. nextrec/utils/common.py +0 -14
  48. nextrec-0.1.10.dist-info/RECORD +0 -51
  49. {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/WHEEL +0 -0
  50. {nextrec-0.1.10.dist-info → nextrec-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,17 +2,17 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Author:
6
- Yang Zhou, zyaztec@gmail.com
5
+ Author: Yang Zhou, zyaztec@gmail.com
7
6
  """
8
-
7
+ from __future__ import annotations
9
8
  import os
10
- import pandas as pd
11
- import numpy as np
12
9
  import pickle
13
10
  import hashlib
14
11
  import logging
12
+ import numpy as np
13
+ import pandas as pd
15
14
 
15
+ from pathlib import Path
16
16
  from typing import Dict, Union, Optional, Literal, Any
17
17
  from sklearn.preprocessing import (
18
18
  StandardScaler,
@@ -22,11 +22,18 @@ from sklearn.preprocessing import (
22
22
  LabelEncoder
23
23
  )
24
24
 
25
-
26
25
  from nextrec.basic.loggers import setup_logger, colorize
26
+ from nextrec.data.data_utils import (
27
+ resolve_file_paths,
28
+ iter_file_chunks,
29
+ read_table,
30
+ load_dataframes,
31
+ default_output_dir,
32
+ )
33
+ from nextrec.basic.session import create_session, resolve_save_path
34
+ from nextrec.basic.features import FeatureConfig
27
35
 
28
-
29
- class DataProcessor:
36
+ class DataProcessor(FeatureConfig):
30
37
  """DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
31
38
 
32
39
  Examples:
@@ -47,23 +54,26 @@ class DataProcessor:
47
54
  >>> # Get vocabulary sizes for embedding layers
48
55
  >>> vocab_sizes = processor.get_vocab_sizes()
49
56
  """
50
- def __init__(self):
57
+ def __init__(self, session_id: str | None = None ):
51
58
  self.numeric_features: Dict[str, Dict[str, Any]] = {}
52
59
  self.sparse_features: Dict[str, Dict[str, Any]] = {}
53
60
  self.sequence_features: Dict[str, Dict[str, Any]] = {}
54
61
  self.target_features: Dict[str, Dict[str, Any]] = {}
55
-
62
+ self.session_id = session_id
63
+ self.session = create_session(session_id)
64
+
56
65
  self.is_fitted = False
57
66
  self._transform_summary_printed = False # Track if summary has been printed during transform
58
67
 
59
68
  self.scalers: Dict[str, Any] = {}
60
69
  self.label_encoders: Dict[str, LabelEncoder] = {}
61
70
  self.target_encoders: Dict[str, Dict[str, int]] = {}
71
+ self._set_target_config([], [])
62
72
 
63
73
  # Initialize logger if not already initialized
64
74
  self._logger_initialized = False
65
75
  if not logging.getLogger().hasHandlers():
66
- setup_logger()
76
+ setup_logger(session_id=self.session_id)
67
77
  self._logger_initialized = True
68
78
 
69
79
  def add_numeric_feature(
@@ -126,6 +136,7 @@ class DataProcessor:
126
136
  'target_type': target_type,
127
137
  'label_map': label_map
128
138
  }
139
+ self._set_target_config(list(self.target_features.keys()), [])
129
140
 
130
141
  def _hash_string(self, s: str, hash_size: int) -> int:
131
142
  return int(hashlib.md5(str(s).encode()).hexdigest(), 16) % hash_size
@@ -212,30 +223,35 @@ class DataProcessor:
212
223
  data: pd.Series,
213
224
  config: Dict[str, Any]
214
225
  ) -> np.ndarray:
215
-
226
+ """Fast path sparse feature transform using cached dict mapping or hashing."""
216
227
  name = str(data.name)
217
228
  encode_method = config['encode_method']
218
229
  fill_na = config['fill_na']
219
230
 
220
- filled_data = data.fillna(fill_na).astype(str)
221
-
231
+ sparse_series = pd.Series(data, name=name).fillna(fill_na).astype(str)
232
+
222
233
  if encode_method == 'label':
223
234
  le = self.label_encoders.get(name)
224
235
  if le is None:
225
236
  raise ValueError(f"LabelEncoder for {name} not fitted")
226
237
 
227
- result = []
228
- for val in filled_data:
229
- if val in le.classes_:
230
- encoded = le.transform([val])
231
- result.append(int(encoded[0]))
232
- else:
233
- result.append(0)
234
- return np.array(result, dtype=np.int64)
235
-
236
- elif encode_method == 'hash':
238
+ class_to_idx = config.get('_class_to_idx')
239
+ if class_to_idx is None:
240
+ class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
241
+ config['_class_to_idx'] = class_to_idx
242
+
243
+ encoded = sparse_series.map(class_to_idx)
244
+ encoded = encoded.fillna(0).astype(np.int64)
245
+ return encoded.to_numpy()
246
+
247
+ if encode_method == 'hash':
237
248
  hash_size = config['hash_size']
238
- return np.array([self._hash_string(val, hash_size) for val in filled_data], dtype=np.int64)
249
+ hash_fn = self._hash_string
250
+ return np.fromiter(
251
+ (hash_fn(v, hash_size) for v in sparse_series.to_numpy()),
252
+ dtype=np.int64,
253
+ count=sparse_series.size,
254
+ )
239
255
 
240
256
  return np.array([], dtype=np.int64)
241
257
 
@@ -282,64 +298,78 @@ class DataProcessor:
282
298
  data: pd.Series,
283
299
  config: Dict[str, Any]
284
300
  ) -> np.ndarray:
301
+ """Optimized sequence transform with preallocation and cached vocab map."""
285
302
  name = str(data.name)
286
303
  encode_method = config['encode_method']
287
304
  max_len = config['max_len']
288
305
  pad_value = config['pad_value']
289
306
  truncate = config['truncate']
290
307
  separator = config['separator']
291
-
292
- result = []
293
- for seq in data:
308
+
309
+ arr = np.asarray(data, dtype=object)
310
+ n = arr.shape[0]
311
+ output = np.full((n, max_len), pad_value, dtype=np.int64)
312
+
313
+ # Shared helpers cached locally for speed and cross-platform consistency
314
+ split_fn = str.split
315
+ is_nan = np.isnan
316
+
317
+ if encode_method == 'label':
318
+ le = self.label_encoders.get(name)
319
+ if le is None:
320
+ raise ValueError(f"LabelEncoder for {name} not fitted")
321
+ class_to_idx = config.get('_class_to_idx')
322
+ if class_to_idx is None:
323
+ class_to_idx = {cls: idx for idx, cls in enumerate(le.classes_)}
324
+ config['_class_to_idx'] = class_to_idx
325
+ else:
326
+ class_to_idx = None # type: ignore
327
+
328
+ hash_fn = self._hash_string
329
+ hash_size = config.get('hash_size')
330
+
331
+ for i, seq in enumerate(arr):
332
+ # normalize sequence to a list of strings
294
333
  tokens = []
295
-
296
334
  if seq is None:
297
335
  tokens = []
298
- elif isinstance(seq, (float, np.floating)) and np.isnan(seq):
299
- tokens = []
336
+ elif isinstance(seq, (float, np.floating)):
337
+ tokens = [] if is_nan(seq) else [str(seq)]
300
338
  elif isinstance(seq, str):
301
- if seq.strip() == '':
302
- tokens = []
303
- else:
304
- tokens = seq.split(separator)
305
- elif isinstance(seq, (list, tuple)):
339
+ seq_str = seq.strip()
340
+ tokens = [] if not seq_str else split_fn(seq_str, separator)
341
+ elif isinstance(seq, (list, tuple, np.ndarray)):
306
342
  tokens = [str(t) for t in seq]
307
- elif isinstance(seq, np.ndarray):
308
- tokens = [str(t) for t in seq.tolist()]
309
343
  else:
310
344
  tokens = []
311
-
345
+
312
346
  if encode_method == 'label':
313
- le = self.label_encoders.get(name)
314
- if le is None:
315
- raise ValueError(f"LabelEncoder for {name} not fitted")
316
-
317
- encoded = []
318
- for token in tokens:
319
- token_str = str(token).strip()
320
- if token_str and token_str in le.classes_:
321
- encoded_val = le.transform([token_str])
322
- encoded.append(int(encoded_val[0]))
323
- else:
324
- encoded.append(0) # UNK
347
+ encoded = [
348
+ class_to_idx.get(token.strip(), 0) # type: ignore[union-attr]
349
+ for token in tokens
350
+ if token is not None and token != ''
351
+ ]
352
+
325
353
  elif encode_method == 'hash':
326
- hash_size = config['hash_size']
327
- encoded = [self._hash_string(str(token), hash_size) for token in tokens if str(token).strip()]
354
+ if hash_size is None:
355
+ raise ValueError("hash_size must be set for hash encoding")
356
+ encoded = [
357
+ hash_fn(str(token), hash_size)
358
+ for token in tokens
359
+ if str(token).strip()
360
+ ]
328
361
  else:
329
362
  encoded = []
330
-
363
+
364
+ if not encoded:
365
+ continue
366
+
331
367
  if len(encoded) > max_len:
332
- if truncate == 'pre': # keep last max_len items
333
- encoded = encoded[-max_len:]
334
- else: # keep first max_len items
335
- encoded = encoded[:max_len]
336
- elif len(encoded) < max_len:
337
- padding = [pad_value] * (max_len - len(encoded))
338
- encoded = encoded + padding
339
-
340
- result.append(encoded)
341
-
342
- return np.array(result, dtype=np.int64)
368
+ encoded = encoded[-max_len:] if truncate == 'pre' else encoded[:max_len]
369
+
370
+ output[i, : len(encoded)] = encoded
371
+
372
+ return output
343
373
 
344
374
  def _process_target_fit(self, data: pd.Series, config: Dict[str, Any]):
345
375
  name = str(data.name)
@@ -393,54 +423,212 @@ class DataProcessor:
393
423
 
394
424
  return np.array(result, dtype=np.int64 if target_type == 'multiclass' else np.float32)
395
425
 
396
- # fit is nothing but registering the statistics from data so that we can transform the data later
397
- def fit(self, data: Union[pd.DataFrame, Dict[str, Any]]):
426
+ def _load_dataframe_from_path(self, path: str) -> pd.DataFrame:
427
+ """Load all data from a file or directory path into a single DataFrame."""
428
+ file_paths, file_type = resolve_file_paths(path)
429
+ frames = load_dataframes(file_paths, file_type)
430
+ return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
431
+
432
+ def _extract_sequence_tokens(self, value: Any, separator: str) -> list[str]:
433
+ """Extract sequence tokens from a single value."""
434
+ if value is None:
435
+ return []
436
+ if isinstance(value, (float, np.floating)) and np.isnan(value):
437
+ return []
438
+ if isinstance(value, str):
439
+ stripped = value.strip()
440
+ return [] if not stripped else stripped.split(separator)
441
+ if isinstance(value, (list, tuple, np.ndarray)):
442
+ return [str(v) for v in value]
443
+ return [str(value)]
444
+
445
+ def _fit_from_path(self, path: str, chunk_size: int) -> 'DataProcessor':
446
+ """Fit processor statistics by streaming files to reduce memory usage."""
398
447
  logger = logging.getLogger()
399
-
400
- if isinstance(data, dict):
401
- data = pd.DataFrame(data)
402
-
403
- logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
448
+ logger.info(colorize("Fitting DataProcessor (streaming path mode)...", color="cyan", bold=True))
449
+ file_paths, file_type = resolve_file_paths(path)
450
+
451
+ numeric_acc: Dict[str, Dict[str, float]] = {}
452
+ for name in self.numeric_features.keys():
453
+ numeric_acc[name] = {
454
+ "sum": 0.0,
455
+ "sumsq": 0.0,
456
+ "count": 0.0,
457
+ "min": np.inf,
458
+ "max": -np.inf,
459
+ "max_abs": 0.0,
460
+ }
461
+
462
+ sparse_vocab: Dict[str, set[str]] = {name: set() for name in self.sparse_features.keys()}
463
+ seq_vocab: Dict[str, set[str]] = {name: set() for name in self.sequence_features.keys()}
464
+ target_values: Dict[str, set[Any]] = {name: set() for name in self.target_features.keys()}
404
465
 
466
+ missing_features = set()
467
+
468
+ for file_path in file_paths:
469
+ for chunk in iter_file_chunks(file_path, file_type, chunk_size):
470
+ # numeric features
471
+ for name, config in self.numeric_features.items():
472
+ if name not in chunk.columns:
473
+ missing_features.add(name)
474
+ continue
475
+ series = chunk[name]
476
+ values = pd.to_numeric(series, errors="coerce")
477
+ values = values.dropna()
478
+ if values.empty:
479
+ continue
480
+ acc = numeric_acc[name]
481
+ arr = values.to_numpy(dtype=np.float64, copy=False)
482
+ acc["count"] += arr.size
483
+ acc["sum"] += float(arr.sum())
484
+ acc["sumsq"] += float(np.square(arr).sum())
485
+ acc["min"] = min(acc["min"], float(arr.min()))
486
+ acc["max"] = max(acc["max"], float(arr.max()))
487
+ acc["max_abs"] = max(acc["max_abs"], float(np.abs(arr).max()))
488
+
489
+ # sparse features
490
+ for name, config in self.sparse_features.items():
491
+ if name not in chunk.columns:
492
+ missing_features.add(name)
493
+ continue
494
+ fill_na = config["fill_na"]
495
+ series = chunk[name].fillna(fill_na).astype(str)
496
+ sparse_vocab[name].update(series.tolist())
497
+
498
+ # sequence features
499
+ for name, config in self.sequence_features.items():
500
+ if name not in chunk.columns:
501
+ missing_features.add(name)
502
+ continue
503
+ separator = config["separator"]
504
+ series = chunk[name]
505
+ tokens = []
506
+ for val in series:
507
+ tokens.extend(self._extract_sequence_tokens(val, separator))
508
+ seq_vocab[name].update(tokens)
509
+
510
+ # target features
511
+ for name in self.target_features.keys():
512
+ if name not in chunk.columns:
513
+ missing_features.add(name)
514
+ continue
515
+ vals = chunk[name].dropna().tolist()
516
+ target_values[name].update(vals)
517
+
518
+ if missing_features:
519
+ logger.warning(
520
+ f"The following configured features were not found in provided files: {sorted(missing_features)}"
521
+ )
522
+
523
+ # finalize numeric scalers
405
524
  for name, config in self.numeric_features.items():
406
- if name not in data.columns:
407
- logger.warning(f"Numeric feature {name} not found in data")
525
+ acc = numeric_acc[name]
526
+ if acc["count"] == 0:
527
+ logger.warning(f"Numeric feature {name} has no valid values in provided files")
408
528
  continue
409
- self._process_numeric_feature_fit(data[name], config)
410
-
411
- for name, config in self.sparse_features.items():
412
- if name not in data.columns:
413
- logger.warning(f"Sparse feature {name} not found in data")
529
+
530
+ mean_val = acc["sum"] / acc["count"]
531
+ if config["fill_na"] is not None:
532
+ config["fill_na_value"] = config["fill_na"]
533
+ else:
534
+ config["fill_na_value"] = mean_val
535
+
536
+ scaler_type = config["scaler"]
537
+ if scaler_type == "standard":
538
+ var = max(acc["sumsq"] / acc["count"] - mean_val * mean_val, 0.0)
539
+ scaler = StandardScaler()
540
+ scaler.mean_ = np.array([mean_val], dtype=np.float64)
541
+ scaler.var_ = np.array([var], dtype=np.float64)
542
+ scaler.scale_ = np.array([np.sqrt(var) if var > 0 else 1.0], dtype=np.float64)
543
+ scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
544
+ self.scalers[name] = scaler
545
+ elif scaler_type == "minmax":
546
+ data_min = acc["min"] if np.isfinite(acc["min"]) else 0.0
547
+ data_max = acc["max"] if np.isfinite(acc["max"]) else data_min
548
+ scaler = MinMaxScaler()
549
+ scaler.data_min_ = np.array([data_min], dtype=np.float64)
550
+ scaler.data_max_ = np.array([data_max], dtype=np.float64)
551
+ scaler.data_range_ = scaler.data_max_ - scaler.data_min_
552
+ scaler.data_range_[scaler.data_range_ == 0] = 1.0
553
+ scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
554
+ self.scalers[name] = scaler
555
+ elif scaler_type == "maxabs":
556
+ scaler = MaxAbsScaler()
557
+ scaler.max_abs_ = np.array([acc["max_abs"]], dtype=np.float64)
558
+ scaler.n_samples_seen_ = np.array([int(acc["count"])], dtype=np.int64)
559
+ self.scalers[name] = scaler
560
+ elif scaler_type in ("log", "none", "robust"):
561
+ # log and none do not require fitting; robust requires full data and is handled earlier
414
562
  continue
415
- self._process_sparse_feature_fit(data[name], config)
416
-
563
+ else:
564
+ raise ValueError(f"Unknown scaler type: {scaler_type}")
565
+
566
+ # finalize sparse label encoders
567
+ for name, config in self.sparse_features.items():
568
+ if config["encode_method"] == "label":
569
+ vocab = sparse_vocab[name]
570
+ if not vocab:
571
+ logger.warning(f"Sparse feature {name} has empty vocabulary")
572
+ continue
573
+ le = LabelEncoder()
574
+ le.fit(list(vocab))
575
+ self.label_encoders[name] = le
576
+ config["vocab_size"] = len(le.classes_)
577
+ elif config["encode_method"] == "hash":
578
+ config["vocab_size"] = config["hash_size"]
579
+
580
+ # finalize sequence vocabularies
417
581
  for name, config in self.sequence_features.items():
418
- if name not in data.columns:
419
- logger.warning(f"Sequence feature {name} not found in data")
420
- continue
421
- self._process_sequence_feature_fit(data[name], config)
582
+ if config["encode_method"] == "label":
583
+ vocab = seq_vocab[name] or {"<PAD>"}
584
+ le = LabelEncoder()
585
+ le.fit(list(vocab))
586
+ self.label_encoders[name] = le
587
+ config["vocab_size"] = len(le.classes_)
588
+ elif config["encode_method"] == "hash":
589
+ config["vocab_size"] = config["hash_size"]
422
590
 
591
+ # finalize targets
423
592
  for name, config in self.target_features.items():
424
- if name not in data.columns:
425
- logger.warning(f"Target {name} not found in data")
593
+ if not target_values[name]:
594
+ logger.warning(f"Target {name} has no valid values in provided files")
426
595
  continue
427
- self._process_target_fit(data[name], config)
428
-
596
+
597
+ target_type = config["target_type"]
598
+ if target_type in ["binary", "multiclass"]:
599
+ unique_values = list(target_values[name])
600
+ try:
601
+ sorted_values = sorted(unique_values)
602
+ except TypeError:
603
+ sorted_values = sorted(unique_values, key=lambda x: str(x))
604
+
605
+ label_map = config["label_map"]
606
+ if label_map is None:
607
+ try:
608
+ int_values = [int(v) for v in sorted_values]
609
+ if int_values == list(range(len(int_values))):
610
+ label_map = {str(val): int(val) for val in sorted_values}
611
+ else:
612
+ label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
613
+ except (ValueError, TypeError):
614
+ label_map = {str(val): idx for idx, val in enumerate(sorted_values)}
615
+ config["label_map"] = label_map
616
+
617
+ self.target_encoders[name] = label_map
618
+
429
619
  self.is_fitted = True
430
- logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
620
+ logger.info(colorize("DataProcessor fitted successfully (streaming path mode)", color="green", bold=True))
431
621
  return self
432
-
433
- def transform(
434
- self,
622
+
623
+ def _transform_in_memory(
624
+ self,
435
625
  data: Union[pd.DataFrame, Dict[str, Any]],
436
- return_dict: bool = True
626
+ return_dict: bool,
627
+ persist: bool,
628
+ save_format: Optional[Literal["csv", "parquet"]],
437
629
  ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
438
630
  logger = logging.getLogger()
439
-
440
631
 
441
- if not self.is_fitted:
442
- raise ValueError("DataProcessor must be fitted before transform")
443
-
444
632
  # Convert input to dict format for unified processing
445
633
  if isinstance(data, pd.DataFrame):
446
634
  data_dict = {col: data[col] for col in data.columns}
@@ -494,61 +682,233 @@ class DataProcessor:
494
682
  series_data = pd.Series(data_dict[name], name=name)
495
683
  processed = self._process_target_transform(series_data, config)
496
684
  result_dict[name] = processed
497
-
498
- if return_dict:
499
- return result_dict
500
- else:
685
+
686
+ def _dict_to_dataframe(result: Dict[str, np.ndarray]) -> pd.DataFrame:
501
687
  # Convert all arrays to Series/lists at once to avoid fragmentation
502
688
  columns_dict = {}
503
- for key, value in result_dict.items():
689
+ for key, value in result.items():
504
690
  if key in self.sequence_features:
505
691
  columns_dict[key] = [list(seq) for seq in value]
506
692
  else:
507
693
  columns_dict[key] = value
694
+ return pd.DataFrame(columns_dict)
695
+
696
+ assert save_format in [None, "csv", "parquet"], "save_format must be either 'csv', 'parquet', or None"
697
+ if persist and save_format is None:
698
+ save_format = "parquet"
699
+
700
+ result_df = None
701
+ if (not return_dict) or (save_format is not None):
702
+ result_df = _dict_to_dataframe(result_dict)
703
+ assert result_df is not None, "DataFrame is None after transform"
704
+
705
+ if save_format is not None:
706
+ save_path = resolve_save_path(
707
+ path=None,
708
+ default_dir=self.session_dir / "processor" / "preprocessed_data",
709
+ default_name="data_processed",
710
+ suffix=f".{save_format}",
711
+ add_timestamp=True,
712
+ )
713
+
714
+ if save_format == "parquet":
715
+ result_df.to_parquet(save_path, index=False)
716
+ else:
717
+ result_df.to_csv(save_path, index=False)
718
+
719
+ logger.info(colorize(
720
+ f"Transformed data saved to: {save_path}",
721
+ color="green"
722
+ ))
723
+
724
+ if return_dict:
725
+ return result_dict
726
+ return result_df
727
+
728
+ def _transform_path(self, path: str, output_path: Optional[str]) -> list[str]:
729
+ """Transform data from files under a path and save them to a new location."""
730
+ logger = logging.getLogger()
731
+
732
+ file_paths, file_type = resolve_file_paths(path)
733
+ default_root = self.session_dir / "processor" / default_output_dir(path).name
734
+ output_root = default_root
735
+ target_file_override: Optional[Path] = None
736
+
737
+ if output_path:
738
+ output_path_obj = Path(output_path)
739
+ if not output_path_obj.is_absolute():
740
+ output_path_obj = self.session_dir / output_path_obj
741
+ if output_path_obj.suffix.lower() in {".csv", ".parquet"}:
742
+ if len(file_paths) != 1:
743
+ raise ValueError("output_path points to a file but multiple input files were provided.")
744
+ target_file_override = output_path_obj
745
+ output_root = output_path_obj.parent
746
+ else:
747
+ output_root = output_path_obj
748
+
749
+ output_root.mkdir(parents=True, exist_ok=True)
750
+
751
+ saved_paths: list[str] = []
752
+ for file_path in file_paths:
753
+ df = read_table(file_path, file_type)
754
+
755
+ transformed_df = self._transform_in_memory(
756
+ df,
757
+ return_dict=False,
758
+ persist=False,
759
+ save_format=None,
760
+ )
761
+ assert isinstance(transformed_df, pd.DataFrame), "Expected DataFrame when return_dict=False"
762
+
763
+ source_path = Path(file_path)
764
+ target_file = (
765
+ target_file_override
766
+ if target_file_override is not None
767
+ else output_root / f"{source_path.stem}_preprocessed{source_path.suffix}"
768
+ )
769
+
770
+ if file_type == "csv":
771
+ transformed_df.to_csv(target_file, index=False)
772
+ else:
773
+ transformed_df.to_parquet(target_file, index=False)
774
+
775
+ saved_paths.append(str(target_file.resolve()))
776
+
777
+ logger.info(colorize(
778
+ f"Transformed {len(saved_paths)} file(s) saved to: {output_root.resolve()}",
779
+ color="green",
780
+ ))
781
+ return saved_paths
782
+
783
+ # fit is nothing but registering the statistics from data so that we can transform the data later
784
+ def fit(
785
+ self,
786
+ data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
787
+ chunk_size: int = 200000,
788
+ ):
789
+ logger = logging.getLogger()
790
+
791
+ if isinstance(data, (str, os.PathLike)):
792
+ path_str = str(data)
793
+ uses_robust = any(cfg.get("scaler") == "robust" for cfg in self.numeric_features.values())
794
+ if uses_robust:
795
+ logger.warning(
796
+ "Robust scaler requires full data; loading all files into memory. "
797
+ "Consider smaller chunk_size or different scaler if memory is limited."
798
+ )
799
+ data = self._load_dataframe_from_path(path_str)
800
+ else:
801
+ return self._fit_from_path(path_str, chunk_size)
802
+ if isinstance(data, dict):
803
+ data = pd.DataFrame(data)
508
804
 
509
- result_df = pd.DataFrame(columns_dict)
510
- return result_df
805
+ logger.info(colorize("Fitting DataProcessor...", color="cyan", bold=True))
806
+
807
+ for name, config in self.numeric_features.items():
808
+ if name not in data.columns:
809
+ logger.warning(f"Numeric feature {name} not found in data")
810
+ continue
811
+ self._process_numeric_feature_fit(data[name], config)
812
+
813
+ for name, config in self.sparse_features.items():
814
+ if name not in data.columns:
815
+ logger.warning(f"Sparse feature {name} not found in data")
816
+ continue
817
+ self._process_sparse_feature_fit(data[name], config)
818
+
819
+ for name, config in self.sequence_features.items():
820
+ if name not in data.columns:
821
+ logger.warning(f"Sequence feature {name} not found in data")
822
+ continue
823
+ self._process_sequence_feature_fit(data[name], config)
824
+
825
+ for name, config in self.target_features.items():
826
+ if name not in data.columns:
827
+ logger.warning(f"Target {name} not found in data")
828
+ continue
829
+ self._process_target_fit(data[name], config)
830
+
831
+ self.is_fitted = True
832
+ logger.info(colorize("DataProcessor fitted successfully", color="green", bold=True))
833
+ return self
834
+
835
+ def transform(
836
+ self,
837
+ data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
838
+ return_dict: bool = True,
839
+ persist: bool = False,
840
+ save_format: Optional[Literal["csv", "parquet"]] = None,
841
+ output_path: Optional[str] = None,
842
+ ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
843
+ logger = logging.getLogger()
844
+
845
+ if not self.is_fitted:
846
+ raise ValueError("DataProcessor must be fitted before transform")
847
+
848
+ if isinstance(data, (str, os.PathLike)):
849
+ if return_dict or persist or save_format is not None:
850
+ raise ValueError("Path transform writes files only; use output_path and leave return_dict/persist/save_format defaults.")
851
+ return self._transform_path(str(data), output_path)
852
+
853
+ return self._transform_in_memory(
854
+ data=data,
855
+ return_dict=return_dict,
856
+ persist=persist,
857
+ save_format=save_format,
858
+ )
511
859
 
512
860
  def fit_transform(
513
861
  self,
514
- data: Union[pd.DataFrame, Dict[str, Any]],
515
- return_dict: bool = True
516
- ) -> Union[pd.DataFrame, Dict[str, np.ndarray]]:
517
- self.fit(data)
518
- return self.transform(data, return_dict=return_dict)
519
-
520
- def save(self, filepath: str):
862
+ data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
863
+ return_dict: bool = True,
864
+ save_format: Optional[Literal["csv", "parquet"]] = None,
865
+ output_path: Optional[str] = None,
866
+ chunk_size: int = 200000,
867
+ ) -> Union[pd.DataFrame, Dict[str, np.ndarray], list[str]]:
868
+ self.fit(data, chunk_size=chunk_size)
869
+ return self.transform(
870
+ data,
871
+ return_dict=return_dict,
872
+ save_format=save_format,
873
+ output_path=output_path,
874
+ )
875
+
876
+ def save(self, save_path: str):
521
877
  logger = logging.getLogger()
522
-
878
+
523
879
  if not self.is_fitted:
524
880
  logger.warning("Saving unfitted DataProcessor")
525
881
 
526
- dir_path = os.path.dirname(filepath)
527
- if dir_path and not os.path.exists(dir_path):
528
- os.makedirs(dir_path, exist_ok=True)
529
- logger.info(f"Created directory: {dir_path}")
530
-
882
+ target_path = resolve_save_path(
883
+ path=save_path,
884
+ default_dir=self.session.processor_dir,
885
+ default_name="processor",
886
+ suffix=".pkl",
887
+ )
888
+
889
+ # Prepare state dict
531
890
  state = {
532
- 'numeric_features': self.numeric_features,
533
- 'sparse_features': self.sparse_features,
534
- 'sequence_features': self.sequence_features,
535
- 'target_features': self.target_features,
536
- 'is_fitted': self.is_fitted,
537
- 'scalers': self.scalers,
538
- 'label_encoders': self.label_encoders,
539
- 'target_encoders': self.target_encoders
891
+ "numeric_features": self.numeric_features,
892
+ "sparse_features": self.sparse_features,
893
+ "sequence_features": self.sequence_features,
894
+ "target_features": self.target_features,
895
+ "is_fitted": self.is_fitted,
896
+ "scalers": self.scalers,
897
+ "label_encoders": self.label_encoders,
898
+ "target_encoders": self.target_encoders,
540
899
  }
541
-
542
- with open(filepath, 'wb') as f:
900
+
901
+ # Save with pickle
902
+ with open(target_path, "wb") as f:
543
903
  pickle.dump(state, f)
544
-
545
- logger.info(f"DataProcessor saved to {filepath}")
904
+
905
+ logger.info(colorize(f"DataProcessor saved to: {target_path}", color="green"))
546
906
 
547
907
  @classmethod
548
- def load(cls, filepath: str) -> 'DataProcessor':
908
+ def load(cls, load_path: str) -> 'DataProcessor':
549
909
  logger = logging.getLogger()
550
910
 
551
- with open(filepath, 'rb') as f:
911
+ with open(load_path, 'rb') as f:
552
912
  state = pickle.load(f)
553
913
 
554
914
  processor = cls()
@@ -561,7 +921,7 @@ class DataProcessor:
561
921
  processor.label_encoders = state['label_encoders']
562
922
  processor.target_encoders = state['target_encoders']
563
923
 
564
- logger.info(f"DataProcessor loaded from {filepath}")
924
+ logger.info(f"DataProcessor loaded from {load_path}")
565
925
  return processor
566
926
 
567
927
  def get_vocab_sizes(self) -> Dict[str, int]:
@@ -659,4 +1019,4 @@ class DataProcessor:
659
1019
 
660
1020
  logger.info("")
661
1021
  logger.info("")
662
- logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1022
+ logger.info(colorize("=" * 80, color="bright_blue", bold=True))