nextrec 0.4.22__py3-none-any.whl → 0.4.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/layers.py +96 -46
  3. nextrec/basic/metrics.py +128 -114
  4. nextrec/basic/model.py +94 -91
  5. nextrec/basic/summary.py +36 -2
  6. nextrec/data/dataloader.py +2 -0
  7. nextrec/data/preprocessor.py +137 -5
  8. nextrec/loss/listwise.py +19 -6
  9. nextrec/loss/pairwise.py +6 -4
  10. nextrec/loss/pointwise.py +8 -6
  11. nextrec/models/multi_task/aitm.py +0 -0
  12. nextrec/models/multi_task/apg.py +0 -0
  13. nextrec/models/multi_task/cross_stitch.py +0 -0
  14. nextrec/models/multi_task/esmm.py +5 -28
  15. nextrec/models/multi_task/mmoe.py +6 -28
  16. nextrec/models/multi_task/pepnet.py +335 -0
  17. nextrec/models/multi_task/ple.py +21 -40
  18. nextrec/models/multi_task/poso.py +17 -39
  19. nextrec/models/multi_task/share_bottom.py +5 -28
  20. nextrec/models/multi_task/snr_trans.py +0 -0
  21. nextrec/models/ranking/afm.py +3 -27
  22. nextrec/models/ranking/autoint.py +5 -38
  23. nextrec/models/ranking/dcn.py +1 -26
  24. nextrec/models/ranking/dcn_v2.py +6 -34
  25. nextrec/models/ranking/deepfm.py +2 -29
  26. nextrec/models/ranking/dien.py +2 -28
  27. nextrec/models/ranking/din.py +2 -27
  28. nextrec/models/ranking/eulernet.py +3 -30
  29. nextrec/models/ranking/ffm.py +0 -26
  30. nextrec/models/ranking/fibinet.py +8 -32
  31. nextrec/models/ranking/fm.py +0 -29
  32. nextrec/models/ranking/lr.py +0 -30
  33. nextrec/models/ranking/masknet.py +4 -30
  34. nextrec/models/ranking/pnn.py +4 -28
  35. nextrec/models/ranking/widedeep.py +0 -32
  36. nextrec/models/ranking/xdeepfm.py +0 -30
  37. nextrec/models/retrieval/dssm.py +4 -28
  38. nextrec/models/retrieval/dssm_v2.py +4 -28
  39. nextrec/models/retrieval/mind.py +2 -22
  40. nextrec/models/retrieval/sdm.py +4 -24
  41. nextrec/models/retrieval/youtube_dnn.py +4 -25
  42. nextrec/models/sequential/hstu.py +0 -18
  43. nextrec/utils/model.py +91 -4
  44. nextrec/utils/types.py +35 -0
  45. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/METADATA +8 -6
  46. nextrec-0.4.24.dist-info/RECORD +86 -0
  47. nextrec-0.4.22.dist-info/RECORD +0 -81
  48. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/WHEEL +0 -0
  49. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/entry_points.txt +0 -0
  50. {nextrec-0.4.22.dist-info → nextrec-0.4.24.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,13 +2,14 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 30/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
10
  import logging
11
11
  import os
12
+ import sys
12
13
  import pickle
13
14
  import socket
14
15
  from pathlib import Path
@@ -16,6 +17,16 @@ from typing import Any, Literal
16
17
 
17
18
  import numpy as np
18
19
  import pandas as pd
20
+
21
+ try:
22
+ import swanlab # type: ignore
23
+ except ModuleNotFoundError:
24
+ swanlab = None
25
+ try:
26
+ import wandb # type: ignore
27
+ except ModuleNotFoundError:
28
+ wandb = None
29
+
19
30
  import torch
20
31
  import torch.distributed as dist
21
32
  import torch.nn as nn
@@ -74,13 +85,19 @@ from nextrec.utils.torch_utils import (
74
85
  to_tensor,
75
86
  )
76
87
  from nextrec.utils.config import safe_value
77
- from nextrec.utils.model import compute_ranking_loss
88
+ from nextrec.utils.model import (
89
+ compute_ranking_loss,
90
+ get_loss_list,
91
+ resolve_loss_weights,
92
+ get_training_modes,
93
+ )
78
94
  from nextrec.utils.types import (
79
95
  LossName,
80
96
  OptimizerName,
81
97
  SchedulerName,
82
98
  TrainingModeName,
83
99
  TaskTypeName,
100
+ MetricsName,
84
101
  )
85
102
 
86
103
 
@@ -90,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
90
107
  raise NotImplementedError
91
108
 
92
109
  @property
93
- def default_task(self) -> str | list[str]:
110
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
94
111
  raise NotImplementedError
95
112
 
96
113
  def __init__(
@@ -139,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
139
156
  world_size: Number of processes (defaults to env WORLD_SIZE).
140
157
  local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
141
158
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
159
+
160
+ Note:
161
+ Optimizer, scheduler, and loss are configured via compile().
142
162
  """
143
163
  super(BaseModel, self).__init__()
144
164
 
@@ -171,24 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
171
191
  dense_features, sparse_features, sequence_features, target, id_columns
172
192
  )
173
193
 
174
- self.task = self.default_task if task is None else task
194
+ self.task = task or self.default_task
175
195
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
- if isinstance(training_mode, list):
177
- training_modes = list(training_mode)
178
- if len(training_modes) != self.nums_task:
179
- raise ValueError(
180
- "[BaseModel-init Error] training_mode list length must match number of tasks."
181
- )
182
- else:
183
- training_modes = [training_mode] * self.nums_task
184
- if any(
185
- mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
186
- ):
187
- raise ValueError(
188
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
189
- )
190
- self.training_modes = training_modes
191
- self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
196
+ self.training_modes = get_training_modes(training_mode, self.nums_task)
197
+ self.training_mode = (
198
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
+ )
192
200
 
193
201
  self.embedding_l1_reg = embedding_l1_reg
194
202
  self.dense_l1_reg = dense_l1_reg
@@ -196,8 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
196
204
  self.dense_l2_reg = dense_l2_reg
197
205
  self.regularization_weights = []
198
206
  self.embedding_params = []
199
- self.loss_weight = None
207
+
200
208
  self.ignore_label = None
209
+ self.compiled = False
201
210
 
202
211
  self.max_gradient_norm = 1.0
203
212
  self.logger_initialized = False
@@ -431,28 +440,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
431
440
  "pairwise": "bpr",
432
441
  "listwise": "listnet",
433
442
  }
434
- effective_loss = loss
435
- if effective_loss is None:
436
- loss_list = [default_losses[mode] for mode in self.training_modes]
437
- elif isinstance(effective_loss, list):
438
- if not effective_loss:
439
- loss_list = [default_losses[mode] for mode in self.training_modes]
440
- else:
441
- if len(effective_loss) != self.nums_task:
442
- raise ValueError(
443
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
444
- )
445
- loss_list = list(effective_loss)
446
- else:
447
- loss_list = [effective_loss] * self.nums_task
448
-
449
- for idx, mode in enumerate(self.training_modes):
450
- if isinstance(loss_list[idx], str) and loss_list[idx] in {
451
- "bce",
452
- "binary_crossentropy",
453
- }:
454
- if mode in {"pairwise", "listwise"}:
455
- loss_list[idx] = default_losses[mode]
443
+ loss_list = get_loss_list(
444
+ loss, self.training_modes, self.nums_task, default_losses
445
+ )
456
446
  self.loss_params = loss_params or {}
457
447
  optimizer_params = optimizer_params or {}
458
448
  self.optimizer_name = (
@@ -516,30 +506,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
516
506
  nums_task=self.nums_task, device=self.device, **grad_norm_params
517
507
  )
518
508
  self.loss_weights = None
519
- elif loss_weights is None:
520
- self.loss_weights = None
521
- elif self.nums_task == 1:
522
- if isinstance(loss_weights, (list, tuple)):
523
- if len(loss_weights) != 1:
524
- raise ValueError(
525
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
526
- )
527
- loss_weights = loss_weights[0]
528
- self.loss_weights = [float(loss_weights)] # type: ignore
529
509
  else:
530
- if isinstance(loss_weights, (int, float)):
531
- weights = [float(loss_weights)] * self.nums_task
532
- elif isinstance(loss_weights, (list, tuple)):
533
- weights = [float(w) for w in loss_weights]
534
- if len(weights) != self.nums_task:
535
- raise ValueError(
536
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
537
- )
538
- else:
539
- raise TypeError(
540
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
541
- )
542
- self.loss_weights = weights
510
+ self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
511
+ self.compiled = True
543
512
 
544
513
  def compute_loss(self, y_pred, y_true):
545
514
  if y_true is None:
@@ -602,9 +571,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
602
571
  for i, (start, end) in enumerate(slices): # type: ignore
603
572
  y_pred_i = y_pred[:, start:end]
604
573
  y_true_i = y_true[:, start:end]
605
- total_count = y_true_i.shape[0]
606
- # valid_count = None
607
-
608
574
  # mask ignored labels
609
575
  if self.ignore_label is not None:
610
576
  valid_mask = y_true_i != self.ignore_label
@@ -613,11 +579,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
613
579
  if not torch.any(valid_mask):
614
580
  task_losses.append(y_pred_i.sum() * 0.0)
615
581
  continue
616
- # valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
617
582
  y_pred_i = y_pred_i[valid_mask]
618
583
  y_true_i = y_true_i[valid_mask]
619
- # else:
620
- # valid_count = y_true_i.new_tensor(float(total_count))
621
584
 
622
585
  mode = self.training_modes[i]
623
586
 
@@ -683,6 +646,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
683
646
  sampler=sampler,
684
647
  collate_fn=collate_fn,
685
648
  num_workers=num_workers,
649
+ pin_memory=self.device.type == "cuda",
650
+ persistent_workers=num_workers > 0,
686
651
  )
687
652
  return (loader, dataset) if return_dataset else loader
688
653
 
@@ -691,7 +656,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
691
656
  train_data=None,
692
657
  valid_data=None,
693
658
  metrics: (
694
- list[str] | dict[str, list[str]] | None
659
+ list[MetricsName] | dict[str, list[MetricsName]] | None
695
660
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
696
661
  epochs: int = 1,
697
662
  shuffle: bool = True,
@@ -705,6 +670,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
705
670
  use_tensorboard: bool = True,
706
671
  use_wandb: bool = False,
707
672
  use_swanlab: bool = False,
673
+ wandb_api: str | None = None,
674
+ swanlab_api: str | None = None,
708
675
  wandb_kwargs: dict | None = None,
709
676
  swanlab_kwargs: dict | None = None,
710
677
  auto_ddp_sampler: bool = True,
@@ -734,6 +701,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
734
701
  use_tensorboard: Enable tensorboard logging.
735
702
  use_wandb: Enable Weights & Biases logging.
736
703
  use_swanlab: Enable SwanLab logging.
704
+ wandb_api: W&B API key for non-tty login.
705
+ swanlab_api: SwanLab API key for non-tty login.
737
706
  wandb_kwargs: Optional kwargs for wandb.init(...).
738
707
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
739
708
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
@@ -751,6 +720,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
751
720
  )
752
721
  self.to(self.device)
753
722
 
723
+ if not self.compiled:
724
+ self.compile(
725
+ optimizer="adam",
726
+ optimizer_params={},
727
+ scheduler=None,
728
+ scheduler_params={},
729
+ loss=None,
730
+ loss_params={},
731
+ )
732
+
754
733
  if (
755
734
  self.distributed
756
735
  and dist.is_available()
@@ -825,6 +804,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
825
804
  }
826
805
  training_config: dict = safe_value(training_config) # type: ignore
827
806
 
807
+ if self.is_main_process:
808
+ is_tty = sys.stdin.isatty() and sys.stdout.isatty()
809
+ if not is_tty:
810
+ if use_wandb and wandb_api:
811
+ if wandb is None:
812
+ logging.warning(
813
+ "[BaseModel-fit] wandb not installed, skip wandb login."
814
+ )
815
+ else:
816
+ wandb.login(key=wandb_api)
817
+ if use_swanlab and swanlab_api:
818
+ if swanlab is None:
819
+ logging.warning(
820
+ "[BaseModel-fit] swanlab not installed, skip swanlab login."
821
+ )
822
+ else:
823
+ swanlab.login(api_key=swanlab_api)
824
+
828
825
  self.training_logger = (
829
826
  TrainingLogger(
830
827
  session=self.session,
@@ -1124,16 +1121,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1124
1121
  train_log_payload, step=epoch + 1, split="train"
1125
1122
  )
1126
1123
  if valid_loader is not None:
1127
- self.callbacks.on_validation_begin()
1128
- val_metrics = self.evaluate(
1129
- valid_loader,
1130
- user_ids=valid_user_ids if self.needs_user_ids else None,
1131
- num_workers=num_workers,
1132
- )
1133
- should_log_valid = (epoch + 1) % log_interval == 0 or (
1124
+ should_eval_valid = (epoch + 1) % log_interval == 0 or (
1134
1125
  epoch + 1
1135
1126
  ) == epochs
1136
- if should_log_valid:
1127
+ val_metrics = None
1128
+ if should_eval_valid:
1129
+ self.callbacks.on_validation_begin()
1130
+ val_metrics = self.evaluate(
1131
+ valid_loader,
1132
+ user_ids=valid_user_ids if self.needs_user_ids else None,
1133
+ num_workers=num_workers,
1134
+ )
1137
1135
  display_metrics_table(
1138
1136
  epoch=epoch + 1,
1139
1137
  epochs=epochs,
@@ -1147,23 +1145,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1147
1145
  is_main_process=self.is_main_process,
1148
1146
  colorize=lambda s: colorize(" " + s, color="cyan"),
1149
1147
  )
1150
- self.callbacks.on_validation_end()
1151
- if should_log_valid and val_metrics and self.training_logger:
1152
- self.training_logger.log_metrics(
1153
- val_metrics, step=epoch + 1, split="valid"
1154
- )
1148
+ self.callbacks.on_validation_end()
1149
+ if val_metrics and self.training_logger:
1150
+ self.training_logger.log_metrics(
1151
+ val_metrics, step=epoch + 1, split="valid"
1152
+ )
1155
1153
  if not val_metrics:
1156
- if self.is_main_process:
1154
+ if should_eval_valid and self.is_main_process:
1157
1155
  logging.info(
1158
1156
  colorize(
1159
1157
  "Warning: No validation metrics computed. Skipping validation for this epoch.",
1160
1158
  color="yellow",
1161
1159
  )
1162
1160
  )
1163
- continue
1164
- epoch_logs = {**train_log_payload}
1165
- for k, v in val_metrics.items():
1166
- epoch_logs[f"val_{k}"] = v
1161
+ epoch_logs = {**train_log_payload}
1162
+ else:
1163
+ epoch_logs = {**train_log_payload}
1164
+ for k, v in val_metrics.items():
1165
+ epoch_logs[f"val_{k}"] = v
1167
1166
  else:
1168
1167
  epoch_logs = {**train_log_payload}
1169
1168
  if self.is_main_process:
@@ -1345,6 +1344,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1345
1344
  target_names=self.target_columns,
1346
1345
  task_specific_metrics=self.task_specific_metrics,
1347
1346
  user_ids=combined_user_ids,
1347
+ ignore_label=self.ignore_label,
1348
1348
  )
1349
1349
  return avg_loss, metrics_dict
1350
1350
  return avg_loss
@@ -1392,6 +1392,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1392
1392
  sampler=valid_sampler,
1393
1393
  collate_fn=collate_fn,
1394
1394
  num_workers=num_workers,
1395
+ pin_memory=self.device.type == "cuda",
1396
+ persistent_workers=num_workers > 0,
1395
1397
  )
1396
1398
  valid_user_ids = None
1397
1399
  if needs_user_ids:
@@ -1537,6 +1539,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1537
1539
  target_names=self.target_columns,
1538
1540
  task_specific_metrics=self.task_specific_metrics,
1539
1541
  user_ids=final_user_ids,
1542
+ ignore_label=self.ignore_label,
1540
1543
  )
1541
1544
  return metrics_dict
1542
1545
 
nextrec/basic/summary.py CHANGED
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Summary utilities for BaseModel.
3
+
4
+ Date: create on 03/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from __future__ import annotations
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
12
16
 
13
17
  from nextrec.basic.loggers import colorize, format_kv
14
18
  from nextrec.data.data_processing import extract_label_arrays, get_data_length
19
+ from nextrec.utils.types import TaskTypeName
15
20
 
16
21
 
17
22
  class SummarySet:
23
+ model_name: str
24
+ dense_features: list[Any]
25
+ sparse_features: list[Any]
26
+ sequence_features: list[Any]
27
+ task: TaskTypeName | list[TaskTypeName]
28
+ target_columns: list[str]
29
+ nums_task: int
30
+ metrics: Any
31
+ device: Any
32
+ optimizer_name: str
33
+ optimizer_params: dict[str, Any]
34
+ scheduler_name: str | None
35
+ scheduler_params: dict[str, Any]
36
+ loss_config: Any
37
+ loss_weights: Any
38
+ grad_norm: Any
39
+ embedding_l1_reg: float
40
+ embedding_l2_reg: float
41
+ dense_l1_reg: float
42
+ dense_l2_reg: float
43
+ early_stop_patience: int
44
+ max_gradient_norm: float | None
45
+ metrics_sample_limit: int | None
46
+ session_id: str | None
47
+ features_config_path: str
48
+ checkpoint_path: str
49
+ train_data_summary: dict[str, Any] | None
50
+ valid_data_summary: dict[str, Any] | None
51
+
18
52
  def build_data_summary(
19
53
  self, data: Any, data_loader: DataLoader | None, sample_key: str
20
54
  ):
@@ -305,7 +339,7 @@ class SummarySet:
305
339
  lines = details.get("lines", [])
306
340
  logger.info(f"{target_name}:")
307
341
  for label, value in lines:
308
- logger.info(format_kv(label, value))
342
+ logger.info(f" {format_kv(label, value)}")
309
343
 
310
344
  if self.valid_data_summary:
311
345
  if self.train_data_summary:
@@ -320,4 +354,4 @@ class SummarySet:
320
354
  lines = details.get("lines", [])
321
355
  logger.info(f"{target_name}:")
322
356
  for label, value in lines:
323
- logger.info(format_kv(label, value))
357
+ logger.info(f" {format_kv(label, value)}")
@@ -282,6 +282,8 @@ class RecDataLoader(FeatureSet):
282
282
  sampler=sampler,
283
283
  collate_fn=collate_fn,
284
284
  num_workers=num_workers,
285
+ pin_memory=torch.cuda.is_available(),
286
+ persistent_workers=num_workers > 0,
285
287
  )
286
288
 
287
289
  def create_from_path(
@@ -2,7 +2,7 @@
2
2
  DataProcessor for data preprocessing including numeric, sparse, sequence features and target processing.
3
3
 
4
4
  Date: create on 13/11/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -79,6 +79,14 @@ class DataProcessor(FeatureSet):
79
79
  ] = "standard",
80
80
  fill_na: Optional[float] = None,
81
81
  ):
82
+ """Add a numeric feature configuration.
83
+
84
+ Args:
85
+ name (str): Feature name.
86
+ scaler (Optional[Literal["standard", "minmax", "robust", "maxabs", "log", "none"]], optional): Scaler type. Defaults to "standard".
87
+ fill_na (Optional[float], optional): Fill value for missing entries. Defaults to None.
88
+ """
89
+
82
90
  self.numeric_features[name] = {"scaler": scaler, "fill_na": fill_na}
83
91
 
84
92
  def add_sparse_feature(
@@ -88,6 +96,14 @@ class DataProcessor(FeatureSet):
88
96
  hash_size: Optional[int] = None,
89
97
  fill_na: str = "<UNK>",
90
98
  ):
99
+ """Add a sparse feature configuration.
100
+
101
+ Args:
102
+ name (str): Feature name.
103
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "label".
104
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
105
+ fill_na (str, optional): Fill value for missing entries. Defaults to "<UNK>".
106
+ """
91
107
  if encode_method == "hash" and hash_size is None:
92
108
  raise ValueError(
93
109
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -101,7 +117,7 @@ class DataProcessor(FeatureSet):
101
117
  def add_sequence_feature(
102
118
  self,
103
119
  name: str,
104
- encode_method: Literal["hash", "label"] = "label",
120
+ encode_method: Literal["hash", "label"] = "hash",
105
121
  hash_size: Optional[int] = None,
106
122
  max_len: Optional[int] = 50,
107
123
  pad_value: int = 0,
@@ -110,6 +126,17 @@ class DataProcessor(FeatureSet):
110
126
  ] = "pre", # pre: keep last max_len items, post: keep first max_len items
111
127
  separator: str = ",",
112
128
  ):
129
+ """Add a sequence feature configuration.
130
+
131
+ Args:
132
+ name (str): Feature name.
133
+ encode_method (Literal["hash", "label"], optional): Encoding method, including "hash encoding" and "label encoding". Defaults to "hash".
134
+ hash_size (Optional[int], optional): Hash size for hash encoding. Required if encode_method is "hash".
135
+ max_len (Optional[int], optional): Maximum sequence length. Defaults to 50.
136
+ pad_value (int, optional): Padding value for sequences shorter than max_len. Defaults to 0.
137
+ truncate (Literal["pre", "post"], optional): Truncation strategy for sequences longer than max_len, including "pre" (keep last max_len items) and "post" (keep first max_len items). Defaults to "pre".
138
+ separator (str, optional): Separator for string sequences. Defaults to ",".
139
+ """
113
140
  if encode_method == "hash" and hash_size is None:
114
141
  raise ValueError(
115
142
  "[Data Processor Error] hash_size must be specified when encode_method='hash'"
@@ -131,6 +158,14 @@ class DataProcessor(FeatureSet):
131
158
  Dict[str, int]
132
159
  ] = None, # example: {'click': 1, 'no_click': 0}
133
160
  ):
161
+ """Add a target configuration.
162
+
163
+ Args:
164
+ name (str): Target name.
165
+ target_type (Literal["binary", "regression"], optional): Target type. Defaults to "binary".
166
+ label_map (Optional[Dict[str, int]], optional): Label mapping for binary targets. Defaults to None.
167
+ """
168
+
134
169
  self.target_features[name] = {
135
170
  "target_type": target_type,
136
171
  "label_map": label_map,
@@ -392,7 +427,15 @@ class DataProcessor(FeatureSet):
392
427
  )
393
428
 
394
429
  def load_dataframe_from_path(self, path: str) -> pd.DataFrame:
395
- """Load all data from a file or directory path into a single DataFrame."""
430
+ """
431
+ Load all data from a file or directory path into a single DataFrame.
432
+
433
+ Args:
434
+ path (str): File or directory path.
435
+
436
+ Returns:
437
+ pd.DataFrame: Loaded DataFrame.
438
+ """
396
439
  file_paths, file_type = resolve_file_paths(path)
397
440
  frames = load_dataframes(file_paths, file_type)
398
441
  return pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
@@ -411,7 +454,16 @@ class DataProcessor(FeatureSet):
411
454
  return [str(value)]
412
455
 
413
456
  def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
414
- """Fit processor statistics by streaming files to reduce memory usage."""
457
+ """
458
+ Fit processor statistics by streaming files to reduce memory usage.
459
+
460
+ Args:
461
+ path (str): File or directory path.
462
+ chunk_size (int): Number of rows per chunk.
463
+
464
+ Returns:
465
+ DataProcessor: Fitted DataProcessor instance.
466
+ """
415
467
  logger = logging.getLogger()
416
468
  logger.info(
417
469
  colorize(
@@ -428,7 +480,7 @@ class DataProcessor(FeatureSet):
428
480
  "Use fit(dataframe) with in-memory data or convert the data format."
429
481
  )
430
482
 
431
- numeric_acc: Dict[str, Dict[str, float]] = {}
483
+ numeric_acc = {}
432
484
  for name in self.numeric_features.keys():
433
485
  numeric_acc[name] = {
434
486
  "sum": 0.0,
@@ -609,6 +661,21 @@ class DataProcessor(FeatureSet):
609
661
  output_path: Optional[str],
610
662
  warn_missing: bool = True,
611
663
  ):
664
+ """
665
+ Transform in-memory data and optionally persist the transformed data.
666
+
667
+ Args:
668
+ data (Union[pd.DataFrame, Dict[str, Any]]): Input data.
669
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
670
+ persist (bool): Whether to persist the transformed data to disk.
671
+ save_format (Optional[str]): Format to save the data if persisting.
672
+ output_path (Optional[str]): Output path to save the data if persisting.
673
+ warn_missing (bool): Whether to warn about missing features in the data.
674
+
675
+ Returns:
676
+ Union[pd.DataFrame, Dict[str, np.ndarray]]: Transformed data.
677
+ """
678
+
612
679
  logger = logging.getLogger()
613
680
  data_dict = data if isinstance(data, dict) else None
614
681
 
@@ -719,6 +786,12 @@ class DataProcessor(FeatureSet):
719
786
  """Transform data from files under a path and save them to a new location.
720
787
 
721
788
  Uses chunked reading/writing to keep peak memory bounded for large files.
789
+
790
+ Args:
791
+ input_path (str): Input file or directory path.
792
+ output_path (Optional[str]): Output directory path. If None, defaults to input_path/transformed_data.
793
+ save_format (Optional[str]): Format to save transformed files. If None, uses input file format.
794
+ chunk_size (int): Number of rows per chunk.
722
795
  """
723
796
  logger = logging.getLogger()
724
797
  file_paths, file_type = resolve_file_paths(input_path)
@@ -876,6 +949,17 @@ class DataProcessor(FeatureSet):
876
949
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
877
950
  chunk_size: int = 200000,
878
951
  ):
952
+ """
953
+ Fit the DataProcessor to the provided data.
954
+
955
+ Args:
956
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting.
957
+ chunk_size (int): Number of rows per chunk when streaming from path.
958
+
959
+ Returns:
960
+ DataProcessor: Fitted DataProcessor instance.
961
+ """
962
+
879
963
  logger = logging.getLogger()
880
964
  if isinstance(data, (str, os.PathLike)):
881
965
  path_str = str(data)
@@ -915,6 +999,19 @@ class DataProcessor(FeatureSet):
915
999
  output_path: Optional[str] = None,
916
1000
  chunk_size: int = 200000,
917
1001
  ):
1002
+ """
1003
+ Transform the provided data using the fitted DataProcessor.
1004
+
1005
+ Args:
1006
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data to transform.
1007
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1008
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1009
+ output_path (Optional[str]): Output path to save the transformed data.
1010
+ chunk_size (int): Number of rows per chunk when streaming from path.
1011
+ Returns:
1012
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1013
+ """
1014
+
918
1015
  if not self.is_fitted:
919
1016
  raise ValueError(
920
1017
  "[Data Processor Error] DataProcessor must be fitted before transform"
@@ -943,6 +1040,19 @@ class DataProcessor(FeatureSet):
943
1040
  output_path: Optional[str] = None,
944
1041
  chunk_size: int = 200000,
945
1042
  ):
1043
+ """
1044
+ Fit the DataProcessor to the provided data and then transform it.
1045
+
1046
+ Args:
1047
+ data (Union[pd.DataFrame, Dict[str, Any], str, os.PathLike]): Input data for fitting and transforming.
1048
+ return_dict (bool): Whether to return a dictionary of numpy arrays.
1049
+ save_format (Optional[str]): Format to save the data if output_path is provided.
1050
+ output_path (Optional[str]): Output path to save the transformed data.
1051
+ chunk_size (int): Number of rows per chunk when streaming from path.
1052
+ Returns:
1053
+ Union[pd.DataFrame, Dict[str, np.ndarray], List[str]]: Transformed data or list of saved file paths.
1054
+ """
1055
+
946
1056
  self.fit(data, chunk_size=chunk_size)
947
1057
  return self.transform(
948
1058
  data,
@@ -952,6 +1062,12 @@ class DataProcessor(FeatureSet):
952
1062
  )
953
1063
 
954
1064
  def save(self, save_path: str | Path):
1065
+ """
1066
+ Save the fitted DataProcessor to a file.
1067
+
1068
+ Args:
1069
+ save_path (str | Path): Path to save the DataProcessor.
1070
+ """
955
1071
  logger = logging.getLogger()
956
1072
  assert isinstance(save_path, (str, Path)), "save_path must be a string or Path"
957
1073
  save_path = Path(save_path)
@@ -983,6 +1099,16 @@ class DataProcessor(FeatureSet):
983
1099
 
984
1100
  @classmethod
985
1101
  def load(cls, load_path: str | Path) -> "DataProcessor":
1102
+ """
1103
+ Load a fitted DataProcessor from a file.
1104
+
1105
+ Args:
1106
+ load_path (str | Path): Path to load the DataProcessor from.
1107
+
1108
+ Returns:
1109
+ DataProcessor: Loaded DataProcessor instance.
1110
+ """
1111
+
986
1112
  logger = logging.getLogger()
987
1113
  load_path = Path(load_path)
988
1114
  with open(load_path, "rb") as f:
@@ -1003,6 +1129,12 @@ class DataProcessor(FeatureSet):
1003
1129
  return processor
1004
1130
 
1005
1131
  def get_vocab_sizes(self) -> Dict[str, int]:
1132
+ """
1133
+ Get vocabulary sizes for all sparse and sequence features.
1134
+
1135
+ Returns:
1136
+ Dict[str, int]: Mapping of feature names to vocabulary sizes.
1137
+ """
1006
1138
  vocab_sizes = {}
1007
1139
  for name, config in self.sparse_features.items():
1008
1140
  vocab_sizes[name] = config.get("vocab_size", 0)