nextrec 0.4.24__py3-none-any.whl → 0.4.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +175 -58
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +14 -70
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/METADATA +4 -4
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/RECORD +15 -15
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/model.py
CHANGED
|
@@ -2,14 +2,13 @@
|
|
|
2
2
|
Model-related utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 31/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from collections import OrderedDict
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
|
-
from torch import nn
|
|
13
12
|
|
|
14
13
|
from nextrec.loss import (
|
|
15
14
|
ApproxNDCGLoss,
|
|
@@ -20,14 +19,6 @@ from nextrec.loss import (
|
|
|
20
19
|
SampledSoftmaxLoss,
|
|
21
20
|
TripletLoss,
|
|
22
21
|
)
|
|
23
|
-
from nextrec.utils.types import (
|
|
24
|
-
LossName,
|
|
25
|
-
OptimizerName,
|
|
26
|
-
SchedulerName,
|
|
27
|
-
TrainingModeName,
|
|
28
|
-
TaskTypeName,
|
|
29
|
-
MetricsName,
|
|
30
|
-
)
|
|
31
22
|
|
|
32
23
|
|
|
33
24
|
def merge_features(primary, secondary) -> list:
|
|
@@ -81,47 +72,26 @@ def compute_pair_scores(model, data, batch_size: int = 512):
|
|
|
81
72
|
return scores.detach().cpu().numpy()
|
|
82
73
|
|
|
83
74
|
|
|
84
|
-
def get_training_modes(
|
|
85
|
-
training_mode,
|
|
86
|
-
nums_task: int,
|
|
87
|
-
valid_modes: set[str] | None = None,
|
|
88
|
-
) -> list:
|
|
89
|
-
valid_modes = valid_modes or {"pointwise", "pairwise", "listwise"}
|
|
90
|
-
if isinstance(training_mode, list):
|
|
91
|
-
training_modes = list(training_mode)
|
|
92
|
-
if len(training_modes) != nums_task:
|
|
93
|
-
raise ValueError(
|
|
94
|
-
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
95
|
-
)
|
|
96
|
-
else:
|
|
97
|
-
training_modes = [training_mode] * nums_task
|
|
98
|
-
if any(mode not in valid_modes for mode in training_modes):
|
|
99
|
-
raise ValueError(
|
|
100
|
-
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
101
|
-
)
|
|
102
|
-
return training_modes
|
|
103
|
-
|
|
104
|
-
|
|
105
75
|
def get_loss_list(
|
|
106
76
|
loss,
|
|
107
77
|
training_modes: list[str],
|
|
108
78
|
nums_task: int,
|
|
109
|
-
default_losses: dict[str, str],
|
|
110
79
|
):
|
|
111
|
-
|
|
112
|
-
|
|
80
|
+
default_losses = {
|
|
81
|
+
"pointwise": "bce",
|
|
82
|
+
"pairwise": "bpr",
|
|
83
|
+
"listwise": "listnet",
|
|
84
|
+
}
|
|
85
|
+
if loss is None:
|
|
113
86
|
loss_list = [default_losses[mode] for mode in training_modes]
|
|
114
|
-
elif isinstance(
|
|
115
|
-
if
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({nums_task})."
|
|
121
|
-
)
|
|
122
|
-
loss_list = list(effective_loss)
|
|
87
|
+
elif isinstance(loss, list):
|
|
88
|
+
if len(loss) != nums_task:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({nums_task})."
|
|
91
|
+
)
|
|
92
|
+
loss_list = loss
|
|
123
93
|
else:
|
|
124
|
-
loss_list = [
|
|
94
|
+
loss_list = [loss] * nums_task
|
|
125
95
|
|
|
126
96
|
for idx, mode in enumerate(training_modes):
|
|
127
97
|
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
@@ -133,32 +103,6 @@ def get_loss_list(
|
|
|
133
103
|
return loss_list
|
|
134
104
|
|
|
135
105
|
|
|
136
|
-
def resolve_loss_weights(loss_weights, nums_task: int):
|
|
137
|
-
if loss_weights is None:
|
|
138
|
-
return None
|
|
139
|
-
if nums_task == 1:
|
|
140
|
-
if isinstance(loss_weights, (list, tuple)):
|
|
141
|
-
if len(loss_weights) != 1:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
144
|
-
)
|
|
145
|
-
loss_weights = loss_weights[0]
|
|
146
|
-
return [float(loss_weights)]
|
|
147
|
-
if isinstance(loss_weights, (int, float)):
|
|
148
|
-
weights = [float(loss_weights)] * nums_task
|
|
149
|
-
elif isinstance(loss_weights, (list, tuple)):
|
|
150
|
-
weights = [float(w) for w in loss_weights]
|
|
151
|
-
if len(weights) != nums_task:
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({nums_task})."
|
|
154
|
-
)
|
|
155
|
-
else:
|
|
156
|
-
raise TypeError(
|
|
157
|
-
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
158
|
-
)
|
|
159
|
-
return weights
|
|
160
|
-
|
|
161
|
-
|
|
162
106
|
def prepare_ranking_targets(
|
|
163
107
|
y_pred: torch.Tensor, y_true: torch.Tensor
|
|
164
108
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
nextrec/utils/torch_utils.py
CHANGED
|
@@ -204,6 +204,11 @@ def get_scheduler(
|
|
|
204
204
|
)
|
|
205
205
|
else:
|
|
206
206
|
raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
|
|
207
|
+
elif isinstance(scheduler, type) and issubclass(
|
|
208
|
+
scheduler,
|
|
209
|
+
(torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
|
|
210
|
+
):
|
|
211
|
+
scheduler_fn = scheduler(optimizer, **scheduler_params)
|
|
207
212
|
elif isinstance(
|
|
208
213
|
scheduler,
|
|
209
214
|
(torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
|
|
@@ -215,6 +220,12 @@ def get_scheduler(
|
|
|
215
220
|
return scheduler_fn
|
|
216
221
|
|
|
217
222
|
|
|
223
|
+
def to_numpy(values: Any) -> np.ndarray:
|
|
224
|
+
if isinstance(values, torch.Tensor):
|
|
225
|
+
return values.detach().cpu().numpy()
|
|
226
|
+
return np.asarray(values)
|
|
227
|
+
|
|
228
|
+
|
|
218
229
|
def to_tensor(
|
|
219
230
|
value: Any, dtype: torch.dtype, device: torch.device | str | None = None
|
|
220
231
|
) -> torch.Tensor:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.25
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
|
|
|
69
69
|

|
|
70
70
|

|
|
71
71
|

|
|
72
|
-

|
|
73
73
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
74
74
|
|
|
75
75
|
中文文档 | [English Version](README_en.md)
|
|
@@ -249,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
249
249
|
|
|
250
250
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
251
251
|
|
|
252
|
-
> 截止当前版本0.4.
|
|
252
|
+
> 截止当前版本0.4.25,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
253
253
|
|
|
254
254
|
## 兼容平台
|
|
255
255
|
|
|
256
|
-
当前最新版本为0.4.
|
|
256
|
+
当前最新版本为0.4.25,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
257
257
|
|
|
258
258
|
| 平台 | 配置 |
|
|
259
259
|
|------|------|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
3
|
-
nextrec/cli.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=2KhS4HNlDzv_pmrC0ssutBbIG8FVsE-4OUVNs-FKmXw,23
|
|
3
|
+
nextrec/cli.py,sha256=uOaXnlAM-ARrbxKOVWWkTE_rv-54px168kBhFUHtIAg,25073
|
|
4
4
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
nextrec/basic/activation.py,sha256=uekcJsOy8SiT0_NaDO2VNSStyYFzVikDFVLDk-VrjwQ,2949
|
|
6
6
|
nextrec/basic/callback.py,sha256=7geza5iMMlMojlrIKH5A7nzvCe4IYwgUaMRh_xpblWk,12585
|
|
@@ -9,15 +9,15 @@ nextrec/basic/heads.py,sha256=BshykLxD41KxKuZaBxf4Fmy1Mc52b3ioJliN1BVaGlk,3374
|
|
|
9
9
|
nextrec/basic/layers.py,sha256=tr8XFOcTvUHEZ6T3zJwmtKMA-u_xfzHloIkItGs821U,40084
|
|
10
10
|
nextrec/basic/loggers.py,sha256=KxTPVHtkebAbpxZIYZ4aqncZCu-dccpKtIxmi2bVs6o,13160
|
|
11
11
|
nextrec/basic/metrics.py,sha256=CPzENDcpO6QTDZLBtQlfAGKUYYQc0FT-eaMKJ4MURFo,23396
|
|
12
|
-
nextrec/basic/model.py,sha256=
|
|
12
|
+
nextrec/basic/model.py,sha256=FFJIrMW0dqh89Jq1poXpNN-8XvPfcqIlKdChQpGH6x0,110083
|
|
13
13
|
nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
|
|
14
|
-
nextrec/basic/summary.py,sha256=
|
|
14
|
+
nextrec/basic/summary.py,sha256=b6jLo70gqZj_bQ4eb5yb8SXmr2ilZlKNN293EyVnkyc,17759
|
|
15
15
|
nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
|
|
16
16
|
nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
|
|
17
|
-
nextrec/data/data_processing.py,sha256=
|
|
17
|
+
nextrec/data/data_processing.py,sha256=lhuwYxWp4Ts2bbuLGDt2LmuPrOy7pNcKczd2uVcQ4ss,6476
|
|
18
18
|
nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
|
|
19
|
-
nextrec/data/dataloader.py,sha256=
|
|
20
|
-
nextrec/data/preprocessor.py,sha256=
|
|
19
|
+
nextrec/data/dataloader.py,sha256=gTs4YC5tHHwTq0A9481KYK1XyloeN2dMVOjPAFehF_E,19972
|
|
20
|
+
nextrec/data/preprocessor.py,sha256=AD5bHNbkAZAnI_SbDfJJaAh57CRtRjoOQJ6aIBkgoQs,65251
|
|
21
21
|
nextrec/loss/__init__.py,sha256=rualGsY-IBvmM52q9eOBk0MyKcMkpkazcscOeDXi_SM,774
|
|
22
22
|
nextrec/loss/grad_norm.py,sha256=YoE_XSIN1HOUcNq1dpfkIlWtMaB5Pu-SEWDaNgtRw1M,8316
|
|
23
23
|
nextrec/loss/listwise.py,sha256=mluxXQt9XiuWGvXA1nk4I0miqaKB6_GPVQqxLhAiJKs,5999
|
|
@@ -70,17 +70,17 @@ nextrec/models/retrieval/youtube_dnn.py,sha256=ciD9RyBy19mfQcEoqw1UfydmVBsJvffDw
|
|
|
70
70
|
nextrec/models/sequential/hstu.py,sha256=4-EUOQ4HTRG5MAhTA2b9FOOXXw8oyPxDBaaDFunkT6o,18979
|
|
71
71
|
nextrec/models/sequential/sasrec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
72
72
|
nextrec/utils/__init__.py,sha256=jD73RLigxcHFP-rXBoPi2VUTKH7kE5vNMQkr4lW8UUY,2655
|
|
73
|
-
nextrec/utils/config.py,sha256=
|
|
73
|
+
nextrec/utils/config.py,sha256=WeGHjoQYA5SoC9B_uS3D6ChzKr9Z5n2fwE-8l6nsuLE,20425
|
|
74
74
|
nextrec/utils/console.py,sha256=RA3ZTjtUQXvueouSmXJNLkRjeUGQZesphwWjFMTbV4I,13577
|
|
75
75
|
nextrec/utils/data.py,sha256=pSL96mWjWfW_RKE-qlUSs9vfiYnFZAaRirzA6r7DB6s,24994
|
|
76
76
|
nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
|
|
77
77
|
nextrec/utils/feature.py,sha256=E3NOFIW8gAoRXVrDhCSonzg8k7nMUZyZzMfCq9k73_A,623
|
|
78
78
|
nextrec/utils/loss.py,sha256=GBWQGpDaYkMJySpdG078XbeUNXUC34PVqFy0AqNS9N0,4578
|
|
79
|
-
nextrec/utils/model.py,sha256=
|
|
80
|
-
nextrec/utils/torch_utils.py,sha256=
|
|
79
|
+
nextrec/utils/model.py,sha256=dcAL2lXNXRzFfCHMfOM_gIDLH68IAezMosSmmOD3FiQ,5624
|
|
80
|
+
nextrec/utils/torch_utils.py,sha256=UQpWS7F3nITYqvx2KRBaQJc9oTowRkIvowhuQLt6NFM,11953
|
|
81
81
|
nextrec/utils/types.py,sha256=VhtLXUVvu0zAZVAUgRUML4FExRC-GH-ZmC1UiVSr3HE,1523
|
|
82
|
-
nextrec-0.4.
|
|
83
|
-
nextrec-0.4.
|
|
84
|
-
nextrec-0.4.
|
|
85
|
-
nextrec-0.4.
|
|
86
|
-
nextrec-0.4.
|
|
82
|
+
nextrec-0.4.25.dist-info/METADATA,sha256=AY7ejbF1WA7Z4EHCbZVKB-Flom6diV50LGwiUVt-auA,21859
|
|
83
|
+
nextrec-0.4.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
84
|
+
nextrec-0.4.25.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
85
|
+
nextrec-0.4.25.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
86
|
+
nextrec-0.4.25.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|