nextrec 0.4.31__py3-none-any.whl → 0.4.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +60 -12
- nextrec/basic/summary.py +2 -1
- nextrec/cli.py +56 -41
- nextrec/data/batch_utils.py +2 -2
- nextrec/data/preprocessor.py +125 -26
- nextrec/models/multi_task/[pre]aitm.py +3 -3
- nextrec/models/multi_task/[pre]snr_trans.py +3 -3
- nextrec/models/multi_task/[pre]star.py +3 -3
- nextrec/models/multi_task/apg.py +3 -3
- nextrec/models/multi_task/cross_stitch.py +3 -3
- nextrec/models/multi_task/escm.py +3 -3
- nextrec/models/multi_task/esmm.py +3 -3
- nextrec/models/multi_task/hmoe.py +3 -3
- nextrec/models/multi_task/mmoe.py +3 -3
- nextrec/models/multi_task/pepnet.py +4 -4
- nextrec/models/multi_task/ple.py +3 -3
- nextrec/models/multi_task/poso.py +3 -3
- nextrec/models/multi_task/share_bottom.py +3 -3
- nextrec/models/ranking/afm.py +3 -2
- nextrec/models/ranking/autoint.py +3 -2
- nextrec/models/ranking/dcn.py +3 -2
- nextrec/models/ranking/dcn_v2.py +3 -2
- nextrec/models/ranking/deepfm.py +3 -2
- nextrec/models/ranking/dien.py +3 -2
- nextrec/models/ranking/din.py +3 -2
- nextrec/models/ranking/eulernet.py +3 -2
- nextrec/models/ranking/ffm.py +3 -2
- nextrec/models/ranking/fibinet.py +3 -2
- nextrec/models/ranking/fm.py +3 -2
- nextrec/models/ranking/lr.py +3 -2
- nextrec/models/ranking/masknet.py +3 -2
- nextrec/models/ranking/pnn.py +3 -2
- nextrec/models/ranking/widedeep.py +3 -2
- nextrec/models/ranking/xdeepfm.py +3 -2
- nextrec/models/tree_base/__init__.py +15 -0
- nextrec/models/tree_base/base.py +693 -0
- nextrec/models/tree_base/catboost.py +97 -0
- nextrec/models/tree_base/lightgbm.py +69 -0
- nextrec/models/tree_base/xgboost.py +61 -0
- nextrec/utils/config.py +1 -0
- nextrec/utils/types.py +2 -0
- {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/METADATA +5 -5
- {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/RECORD +47 -42
- {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/licenses/LICENSE +1 -1
- {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/WHEEL +0 -0
- {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/entry_points.txt +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.33"
|
nextrec/basic/model.py
CHANGED
|
@@ -13,7 +13,7 @@ import sys
|
|
|
13
13
|
import pickle
|
|
14
14
|
import socket
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import Any, Literal
|
|
16
|
+
from typing import Any, Literal, cast, overload
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import pandas as pd
|
|
@@ -97,6 +97,7 @@ from nextrec.utils.types import (
|
|
|
97
97
|
SchedulerName,
|
|
98
98
|
TrainingModeName,
|
|
99
99
|
TaskTypeName,
|
|
100
|
+
TaskTypeInput,
|
|
100
101
|
MetricsName,
|
|
101
102
|
)
|
|
102
103
|
|
|
@@ -119,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
119
120
|
sequence_features: list[SequenceFeature] | None = None,
|
|
120
121
|
target: list[str] | str | None = None,
|
|
121
122
|
id_columns: list[str] | str | None = None,
|
|
122
|
-
task:
|
|
123
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
123
124
|
training_mode: TrainingModeName | list[TrainingModeName] | None = None,
|
|
124
125
|
embedding_l1_reg: float = 0.0,
|
|
125
126
|
dense_l1_reg: float = 0.0,
|
|
@@ -193,7 +194,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
193
194
|
dense_features, sparse_features, sequence_features, target, id_columns
|
|
194
195
|
)
|
|
195
196
|
|
|
196
|
-
self.task = task or self.default_task
|
|
197
|
+
self.task = cast(TaskTypeName | list[TaskTypeName], task or self.default_task)
|
|
197
198
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
198
199
|
|
|
199
200
|
training_mode = training_mode or "pointwise"
|
|
@@ -932,6 +933,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
932
933
|
|
|
933
934
|
existing_callbacks = self.callbacks.callbacks
|
|
934
935
|
|
|
936
|
+
has_validation = valid_data is not None or valid_split is not None
|
|
937
|
+
checkpoint_monitor = monitor_metric
|
|
938
|
+
checkpoint_mode = self.best_metrics_mode
|
|
939
|
+
if not has_validation:
|
|
940
|
+
checkpoint_monitor = "loss"
|
|
941
|
+
checkpoint_mode = "min"
|
|
942
|
+
|
|
935
943
|
if self.early_stop_patience > 0 and not any(
|
|
936
944
|
isinstance(cb, EarlyStopper) for cb in existing_callbacks
|
|
937
945
|
):
|
|
@@ -945,6 +953,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
945
953
|
)
|
|
946
954
|
)
|
|
947
955
|
|
|
956
|
+
has_validation = valid_data is not None or valid_split is not None
|
|
957
|
+
|
|
948
958
|
if self.is_main_process and not any(
|
|
949
959
|
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
950
960
|
):
|
|
@@ -952,9 +962,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
952
962
|
CheckpointSaver(
|
|
953
963
|
best_path=self.best_path,
|
|
954
964
|
checkpoint_path=self.checkpoint_path,
|
|
955
|
-
monitor=
|
|
956
|
-
mode=
|
|
957
|
-
save_best_only=
|
|
965
|
+
monitor=checkpoint_monitor,
|
|
966
|
+
mode=checkpoint_mode,
|
|
967
|
+
save_best_only=has_validation,
|
|
958
968
|
verbose=1,
|
|
959
969
|
)
|
|
960
970
|
)
|
|
@@ -1245,11 +1255,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1245
1255
|
epoch_logs[f"val_{k}"] = v
|
|
1246
1256
|
else:
|
|
1247
1257
|
epoch_logs = {**train_log_payload}
|
|
1248
|
-
if self.is_main_process:
|
|
1249
|
-
self.save_model(
|
|
1250
|
-
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
1251
|
-
)
|
|
1252
|
-
self.best_checkpoint_path = self.checkpoint_path
|
|
1253
1258
|
|
|
1254
1259
|
# Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
|
|
1255
1260
|
self.callbacks.on_epoch_end(epoch, epoch_logs)
|
|
@@ -1623,6 +1628,49 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1623
1628
|
)
|
|
1624
1629
|
return metrics_dict
|
|
1625
1630
|
|
|
1631
|
+
@overload
|
|
1632
|
+
def predict(
|
|
1633
|
+
self,
|
|
1634
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
1635
|
+
batch_size: int = 32,
|
|
1636
|
+
save_path: str | os.PathLike | None = None,
|
|
1637
|
+
save_format: str = "csv",
|
|
1638
|
+
include_ids: bool | None = None,
|
|
1639
|
+
id_columns: str | list[str] | None = None,
|
|
1640
|
+
return_dataframe: Literal[True] = True,
|
|
1641
|
+
stream_chunk_size: int = 10000,
|
|
1642
|
+
num_workers: int = 0,
|
|
1643
|
+
) -> pd.DataFrame: ...
|
|
1644
|
+
|
|
1645
|
+
@overload
|
|
1646
|
+
def predict(
|
|
1647
|
+
self,
|
|
1648
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
1649
|
+
batch_size: int = 32,
|
|
1650
|
+
save_path: None = None,
|
|
1651
|
+
save_format: str = "csv",
|
|
1652
|
+
include_ids: bool | None = None,
|
|
1653
|
+
id_columns: str | list[str] | None = None,
|
|
1654
|
+
return_dataframe: Literal[False] = False,
|
|
1655
|
+
stream_chunk_size: int = 10000,
|
|
1656
|
+
num_workers: int = 0,
|
|
1657
|
+
) -> np.ndarray: ...
|
|
1658
|
+
|
|
1659
|
+
@overload
|
|
1660
|
+
def predict(
|
|
1661
|
+
self,
|
|
1662
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
1663
|
+
batch_size: int = 32,
|
|
1664
|
+
*,
|
|
1665
|
+
save_path: str | os.PathLike,
|
|
1666
|
+
save_format: str = "csv",
|
|
1667
|
+
include_ids: bool | None = None,
|
|
1668
|
+
id_columns: str | list[str] | None = None,
|
|
1669
|
+
return_dataframe: Literal[False] = False,
|
|
1670
|
+
stream_chunk_size: int = 10000,
|
|
1671
|
+
num_workers: int = 0,
|
|
1672
|
+
) -> Path: ...
|
|
1673
|
+
|
|
1626
1674
|
def predict(
|
|
1627
1675
|
self,
|
|
1628
1676
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
@@ -2225,7 +2273,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2225
2273
|
dense_l2_reg: float = 0.0,
|
|
2226
2274
|
target: list[str] | str | None = "label",
|
|
2227
2275
|
id_columns: list[str] | str | None = None,
|
|
2228
|
-
task:
|
|
2276
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
2229
2277
|
session_id: str | None = None,
|
|
2230
2278
|
distributed: bool = False,
|
|
2231
2279
|
rank: int | None = None,
|
nextrec/basic/summary.py
CHANGED
|
@@ -73,7 +73,8 @@ class SummarySet:
|
|
|
73
73
|
def build_data_summary(
|
|
74
74
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
75
75
|
):
|
|
76
|
-
|
|
76
|
+
|
|
77
|
+
dataset = data_loader.dataset if data_loader is not None else None
|
|
77
78
|
|
|
78
79
|
train_size = get_data_length(dataset)
|
|
79
80
|
if train_size is None:
|
nextrec/cli.py
CHANGED
|
@@ -152,16 +152,19 @@ def train_model(train_config_path: str) -> None:
|
|
|
152
152
|
)
|
|
153
153
|
if data_cfg.get("valid_ratio") is not None:
|
|
154
154
|
logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
|
|
155
|
-
if data_cfg.get("
|
|
155
|
+
if data_cfg.get("valid_path"):
|
|
156
156
|
logger.info(
|
|
157
157
|
format_kv(
|
|
158
158
|
"Validation path",
|
|
159
159
|
resolve_path(
|
|
160
|
-
data_cfg.get("
|
|
160
|
+
data_cfg.get("valid_path"), config_dir
|
|
161
161
|
),
|
|
162
162
|
)
|
|
163
163
|
)
|
|
164
164
|
|
|
165
|
+
# Determine validation dataset path early for streaming split / fitting
|
|
166
|
+
val_data_path = data_cfg.get("valid_path")
|
|
167
|
+
|
|
165
168
|
if streaming:
|
|
166
169
|
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
167
170
|
log_kv_lines(
|
|
@@ -180,6 +183,34 @@ def train_model(train_config_path: str) -> None:
|
|
|
180
183
|
raise ValueError(f"Data file is empty: {first_file}") from exc
|
|
181
184
|
df_columns = list(first_chunk.columns)
|
|
182
185
|
|
|
186
|
+
# Decide training/validation file lists before fitting processor, to avoid
|
|
187
|
+
# leaking validation statistics into preprocessing (scalers/encoders).
|
|
188
|
+
streaming_train_files = file_paths
|
|
189
|
+
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
190
|
+
if val_data_path:
|
|
191
|
+
streaming_valid_files = None
|
|
192
|
+
elif streaming_valid_ratio is not None:
|
|
193
|
+
ratio = float(streaming_valid_ratio)
|
|
194
|
+
if not (0 < ratio < 1):
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
197
|
+
)
|
|
198
|
+
total_files = len(file_paths)
|
|
199
|
+
if total_files < 2:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"[NextRec CLI Error] Must provide valid_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
202
|
+
)
|
|
203
|
+
val_count = max(1, int(round(total_files * ratio)))
|
|
204
|
+
if val_count >= total_files:
|
|
205
|
+
val_count = total_files - 1
|
|
206
|
+
streaming_valid_files = file_paths[-val_count:]
|
|
207
|
+
streaming_train_files = file_paths[:-val_count]
|
|
208
|
+
logger.info(
|
|
209
|
+
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
streaming_valid_files = None
|
|
213
|
+
|
|
183
214
|
else:
|
|
184
215
|
df = read_table(data_path, data_cfg.get("format"))
|
|
185
216
|
logger.info(format_kv("Rows", len(df)))
|
|
@@ -215,7 +246,13 @@ def train_model(train_config_path: str) -> None:
|
|
|
215
246
|
)
|
|
216
247
|
|
|
217
248
|
if streaming:
|
|
218
|
-
|
|
249
|
+
if file_type is None:
|
|
250
|
+
raise ValueError("[NextRec CLI Error] Streaming mode requires a valid file_type")
|
|
251
|
+
processor.fit_from_files(
|
|
252
|
+
file_paths=streaming_train_files or file_paths,
|
|
253
|
+
file_type=file_type,
|
|
254
|
+
chunk_size=dataloader_chunk_size,
|
|
255
|
+
)
|
|
219
256
|
processed = None
|
|
220
257
|
df = None # type: ignore[assignment]
|
|
221
258
|
else:
|
|
@@ -232,34 +269,6 @@ def train_model(train_config_path: str) -> None:
|
|
|
232
269
|
sequence_names,
|
|
233
270
|
)
|
|
234
271
|
|
|
235
|
-
# Check if validation dataset path is specified
|
|
236
|
-
val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
|
|
237
|
-
if streaming:
|
|
238
|
-
if not file_paths:
|
|
239
|
-
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
240
|
-
streaming_train_files = file_paths
|
|
241
|
-
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
242
|
-
if val_data_path:
|
|
243
|
-
streaming_valid_files = None
|
|
244
|
-
elif streaming_valid_ratio is not None:
|
|
245
|
-
ratio = float(streaming_valid_ratio)
|
|
246
|
-
if not (0 < ratio < 1):
|
|
247
|
-
raise ValueError(
|
|
248
|
-
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
249
|
-
)
|
|
250
|
-
total_files = len(file_paths)
|
|
251
|
-
if total_files < 2:
|
|
252
|
-
raise ValueError(
|
|
253
|
-
"[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
254
|
-
)
|
|
255
|
-
val_count = max(1, int(round(total_files * ratio)))
|
|
256
|
-
if val_count >= total_files:
|
|
257
|
-
val_count = total_files - 1
|
|
258
|
-
streaming_valid_files = file_paths[-val_count:]
|
|
259
|
-
streaming_train_files = file_paths[:-val_count]
|
|
260
|
-
logger.info(
|
|
261
|
-
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
262
|
-
)
|
|
263
272
|
train_data: Dict[str, Any]
|
|
264
273
|
valid_data: Dict[str, Any] | None
|
|
265
274
|
|
|
@@ -682,16 +691,22 @@ Examples:
|
|
|
682
691
|
if not args.mode:
|
|
683
692
|
parser.error("[NextRec CLI Error] --mode is required (train|predict)")
|
|
684
693
|
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
694
|
+
try:
|
|
695
|
+
if args.mode == "train":
|
|
696
|
+
config_path = args.train_config
|
|
697
|
+
if not config_path:
|
|
698
|
+
parser.error("[NextRec CLI Error] train mode requires --train_config")
|
|
699
|
+
train_model(config_path)
|
|
700
|
+
else:
|
|
701
|
+
config_path = args.predict_config
|
|
702
|
+
if not config_path:
|
|
703
|
+
parser.error(
|
|
704
|
+
"[NextRec CLI Error] predict mode requires --predict_config"
|
|
705
|
+
)
|
|
706
|
+
predict_model(config_path)
|
|
707
|
+
except Exception:
|
|
708
|
+
logging.getLogger(__name__).exception("[NextRec CLI Error] Unhandled exception")
|
|
709
|
+
raise
|
|
695
710
|
|
|
696
711
|
|
|
697
712
|
if __name__ == "__main__":
|
nextrec/data/batch_utils.py
CHANGED
|
@@ -12,7 +12,7 @@ import torch
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def stack_section(batch: list[dict], section: Literal["features", "labels", "ids"]):
|
|
15
|
-
"""
|
|
15
|
+
"""
|
|
16
16
|
input example:
|
|
17
17
|
batch = [
|
|
18
18
|
{"features": {"f1": tensor1, "f2": tensor2}, "labels": {"label": tensor3}},
|
|
@@ -24,7 +24,7 @@ def stack_section(batch: list[dict], section: Literal["features", "labels", "ids
|
|
|
24
24
|
"f1": torch.stack([tensor1, tensor4], dim=0),
|
|
25
25
|
"f2": torch.stack([tensor2, tensor5], dim=0),
|
|
26
26
|
}
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
"""
|
|
29
29
|
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
30
30
|
if not entries:
|
nextrec/data/preprocessor.py
CHANGED
|
@@ -13,7 +13,7 @@ import logging
|
|
|
13
13
|
import os
|
|
14
14
|
import pickle
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from typing import Any, Dict, Literal, Optional, Union
|
|
16
|
+
from typing import Any, Dict, Literal, Optional, Union, overload
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import pandas as pd
|
|
@@ -566,35 +566,16 @@ class DataProcessor(FeatureSet):
|
|
|
566
566
|
return [str(v) for v in value]
|
|
567
567
|
return [str(value)]
|
|
568
568
|
|
|
569
|
-
def
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
Args:
|
|
574
|
-
path (str): File or directory path.
|
|
575
|
-
chunk_size (int): Number of rows per chunk.
|
|
576
|
-
|
|
577
|
-
Returns:
|
|
578
|
-
DataProcessor: Fitted DataProcessor instance.
|
|
579
|
-
"""
|
|
569
|
+
def fit_from_file_paths(
|
|
570
|
+
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
571
|
+
) -> "DataProcessor":
|
|
580
572
|
logger = logging.getLogger()
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
"Fitting DataProcessor (streaming path mode)...",
|
|
584
|
-
color="cyan",
|
|
585
|
-
bold=True,
|
|
586
|
-
)
|
|
587
|
-
)
|
|
588
|
-
for config in self.sparse_features.values():
|
|
589
|
-
config.pop("_min_freq_logged", None)
|
|
590
|
-
for config in self.sequence_features.values():
|
|
591
|
-
config.pop("_min_freq_logged", None)
|
|
592
|
-
file_paths, file_type = resolve_file_paths(path)
|
|
573
|
+
if not file_paths:
|
|
574
|
+
raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
|
|
593
575
|
if not check_streaming_support(file_type):
|
|
594
576
|
raise ValueError(
|
|
595
577
|
f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
|
|
596
|
-
"
|
|
597
|
-
"Use fit(dataframe) with in-memory data or convert the data format."
|
|
578
|
+
"Streaming fit only supports csv, parquet to avoid high memory usage."
|
|
598
579
|
)
|
|
599
580
|
|
|
600
581
|
numeric_acc = {}
|
|
@@ -636,6 +617,7 @@ class DataProcessor(FeatureSet):
|
|
|
636
617
|
target_values: Dict[str, set[Any]] = {
|
|
637
618
|
name: set() for name in self.target_features.keys()
|
|
638
619
|
}
|
|
620
|
+
|
|
639
621
|
missing_features = set()
|
|
640
622
|
for file_path in file_paths:
|
|
641
623
|
for chunk in iter_file_chunks(file_path, file_type, chunk_size):
|
|
@@ -702,10 +684,12 @@ class DataProcessor(FeatureSet):
|
|
|
702
684
|
for name in self.target_features.keys() & columns:
|
|
703
685
|
vals = chunk[name].dropna().tolist()
|
|
704
686
|
target_values[name].update(vals)
|
|
687
|
+
|
|
705
688
|
if missing_features:
|
|
706
689
|
logger.warning(
|
|
707
690
|
f"The following configured features were not found in provided files: {sorted(missing_features)}"
|
|
708
691
|
)
|
|
692
|
+
|
|
709
693
|
# finalize numeric scalers
|
|
710
694
|
for name, config in self.numeric_features.items():
|
|
711
695
|
acc = numeric_acc[name]
|
|
@@ -895,6 +879,91 @@ class DataProcessor(FeatureSet):
|
|
|
895
879
|
)
|
|
896
880
|
return self
|
|
897
881
|
|
|
882
|
+
def fit_from_files(
|
|
883
|
+
self, file_paths: list[str], file_type: str, chunk_size: int
|
|
884
|
+
) -> "DataProcessor":
|
|
885
|
+
"""Fit processor statistics by streaming an explicit list of files.
|
|
886
|
+
|
|
887
|
+
This is useful when you want to fit statistics on training files only (exclude
|
|
888
|
+
validation files) in streaming mode.
|
|
889
|
+
"""
|
|
890
|
+
logger = logging.getLogger()
|
|
891
|
+
logger.info(
|
|
892
|
+
colorize(
|
|
893
|
+
"Fitting DataProcessor (streaming files mode)...",
|
|
894
|
+
color="cyan",
|
|
895
|
+
bold=True,
|
|
896
|
+
)
|
|
897
|
+
)
|
|
898
|
+
for config in self.sparse_features.values():
|
|
899
|
+
config.pop("_min_freq_logged", None)
|
|
900
|
+
for config in self.sequence_features.values():
|
|
901
|
+
config.pop("_min_freq_logged", None)
|
|
902
|
+
uses_robust = any(
|
|
903
|
+
cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
|
|
904
|
+
)
|
|
905
|
+
if uses_robust:
|
|
906
|
+
logger.warning(
|
|
907
|
+
"Robust scaler requires full data; loading provided files into memory. "
|
|
908
|
+
"Consider smaller chunk_size or different scaler if memory is limited."
|
|
909
|
+
)
|
|
910
|
+
frames = [read_table(p, file_type) for p in file_paths]
|
|
911
|
+
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
|
912
|
+
return self.fit(df)
|
|
913
|
+
return self.fit_from_file_paths(file_paths=file_paths, file_type=file_type, chunk_size=chunk_size)
|
|
914
|
+
|
|
915
|
+
def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
|
|
916
|
+
"""
|
|
917
|
+
Fit processor statistics by streaming files to reduce memory usage.
|
|
918
|
+
|
|
919
|
+
Args:
|
|
920
|
+
path (str): File or directory path.
|
|
921
|
+
chunk_size (int): Number of rows per chunk.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
DataProcessor: Fitted DataProcessor instance.
|
|
925
|
+
"""
|
|
926
|
+
logger = logging.getLogger()
|
|
927
|
+
logger.info(
|
|
928
|
+
colorize(
|
|
929
|
+
"Fitting DataProcessor (streaming path mode)...",
|
|
930
|
+
color="cyan",
|
|
931
|
+
bold=True,
|
|
932
|
+
)
|
|
933
|
+
)
|
|
934
|
+
for config in self.sparse_features.values():
|
|
935
|
+
config.pop("_min_freq_logged", None)
|
|
936
|
+
for config in self.sequence_features.values():
|
|
937
|
+
config.pop("_min_freq_logged", None)
|
|
938
|
+
file_paths, file_type = resolve_file_paths(path)
|
|
939
|
+
return self.fit_from_file_paths(
|
|
940
|
+
file_paths=file_paths,
|
|
941
|
+
file_type=file_type,
|
|
942
|
+
chunk_size=chunk_size,
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
@overload
|
|
946
|
+
def transform_in_memory(
|
|
947
|
+
self,
|
|
948
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
949
|
+
return_dict: Literal[True],
|
|
950
|
+
persist: bool,
|
|
951
|
+
save_format: Optional[str],
|
|
952
|
+
output_path: Optional[str],
|
|
953
|
+
warn_missing: bool = True,
|
|
954
|
+
) -> Dict[str, np.ndarray]: ...
|
|
955
|
+
|
|
956
|
+
@overload
|
|
957
|
+
def transform_in_memory(
|
|
958
|
+
self,
|
|
959
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
960
|
+
return_dict: Literal[False],
|
|
961
|
+
persist: bool,
|
|
962
|
+
save_format: Optional[str],
|
|
963
|
+
output_path: Optional[str],
|
|
964
|
+
warn_missing: bool = True,
|
|
965
|
+
) -> pd.DataFrame: ...
|
|
966
|
+
|
|
898
967
|
def transform_in_memory(
|
|
899
968
|
self,
|
|
900
969
|
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
@@ -1238,6 +1307,36 @@ class DataProcessor(FeatureSet):
|
|
|
1238
1307
|
self.is_fitted = True
|
|
1239
1308
|
return self
|
|
1240
1309
|
|
|
1310
|
+
@overload
|
|
1311
|
+
def transform(
|
|
1312
|
+
self,
|
|
1313
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
1314
|
+
return_dict: Literal[True] = True,
|
|
1315
|
+
save_format: Optional[str] = None,
|
|
1316
|
+
output_path: Optional[str] = None,
|
|
1317
|
+
chunk_size: int = 200000,
|
|
1318
|
+
) -> Dict[str, np.ndarray]: ...
|
|
1319
|
+
|
|
1320
|
+
@overload
|
|
1321
|
+
def transform(
|
|
1322
|
+
self,
|
|
1323
|
+
data: Union[pd.DataFrame, Dict[str, Any]],
|
|
1324
|
+
return_dict: Literal[False] = False,
|
|
1325
|
+
save_format: Optional[str] = None,
|
|
1326
|
+
output_path: Optional[str] = None,
|
|
1327
|
+
chunk_size: int = 200000,
|
|
1328
|
+
) -> pd.DataFrame: ...
|
|
1329
|
+
|
|
1330
|
+
@overload
|
|
1331
|
+
def transform(
|
|
1332
|
+
self,
|
|
1333
|
+
data: str | os.PathLike,
|
|
1334
|
+
return_dict: Literal[False] = False,
|
|
1335
|
+
save_format: Optional[str] = None,
|
|
1336
|
+
output_path: Optional[str] = None,
|
|
1337
|
+
chunk_size: int = 200000,
|
|
1338
|
+
) -> list[str]: ...
|
|
1339
|
+
|
|
1241
1340
|
def transform(
|
|
1242
1341
|
self,
|
|
1243
1342
|
data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
|
|
@@ -20,7 +20,7 @@ from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
|
20
20
|
from nextrec.basic.heads import TaskHead
|
|
21
21
|
from nextrec.basic.model import BaseModel
|
|
22
22
|
from nextrec.utils.model import get_mlp_output_dim
|
|
23
|
-
from nextrec.utils.types import
|
|
23
|
+
from nextrec.utils.types import TaskTypeInput
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class AITMTransfer(nn.Module):
|
|
@@ -76,7 +76,7 @@ class AITM(BaseModel):
|
|
|
76
76
|
tower_mlp_params_list: list[dict] | None = None,
|
|
77
77
|
calibrator_alpha: float = 0.1,
|
|
78
78
|
target: list[str] | str | None = None,
|
|
79
|
-
task: list[
|
|
79
|
+
task: list[TaskTypeInput] | None = None,
|
|
80
80
|
**kwargs,
|
|
81
81
|
):
|
|
82
82
|
dense_features = dense_features or []
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Ma J, Zhao Z, Chen J, Li A, Hong L, Chi EH. SNR: Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning in E-Commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 33rd AAAI Conference on Artificial Intelligence (AAAI 2019), 2019, pp. 216-223.
|
|
@@ -22,7 +22,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
22
22
|
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
23
23
|
from nextrec.basic.heads import TaskHead
|
|
24
24
|
from nextrec.basic.model import BaseModel
|
|
25
|
-
from nextrec.utils.types import TaskTypeName
|
|
25
|
+
from nextrec.utils.types import TaskTypeInput, TaskTypeName
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class SNRTransGate(nn.Module):
|
|
@@ -101,7 +101,7 @@ class SNRTrans(BaseModel):
|
|
|
101
101
|
num_experts: int = 4,
|
|
102
102
|
tower_mlp_params_list: list[dict] | None = None,
|
|
103
103
|
target: list[str] | str | None = None,
|
|
104
|
-
task:
|
|
104
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
105
105
|
**kwargs,
|
|
106
106
|
) -> None:
|
|
107
107
|
dense_features = dense_features or []
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Sheng XR, Zhao L, Zhou G, Ding X, Dai B, Luo Q, Yang S, Lv J, Zhang C, Deng H, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427, 2021.
|
|
@@ -22,7 +22,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
22
22
|
from nextrec.basic.heads import TaskHead
|
|
23
23
|
from nextrec.basic.layers import DomainBatchNorm, EmbeddingLayer
|
|
24
24
|
from nextrec.basic.model import BaseModel
|
|
25
|
-
from nextrec.utils.types import TaskTypeName
|
|
25
|
+
from nextrec.utils.types import TaskTypeInput, TaskTypeName
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class SharedSpecificLinear(nn.Module):
|
|
@@ -73,7 +73,7 @@ class STAR(BaseModel):
|
|
|
73
73
|
sparse_features: list[SparseFeature] | None = None,
|
|
74
74
|
sequence_features: list[SequenceFeature] | None = None,
|
|
75
75
|
target: list[str] | str | None = None,
|
|
76
|
-
task:
|
|
76
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
77
77
|
mlp_params: dict | None = None,
|
|
78
78
|
use_shared: bool = True,
|
|
79
79
|
**kwargs,
|
nextrec/models/multi_task/apg.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Yan B, Wang P, Zhang K, Li F, Deng H, Xu J, Zheng B. APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction. Advances in Neural Information Processing Systems 35 (NeurIPS 2022), 2022.
|
|
@@ -20,7 +20,7 @@ from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
|
20
20
|
from nextrec.basic.heads import TaskHead
|
|
21
21
|
from nextrec.basic.model import BaseModel
|
|
22
22
|
from nextrec.utils.model import select_features
|
|
23
|
-
from nextrec.utils.types import ActivationName, TaskTypeName
|
|
23
|
+
from nextrec.utils.types import ActivationName, TaskTypeInput, TaskTypeName
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class APGLayer(nn.Module):
|
|
@@ -233,7 +233,7 @@ class APG(BaseModel):
|
|
|
233
233
|
sparse_features: list[SparseFeature] | None = None,
|
|
234
234
|
sequence_features: list[SequenceFeature] | None = None,
|
|
235
235
|
target: list[str] | str | None = None,
|
|
236
|
-
task:
|
|
236
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
237
237
|
mlp_params: dict | None = None,
|
|
238
238
|
inner_activation: ActivationName | None = None,
|
|
239
239
|
generate_activation: ActivationName | None = None,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Misra I, Shrivastava A, Gupta A, Hebert M. Cross-Stitch Networks for Multi-Task Learning. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2016), 2016, pp. 3994–4003.
|
|
@@ -21,7 +21,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
21
21
|
from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
22
22
|
from nextrec.basic.heads import TaskHead
|
|
23
23
|
from nextrec.basic.model import BaseModel
|
|
24
|
-
from nextrec.utils.types import TaskTypeName
|
|
24
|
+
from nextrec.utils.types import TaskTypeInput, TaskTypeName
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class CrossStitchLayer(nn.Module):
|
|
@@ -76,7 +76,7 @@ class CrossStitch(BaseModel):
|
|
|
76
76
|
sparse_features: list[SparseFeature] | None = None,
|
|
77
77
|
sequence_features: list[SequenceFeature] | None = None,
|
|
78
78
|
target: list[str] | str | None = None,
|
|
79
|
-
task:
|
|
79
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
80
80
|
shared_mlp_params: dict | None = None,
|
|
81
81
|
task_mlp_params: dict | None = None,
|
|
82
82
|
tower_mlp_params: dict | None = None,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 01/01/2026
|
|
3
|
-
Checkpoint: edit on 01/
|
|
3
|
+
Checkpoint: edit on 01/14/2026
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
- [1] Wang H, Chang T-W, Liu T, Huang J, Chen Z, Yu C, Li R, Chu W. ESCM²: Entire Space Counterfactual Multi-Task Model for Post-Click Conversion Rate Estimation. Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’22), 2022:363–372.
|
|
@@ -23,7 +23,7 @@ from nextrec.basic.layers import EmbeddingLayer, MLP
|
|
|
23
23
|
from nextrec.basic.model import BaseModel
|
|
24
24
|
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
25
25
|
from nextrec.utils.model import compute_ranking_loss
|
|
26
|
-
from nextrec.utils.types import TaskTypeName
|
|
26
|
+
from nextrec.utils.types import TaskTypeInput, TaskTypeName
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class ESCM(BaseModel):
|
|
@@ -52,7 +52,7 @@ class ESCM(BaseModel):
|
|
|
52
52
|
imp_mlp_params: dict | None = None,
|
|
53
53
|
use_dr: bool = False,
|
|
54
54
|
target: list[str] | str | None = None,
|
|
55
|
-
task:
|
|
55
|
+
task: TaskTypeInput | list[TaskTypeInput] | None = None,
|
|
56
56
|
**kwargs,
|
|
57
57
|
) -> None:
|
|
58
58
|
dense_features = dense_features or []
|