nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +4 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +209 -316
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/loss_utils.py +5 -30
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +12 -14
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +59 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
- nextrec-0.4.21.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
nextrec/models/retrieval/sdm.py
CHANGED
|
@@ -49,12 +49,10 @@ class SDM(BaseMatchModel):
|
|
|
49
49
|
num_negative_samples: int = 4,
|
|
50
50
|
temperature: float = 1.0,
|
|
51
51
|
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
dense_l2_reg: float = 0.0,
|
|
57
|
-
early_stop_patience: int = 20,
|
|
52
|
+
embedding_l1_reg=0.0,
|
|
53
|
+
dense_l1_reg=0.0,
|
|
54
|
+
embedding_l2_reg=0.0,
|
|
55
|
+
dense_l2_reg=0.0,
|
|
58
56
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
59
57
|
optimizer_params: dict | None = None,
|
|
60
58
|
scheduler: (
|
|
@@ -80,12 +78,10 @@ class SDM(BaseMatchModel):
|
|
|
80
78
|
num_negative_samples=num_negative_samples,
|
|
81
79
|
temperature=temperature,
|
|
82
80
|
similarity_metric=similarity_metric,
|
|
83
|
-
device=device,
|
|
84
81
|
embedding_l1_reg=embedding_l1_reg,
|
|
85
82
|
dense_l1_reg=dense_l1_reg,
|
|
86
83
|
embedding_l2_reg=embedding_l2_reg,
|
|
87
84
|
dense_l2_reg=dense_l2_reg,
|
|
88
|
-
early_stop_patience=early_stop_patience,
|
|
89
85
|
**kwargs,
|
|
90
86
|
)
|
|
91
87
|
|
|
@@ -202,8 +198,6 @@ class SDM(BaseMatchModel):
|
|
|
202
198
|
loss_params=loss_params,
|
|
203
199
|
)
|
|
204
200
|
|
|
205
|
-
self.to(device)
|
|
206
|
-
|
|
207
201
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
208
202
|
seq_feature = self.user_sequence_features[0]
|
|
209
203
|
seq_input = user_input[seq_feature.name]
|
|
@@ -50,12 +50,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
50
50
|
num_negative_samples: int = 100,
|
|
51
51
|
temperature: float = 1.0,
|
|
52
52
|
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
dense_l2_reg: float = 0.0,
|
|
58
|
-
early_stop_patience: int = 20,
|
|
53
|
+
embedding_l1_reg=0.0,
|
|
54
|
+
dense_l1_reg=0.0,
|
|
55
|
+
embedding_l2_reg=0.0,
|
|
56
|
+
dense_l2_reg=0.0,
|
|
59
57
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
60
58
|
optimizer_params: dict | None = None,
|
|
61
59
|
scheduler: (
|
|
@@ -81,12 +79,10 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
81
79
|
num_negative_samples=num_negative_samples,
|
|
82
80
|
temperature=temperature,
|
|
83
81
|
similarity_metric=similarity_metric,
|
|
84
|
-
device=device,
|
|
85
82
|
embedding_l1_reg=embedding_l1_reg,
|
|
86
83
|
dense_l1_reg=dense_l1_reg,
|
|
87
84
|
embedding_l2_reg=embedding_l2_reg,
|
|
88
85
|
dense_l2_reg=dense_l2_reg,
|
|
89
|
-
early_stop_patience=early_stop_patience,
|
|
90
86
|
**kwargs,
|
|
91
87
|
)
|
|
92
88
|
|
|
@@ -169,8 +165,6 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
169
165
|
loss_params=loss_params,
|
|
170
166
|
)
|
|
171
167
|
|
|
172
|
-
self.to(device)
|
|
173
|
-
|
|
174
168
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
175
169
|
"""
|
|
176
170
|
User tower to encode historical behavior sequences and user features.
|
|
@@ -332,7 +332,6 @@ class HSTU(BaseModel):
|
|
|
332
332
|
dense_l1_reg: float = 0.0,
|
|
333
333
|
embedding_l2_reg: float = 0.0,
|
|
334
334
|
dense_l2_reg: float = 0.0,
|
|
335
|
-
device: str = "cpu",
|
|
336
335
|
**kwargs,
|
|
337
336
|
):
|
|
338
337
|
raise NotImplementedError(
|
|
@@ -348,7 +347,7 @@ class HSTU(BaseModel):
|
|
|
348
347
|
)[0]
|
|
349
348
|
|
|
350
349
|
self.hidden_dim = hidden_dim or max(
|
|
351
|
-
int(
|
|
350
|
+
int(self.item_history_feature.embedding_dim or 0), 32
|
|
352
351
|
)
|
|
353
352
|
# Make hidden_dim divisible by num_heads
|
|
354
353
|
if self.hidden_dim % num_heads != 0:
|
|
@@ -368,7 +367,6 @@ class HSTU(BaseModel):
|
|
|
368
367
|
sequence_features=sequence_features,
|
|
369
368
|
target=target,
|
|
370
369
|
task=task or self.default_task,
|
|
371
|
-
device=device,
|
|
372
370
|
embedding_l1_reg=embedding_l1_reg,
|
|
373
371
|
dense_l1_reg=dense_l1_reg,
|
|
374
372
|
embedding_l2_reg=embedding_l2_reg,
|
nextrec/utils/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@ from .config import (
|
|
|
14
14
|
load_model_class,
|
|
15
15
|
register_processor_features,
|
|
16
16
|
resolve_path,
|
|
17
|
+
safe_value,
|
|
17
18
|
select_features,
|
|
18
19
|
)
|
|
19
20
|
from .console import (
|
|
@@ -35,23 +36,19 @@ from .data import (
|
|
|
35
36
|
resolve_file_paths,
|
|
36
37
|
)
|
|
37
38
|
from .embedding import get_auto_embedding_dim
|
|
38
|
-
from .feature import
|
|
39
|
+
from .feature import to_list
|
|
39
40
|
from .model import compute_pair_scores, get_mlp_output_dim, merge_features
|
|
40
41
|
from .torch_utils import (
|
|
41
42
|
add_distributed_sampler,
|
|
42
|
-
|
|
43
|
-
configure_device,
|
|
43
|
+
get_device,
|
|
44
44
|
gather_numpy,
|
|
45
|
-
get_device_info,
|
|
46
45
|
get_initializer,
|
|
47
46
|
get_optimizer,
|
|
48
47
|
get_scheduler,
|
|
49
48
|
init_process_group,
|
|
50
|
-
pad_sequence_tensors,
|
|
51
|
-
resolve_device,
|
|
52
|
-
stack_tensors,
|
|
53
49
|
to_tensor,
|
|
54
50
|
)
|
|
51
|
+
from .types import LossName, OptimizerName, SchedulerName, ActivationName
|
|
55
52
|
|
|
56
53
|
__all__ = [
|
|
57
54
|
# Console utilities
|
|
@@ -67,17 +64,12 @@ __all__ = [
|
|
|
67
64
|
# Embedding utilities
|
|
68
65
|
"get_auto_embedding_dim",
|
|
69
66
|
# Device utilities (torch utils)
|
|
70
|
-
"
|
|
71
|
-
"get_device_info",
|
|
72
|
-
"configure_device",
|
|
67
|
+
"get_device",
|
|
73
68
|
"init_process_group",
|
|
74
69
|
"gather_numpy",
|
|
75
70
|
"add_distributed_sampler",
|
|
76
71
|
# Tensor utilities
|
|
77
72
|
"to_tensor",
|
|
78
|
-
"stack_tensors",
|
|
79
|
-
"concat_tensors",
|
|
80
|
-
"pad_sequence_tensors",
|
|
81
73
|
# Data utilities
|
|
82
74
|
"resolve_file_paths",
|
|
83
75
|
"read_table",
|
|
@@ -90,9 +82,10 @@ __all__ = [
|
|
|
90
82
|
"get_mlp_output_dim",
|
|
91
83
|
"compute_pair_scores",
|
|
92
84
|
# Feature utilities
|
|
93
|
-
"
|
|
85
|
+
"to_list",
|
|
94
86
|
# Config utilities
|
|
95
87
|
"resolve_path",
|
|
88
|
+
"safe_value",
|
|
96
89
|
"register_processor_features",
|
|
97
90
|
"build_feature_objects",
|
|
98
91
|
"extract_feature_groups",
|
|
@@ -109,4 +102,9 @@ __all__ = [
|
|
|
109
102
|
"data",
|
|
110
103
|
"embedding",
|
|
111
104
|
"torch_utils",
|
|
105
|
+
# Type aliases
|
|
106
|
+
"OptimizerName",
|
|
107
|
+
"SchedulerName",
|
|
108
|
+
"LossName",
|
|
109
|
+
"ActivationName",
|
|
112
110
|
]
|
nextrec/utils/config.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from nextrec.utils.feature import
|
|
24
|
+
from nextrec.utils.feature import to_list
|
|
25
25
|
|
|
26
26
|
if TYPE_CHECKING:
|
|
27
27
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
@@ -52,6 +52,16 @@ def resolve_path(
|
|
|
52
52
|
)
|
|
53
53
|
|
|
54
54
|
|
|
55
|
+
def safe_value(value: Any):
|
|
56
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
57
|
+
return value
|
|
58
|
+
if isinstance(value, dict):
|
|
59
|
+
return {str(k): safe_value(v) for k, v in value.items()}
|
|
60
|
+
if isinstance(value, (list, tuple)):
|
|
61
|
+
return [safe_value(v) for v in value]
|
|
62
|
+
return str(value)
|
|
63
|
+
|
|
64
|
+
|
|
55
65
|
def select_features(
|
|
56
66
|
feature_cfg: Dict[str, Any], df_columns: List[str]
|
|
57
67
|
) -> Tuple[List[str], List[str], List[str]]:
|
|
@@ -152,9 +162,9 @@ def build_feature_objects(
|
|
|
152
162
|
dense_features.append(
|
|
153
163
|
DenseFeature(
|
|
154
164
|
name=name,
|
|
155
|
-
|
|
165
|
+
proj_dim=embed_cfg.get("proj_dim"),
|
|
156
166
|
input_dim=embed_cfg.get("input_dim", 1),
|
|
157
|
-
|
|
167
|
+
use_projection=embed_cfg.get("use_projection", False),
|
|
158
168
|
)
|
|
159
169
|
)
|
|
160
170
|
|
|
@@ -239,7 +249,7 @@ def extract_feature_groups(
|
|
|
239
249
|
collected: List[str] = []
|
|
240
250
|
|
|
241
251
|
for group_name, names in feature_groups.items():
|
|
242
|
-
name_list =
|
|
252
|
+
name_list = to_list(names)
|
|
243
253
|
filtered = []
|
|
244
254
|
missing_defined = [n for n in name_list if n not in defined]
|
|
245
255
|
missing_cols = [n for n in name_list if n not in available_cols]
|
|
@@ -441,7 +451,7 @@ def build_model_instance(
|
|
|
441
451
|
direct_features = binding.get("features") or binding.get("feature_names")
|
|
442
452
|
if direct_features and (accepts(param_name) or accepts_var_kwargs):
|
|
443
453
|
init_kwargs[param_name] = _select(
|
|
444
|
-
|
|
454
|
+
to_list(direct_features),
|
|
445
455
|
feature_pool,
|
|
446
456
|
f"feature_bindings.{param_name}",
|
|
447
457
|
)
|
nextrec/utils/console.py
CHANGED
|
@@ -36,7 +36,7 @@ from rich.progress import (
|
|
|
36
36
|
from rich.table import Table
|
|
37
37
|
from rich.text import Text
|
|
38
38
|
|
|
39
|
-
from nextrec.utils.feature import as_float,
|
|
39
|
+
from nextrec.utils.feature import as_float, to_list
|
|
40
40
|
|
|
41
41
|
T = TypeVar("T")
|
|
42
42
|
|
|
@@ -283,7 +283,7 @@ def display_metrics_table(
|
|
|
283
283
|
if not is_main_process:
|
|
284
284
|
return
|
|
285
285
|
|
|
286
|
-
target_list =
|
|
286
|
+
target_list = to_list(target_names)
|
|
287
287
|
task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
|
|
288
288
|
|
|
289
289
|
if isinstance(loss, np.ndarray) and target_list:
|
nextrec/utils/feature.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Feature processing utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -10,7 +10,7 @@ import numbers
|
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def to_list(value: str | list[str] | None) -> list[str]:
|
|
14
14
|
if value is None:
|
|
15
15
|
return []
|
|
16
16
|
if isinstance(value, str):
|
nextrec/utils/torch_utils.py
CHANGED
|
@@ -3,12 +3,16 @@ PyTorch-related utilities for NextRec.
|
|
|
3
3
|
|
|
4
4
|
This module groups device setup, distributed helpers, optimizers/schedulers,
|
|
5
5
|
initialization, and tensor helpers.
|
|
6
|
+
|
|
7
|
+
Date: create on 27/10/2025
|
|
8
|
+
Checkpoint: edit on 27/12/2025
|
|
9
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
10
|
"""
|
|
7
11
|
|
|
8
12
|
from __future__ import annotations
|
|
9
13
|
|
|
10
14
|
import logging
|
|
11
|
-
from typing import Any, Dict, Iterable,
|
|
15
|
+
from typing import Any, Dict, Iterable, Literal
|
|
12
16
|
|
|
13
17
|
import numpy as np
|
|
14
18
|
import torch
|
|
@@ -18,26 +22,25 @@ from torch.utils.data import DataLoader, IterableDataset
|
|
|
18
22
|
from torch.utils.data.distributed import DistributedSampler
|
|
19
23
|
|
|
20
24
|
from nextrec.basic.loggers import colorize
|
|
21
|
-
|
|
22
|
-
KNOWN_NONLINEARITIES: Set[str] = {
|
|
23
|
-
"linear",
|
|
24
|
-
"conv1d",
|
|
25
|
-
"conv2d",
|
|
26
|
-
"conv3d",
|
|
27
|
-
"conv_transpose1d",
|
|
28
|
-
"conv_transpose2d",
|
|
29
|
-
"conv_transpose3d",
|
|
30
|
-
"sigmoid",
|
|
31
|
-
"tanh",
|
|
32
|
-
"relu",
|
|
33
|
-
"leaky_relu",
|
|
34
|
-
"selu",
|
|
35
|
-
"gelu",
|
|
36
|
-
}
|
|
25
|
+
from nextrec.utils.types import OptimizerName, SchedulerName
|
|
37
26
|
|
|
38
27
|
|
|
39
28
|
def resolve_nonlinearity(activation: str) -> str:
|
|
40
|
-
if activation in
|
|
29
|
+
if activation in [
|
|
30
|
+
"linear",
|
|
31
|
+
"conv1d",
|
|
32
|
+
"conv2d",
|
|
33
|
+
"conv3d",
|
|
34
|
+
"conv_transpose1d",
|
|
35
|
+
"conv_transpose2d",
|
|
36
|
+
"conv_transpose3d",
|
|
37
|
+
"sigmoid",
|
|
38
|
+
"tanh",
|
|
39
|
+
"relu",
|
|
40
|
+
"leaky_relu",
|
|
41
|
+
"selu",
|
|
42
|
+
"gelu",
|
|
43
|
+
]:
|
|
41
44
|
return activation
|
|
42
45
|
return "linear"
|
|
43
46
|
|
|
@@ -53,8 +56,30 @@ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
|
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def get_initializer(
|
|
56
|
-
init_type:
|
|
57
|
-
|
|
59
|
+
init_type: Literal[
|
|
60
|
+
"xavier_uniform",
|
|
61
|
+
"xavier_normal",
|
|
62
|
+
"kaiming_uniform",
|
|
63
|
+
"kaiming_normal",
|
|
64
|
+
"orthogonal",
|
|
65
|
+
"normal",
|
|
66
|
+
"uniform",
|
|
67
|
+
] = "normal",
|
|
68
|
+
activation: Literal[
|
|
69
|
+
"linear",
|
|
70
|
+
"conv1d",
|
|
71
|
+
"conv2d",
|
|
72
|
+
"conv3d",
|
|
73
|
+
"conv_transpose1d",
|
|
74
|
+
"conv_transpose2d",
|
|
75
|
+
"conv_transpose3d",
|
|
76
|
+
"sigmoid",
|
|
77
|
+
"tanh",
|
|
78
|
+
"relu",
|
|
79
|
+
"leaky_relu",
|
|
80
|
+
"selu",
|
|
81
|
+
"gelu",
|
|
82
|
+
] = "linear",
|
|
58
83
|
param: Dict[str, Any] | None = None,
|
|
59
84
|
):
|
|
60
85
|
param = param or {}
|
|
@@ -89,47 +114,14 @@ def get_initializer(
|
|
|
89
114
|
return initializer_fn
|
|
90
115
|
|
|
91
116
|
|
|
92
|
-
def
|
|
93
|
-
if torch.cuda.is_available():
|
|
94
|
-
return "cuda"
|
|
95
|
-
if torch.backends.mps.is_available():
|
|
96
|
-
import platform
|
|
97
|
-
|
|
98
|
-
mac_ver = platform.mac_ver()[0]
|
|
99
|
-
try:
|
|
100
|
-
major, _ = (int(x) for x in mac_ver.split(".")[:2])
|
|
101
|
-
except Exception:
|
|
102
|
-
major, _ = 0, 0
|
|
103
|
-
if major >= 14:
|
|
104
|
-
return "mps"
|
|
105
|
-
return "cpu"
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def get_device_info() -> dict:
|
|
109
|
-
info = {
|
|
110
|
-
"cuda_available": torch.cuda.is_available(),
|
|
111
|
-
"cuda_device_count": (
|
|
112
|
-
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
113
|
-
),
|
|
114
|
-
"mps_available": torch.backends.mps.is_available(),
|
|
115
|
-
"current_device": resolve_device(),
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
if torch.cuda.is_available():
|
|
119
|
-
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
|
120
|
-
info["cuda_capability"] = torch.cuda.get_device_capability(0)
|
|
121
|
-
|
|
122
|
-
return info
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def configure_device(
|
|
117
|
+
def get_device(
|
|
126
118
|
distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
|
|
127
119
|
) -> torch.device:
|
|
128
120
|
try:
|
|
129
121
|
device = torch.device(base_device)
|
|
130
122
|
except Exception:
|
|
131
123
|
logging.warning(
|
|
132
|
-
"[
|
|
124
|
+
"[get_device Warning] Invalid base_device, falling back to CPU."
|
|
133
125
|
)
|
|
134
126
|
return torch.device("cpu")
|
|
135
127
|
|
|
@@ -158,7 +150,7 @@ def configure_device(
|
|
|
158
150
|
|
|
159
151
|
|
|
160
152
|
def get_optimizer(
|
|
161
|
-
optimizer:
|
|
153
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
162
154
|
params: Iterable[torch.nn.Parameter] | None = None,
|
|
163
155
|
**optimizer_params,
|
|
164
156
|
):
|
|
@@ -191,7 +183,7 @@ def get_optimizer(
|
|
|
191
183
|
|
|
192
184
|
def get_scheduler(
|
|
193
185
|
scheduler: (
|
|
194
|
-
|
|
186
|
+
SchedulerName
|
|
195
187
|
| torch.optim.lr_scheduler._LRScheduler
|
|
196
188
|
| torch.optim.lr_scheduler.LRScheduler
|
|
197
189
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -241,51 +233,6 @@ def to_tensor(
|
|
|
241
233
|
return tensor
|
|
242
234
|
|
|
243
235
|
|
|
244
|
-
def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
245
|
-
if not tensors:
|
|
246
|
-
raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
|
|
247
|
-
return torch.stack(tensors, dim=dim)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
251
|
-
if not tensors:
|
|
252
|
-
raise ValueError(
|
|
253
|
-
"[Tensor Utils Error] Cannot concatenate empty list of tensors."
|
|
254
|
-
)
|
|
255
|
-
return torch.cat(tensors, dim=dim)
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def pad_sequence_tensors(
|
|
259
|
-
tensors: list[torch.Tensor],
|
|
260
|
-
max_len: int | None = None,
|
|
261
|
-
padding_value: float = 0.0,
|
|
262
|
-
padding_side: str = "right",
|
|
263
|
-
) -> torch.Tensor:
|
|
264
|
-
if not tensors:
|
|
265
|
-
raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
|
|
266
|
-
if max_len is None:
|
|
267
|
-
max_len = max(t.size(0) for t in tensors)
|
|
268
|
-
batch_size = len(tensors)
|
|
269
|
-
padded = torch.full(
|
|
270
|
-
(batch_size, max_len),
|
|
271
|
-
padding_value,
|
|
272
|
-
dtype=tensors[0].dtype,
|
|
273
|
-
device=tensors[0].device,
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
for i, tensor in enumerate(tensors):
|
|
277
|
-
length = min(tensor.size(0), max_len)
|
|
278
|
-
if padding_side == "right":
|
|
279
|
-
padded[i, :length] = tensor[:length]
|
|
280
|
-
elif padding_side == "left":
|
|
281
|
-
padded[i, -length:] = tensor[:length]
|
|
282
|
-
else:
|
|
283
|
-
raise ValueError(
|
|
284
|
-
f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
|
|
285
|
-
)
|
|
286
|
-
return padded
|
|
287
|
-
|
|
288
|
-
|
|
289
236
|
def init_process_group(
|
|
290
237
|
distributed: bool, rank: int, world_size: int, device_id: int | None = None
|
|
291
238
|
) -> None:
|
|
@@ -350,7 +297,7 @@ def add_distributed_sampler(
|
|
|
350
297
|
# return if already has DistributedSampler
|
|
351
298
|
if isinstance(loader.sampler, DistributedSampler):
|
|
352
299
|
return loader, loader.sampler
|
|
353
|
-
dataset =
|
|
300
|
+
dataset = loader.dataset
|
|
354
301
|
if dataset is None:
|
|
355
302
|
return loader, None
|
|
356
303
|
if isinstance(dataset, IterableDataset):
|
|
@@ -379,25 +326,23 @@ def add_distributed_sampler(
|
|
|
379
326
|
"collate_fn": loader.collate_fn,
|
|
380
327
|
"drop_last": drop_last,
|
|
381
328
|
}
|
|
382
|
-
if
|
|
329
|
+
if loader.pin_memory:
|
|
383
330
|
loader_kwargs["pin_memory"] = True
|
|
384
|
-
pin_memory_device =
|
|
331
|
+
pin_memory_device = loader.pin_memory_device
|
|
385
332
|
if pin_memory_device:
|
|
386
333
|
loader_kwargs["pin_memory_device"] = pin_memory_device
|
|
387
|
-
timeout =
|
|
334
|
+
timeout = loader.timeout
|
|
388
335
|
if timeout:
|
|
389
336
|
loader_kwargs["timeout"] = timeout
|
|
390
|
-
worker_init_fn =
|
|
337
|
+
worker_init_fn = loader.worker_init_fn
|
|
391
338
|
if worker_init_fn is not None:
|
|
392
339
|
loader_kwargs["worker_init_fn"] = worker_init_fn
|
|
393
|
-
generator =
|
|
340
|
+
generator = loader.generator
|
|
394
341
|
if generator is not None:
|
|
395
342
|
loader_kwargs["generator"] = generator
|
|
396
343
|
if loader.num_workers > 0:
|
|
397
|
-
loader_kwargs["persistent_workers"] =
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
prefetch_factor = getattr(loader, "prefetch_factor", None)
|
|
344
|
+
loader_kwargs["persistent_workers"] = loader.persistent_workers
|
|
345
|
+
prefetch_factor = loader.prefetch_factor
|
|
401
346
|
if prefetch_factor is not None:
|
|
402
347
|
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
403
348
|
distributed_loader = DataLoader(dataset, **loader_kwargs)
|
nextrec/utils/types.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared type aliases for NextRec.
|
|
3
|
+
|
|
4
|
+
Keep Literal-based public string options centralized to avoid drift.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
OptimizerName = Literal["adam", "sgd", "adamw", "adagrad", "rmsprop"]
|
|
10
|
+
|
|
11
|
+
SchedulerName = Literal["step", "cosine"]
|
|
12
|
+
|
|
13
|
+
LossName = Literal[
|
|
14
|
+
"bce",
|
|
15
|
+
"binary_crossentropy",
|
|
16
|
+
"weighted_bce",
|
|
17
|
+
"focal",
|
|
18
|
+
"focal_loss",
|
|
19
|
+
"cb_focal",
|
|
20
|
+
"class_balanced_focal",
|
|
21
|
+
"crossentropy",
|
|
22
|
+
"ce",
|
|
23
|
+
"mse",
|
|
24
|
+
"mae",
|
|
25
|
+
"bpr",
|
|
26
|
+
"hinge",
|
|
27
|
+
"triplet",
|
|
28
|
+
"sampled_softmax",
|
|
29
|
+
"softmax",
|
|
30
|
+
"infonce",
|
|
31
|
+
"listnet",
|
|
32
|
+
"listmle",
|
|
33
|
+
"approx_ndcg",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
ActivationName = Literal[
|
|
37
|
+
"dice",
|
|
38
|
+
"relu",
|
|
39
|
+
"relu6",
|
|
40
|
+
"elu",
|
|
41
|
+
"selu",
|
|
42
|
+
"leaky_relu",
|
|
43
|
+
"prelu",
|
|
44
|
+
"gelu",
|
|
45
|
+
"sigmoid",
|
|
46
|
+
"tanh",
|
|
47
|
+
"softplus",
|
|
48
|
+
"softsign",
|
|
49
|
+
"hardswish",
|
|
50
|
+
"mish",
|
|
51
|
+
"silu",
|
|
52
|
+
"swish",
|
|
53
|
+
"hardsigmoid",
|
|
54
|
+
"tanhshrink",
|
|
55
|
+
"softshrink",
|
|
56
|
+
"none",
|
|
57
|
+
"linear",
|
|
58
|
+
"identity",
|
|
59
|
+
]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.21
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -42,9 +42,11 @@ Requires-Dist: scipy<1.12,>=1.8; sys_platform == 'linux' and python_version < '3
|
|
|
42
42
|
Requires-Dist: scipy>=1.10.0; sys_platform == 'darwin'
|
|
43
43
|
Requires-Dist: scipy>=1.10.0; sys_platform == 'win32'
|
|
44
44
|
Requires-Dist: scipy>=1.11.0; sys_platform == 'linux' and python_version >= '3.12'
|
|
45
|
+
Requires-Dist: swanlab>=0.7.2
|
|
45
46
|
Requires-Dist: torch>=2.0.0
|
|
46
47
|
Requires-Dist: torchvision>=0.15.0
|
|
47
48
|
Requires-Dist: transformers>=4.38.0
|
|
49
|
+
Requires-Dist: wandb>=0.23.1
|
|
48
50
|
Provides-Extra: dev
|
|
49
51
|
Requires-Dist: jupyter>=1.0.0; extra == 'dev'
|
|
50
52
|
Requires-Dist: matplotlib>=3.7.0; extra == 'dev'
|
|
@@ -67,7 +69,7 @@ Description-Content-Type: text/markdown
|
|
|
67
69
|

|
|
68
70
|

|
|
69
71
|

|
|
70
|
-

|
|
71
73
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
72
74
|
|
|
73
75
|
中文文档 | [English Version](README_en.md)
|
|
@@ -100,7 +102,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
100
102
|
- **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
|
|
101
103
|
|
|
102
104
|
## NextRec近期进展
|
|
103
|
-
|
|
105
|
+
- **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
|
|
104
106
|
- **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
|
|
105
107
|
- **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
106
108
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
@@ -245,11 +247,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
245
247
|
|
|
246
248
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
247
249
|
|
|
248
|
-
> 截止当前版本0.4.
|
|
250
|
+
> 截止当前版本0.4.21,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
249
251
|
|
|
250
252
|
## 兼容平台
|
|
251
253
|
|
|
252
|
-
当前最新版本为0.4.
|
|
254
|
+
当前最新版本为0.4.21,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
253
255
|
|
|
254
256
|
| 平台 | 配置 |
|
|
255
257
|
|------|------|
|