nextrec 0.4.24__py3-none-any.whl → 0.4.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/utils/model.py CHANGED
@@ -2,14 +2,13 @@
2
2
  Model-related utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
- Checkpoint: edit on 29/12/2025
5
+ Checkpoint: edit on 31/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  from collections import OrderedDict
10
10
 
11
11
  import torch
12
- from torch import nn
13
12
 
14
13
  from nextrec.loss import (
15
14
  ApproxNDCGLoss,
@@ -20,14 +19,6 @@ from nextrec.loss import (
20
19
  SampledSoftmaxLoss,
21
20
  TripletLoss,
22
21
  )
23
- from nextrec.utils.types import (
24
- LossName,
25
- OptimizerName,
26
- SchedulerName,
27
- TrainingModeName,
28
- TaskTypeName,
29
- MetricsName,
30
- )
31
22
 
32
23
 
33
24
  def merge_features(primary, secondary) -> list:
@@ -81,47 +72,26 @@ def compute_pair_scores(model, data, batch_size: int = 512):
81
72
  return scores.detach().cpu().numpy()
82
73
 
83
74
 
84
- def get_training_modes(
85
- training_mode,
86
- nums_task: int,
87
- valid_modes: set[str] | None = None,
88
- ) -> list:
89
- valid_modes = valid_modes or {"pointwise", "pairwise", "listwise"}
90
- if isinstance(training_mode, list):
91
- training_modes = list(training_mode)
92
- if len(training_modes) != nums_task:
93
- raise ValueError(
94
- "[BaseModel-init Error] training_mode list length must match number of tasks."
95
- )
96
- else:
97
- training_modes = [training_mode] * nums_task
98
- if any(mode not in valid_modes for mode in training_modes):
99
- raise ValueError(
100
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
101
- )
102
- return training_modes
103
-
104
-
105
75
  def get_loss_list(
106
76
  loss,
107
77
  training_modes: list[str],
108
78
  nums_task: int,
109
- default_losses: dict[str, str],
110
79
  ):
111
- effective_loss = loss
112
- if effective_loss is None:
80
+ default_losses = {
81
+ "pointwise": "bce",
82
+ "pairwise": "bpr",
83
+ "listwise": "listnet",
84
+ }
85
+ if loss is None:
113
86
  loss_list = [default_losses[mode] for mode in training_modes]
114
- elif isinstance(effective_loss, list):
115
- if not effective_loss:
116
- loss_list = [default_losses[mode] for mode in training_modes]
117
- else:
118
- if len(effective_loss) != nums_task:
119
- raise ValueError(
120
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({nums_task})."
121
- )
122
- loss_list = list(effective_loss)
87
+ elif isinstance(loss, list):
88
+ if len(loss) != nums_task:
89
+ raise ValueError(
90
+ f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({nums_task})."
91
+ )
92
+ loss_list = loss
123
93
  else:
124
- loss_list = [effective_loss] * nums_task
94
+ loss_list = [loss] * nums_task
125
95
 
126
96
  for idx, mode in enumerate(training_modes):
127
97
  if isinstance(loss_list[idx], str) and loss_list[idx] in {
@@ -133,32 +103,6 @@ def get_loss_list(
133
103
  return loss_list
134
104
 
135
105
 
136
- def resolve_loss_weights(loss_weights, nums_task: int):
137
- if loss_weights is None:
138
- return None
139
- if nums_task == 1:
140
- if isinstance(loss_weights, (list, tuple)):
141
- if len(loss_weights) != 1:
142
- raise ValueError(
143
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
144
- )
145
- loss_weights = loss_weights[0]
146
- return [float(loss_weights)]
147
- if isinstance(loss_weights, (int, float)):
148
- weights = [float(loss_weights)] * nums_task
149
- elif isinstance(loss_weights, (list, tuple)):
150
- weights = [float(w) for w in loss_weights]
151
- if len(weights) != nums_task:
152
- raise ValueError(
153
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({nums_task})."
154
- )
155
- else:
156
- raise TypeError(
157
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
158
- )
159
- return weights
160
-
161
-
162
106
  def prepare_ranking_targets(
163
107
  y_pred: torch.Tensor, y_true: torch.Tensor
164
108
  ) -> tuple[torch.Tensor, torch.Tensor]:
@@ -204,6 +204,11 @@ def get_scheduler(
204
204
  )
205
205
  else:
206
206
  raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
207
+ elif isinstance(scheduler, type) and issubclass(
208
+ scheduler,
209
+ (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
210
+ ):
211
+ scheduler_fn = scheduler(optimizer, **scheduler_params)
207
212
  elif isinstance(
208
213
  scheduler,
209
214
  (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
@@ -215,6 +220,12 @@ def get_scheduler(
215
220
  return scheduler_fn
216
221
 
217
222
 
223
+ def to_numpy(values: Any) -> np.ndarray:
224
+ if isinstance(values, torch.Tensor):
225
+ return values.detach().cpu().numpy()
226
+ return np.asarray(values)
227
+
228
+
218
229
  def to_tensor(
219
230
  value: Any, dtype: torch.dtype, device: torch.device | str | None = None
220
231
  ) -> torch.Tensor:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.24
3
+ Version: 0.4.25
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.25-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -249,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
249
249
 
250
250
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
251
251
 
252
- > 截止当前版本0.4.24,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
252
+ > 截止当前版本0.4.25,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
253
253
 
254
254
  ## 兼容平台
255
255
 
256
- 当前最新版本为0.4.24,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
256
+ 当前最新版本为0.4.25,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
257
257
 
258
258
  | 平台 | 配置 |
259
259
  |------|------|
@@ -1,6 +1,6 @@
1
1
  nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
- nextrec/__version__.py,sha256=WWrTOK_Nz_e97GQCElAGK_CtxVsM2uOWZphO5msHKOs,23
3
- nextrec/cli.py,sha256=Vm1XCFVw1vFh9NFw3PYZ_fYbh07tf45fl3RtPycooUI,24317
2
+ nextrec/__version__.py,sha256=2KhS4HNlDzv_pmrC0ssutBbIG8FVsE-4OUVNs-FKmXw,23
3
+ nextrec/cli.py,sha256=uOaXnlAM-ARrbxKOVWWkTE_rv-54px168kBhFUHtIAg,25073
4
4
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  nextrec/basic/activation.py,sha256=uekcJsOy8SiT0_NaDO2VNSStyYFzVikDFVLDk-VrjwQ,2949
6
6
  nextrec/basic/callback.py,sha256=7geza5iMMlMojlrIKH5A7nzvCe4IYwgUaMRh_xpblWk,12585
@@ -9,15 +9,15 @@ nextrec/basic/heads.py,sha256=BshykLxD41KxKuZaBxf4Fmy1Mc52b3ioJliN1BVaGlk,3374
9
9
  nextrec/basic/layers.py,sha256=tr8XFOcTvUHEZ6T3zJwmtKMA-u_xfzHloIkItGs821U,40084
10
10
  nextrec/basic/loggers.py,sha256=KxTPVHtkebAbpxZIYZ4aqncZCu-dccpKtIxmi2bVs6o,13160
11
11
  nextrec/basic/metrics.py,sha256=CPzENDcpO6QTDZLBtQlfAGKUYYQc0FT-eaMKJ4MURFo,23396
12
- nextrec/basic/model.py,sha256=2dXMpYC8KV-prpUW7ex5hLq_NvKbPFbCNB9ncmCmBAE,104416
12
+ nextrec/basic/model.py,sha256=FFJIrMW0dqh89Jq1poXpNN-8XvPfcqIlKdChQpGH6x0,110083
13
13
  nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
14
- nextrec/basic/summary.py,sha256=9xDtDbtMCPSQuEVLx23-SLL6qDRl1MfM19YMBG3Wtow,15372
14
+ nextrec/basic/summary.py,sha256=b6jLo70gqZj_bQ4eb5yb8SXmr2ilZlKNN293EyVnkyc,17759
15
15
  nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
16
16
  nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
17
- nextrec/data/data_processing.py,sha256=ZDZMSTBvxjPppl872is4M49o4WAkZXw2vUFOsNr0q3w,6658
17
+ nextrec/data/data_processing.py,sha256=lhuwYxWp4Ts2bbuLGDt2LmuPrOy7pNcKczd2uVcQ4ss,6476
18
18
  nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
19
- nextrec/data/dataloader.py,sha256=hjp9gf9tgREozZE0tBVBhtNDb2Ss1bpOVo6Bw0WWsrk,19091
20
- nextrec/data/preprocessor.py,sha256=4mVhQ6W2M9nmTeQjArx_cndWwnk2i29U2iXSNgg5gXM,52917
19
+ nextrec/data/dataloader.py,sha256=gTs4YC5tHHwTq0A9481KYK1XyloeN2dMVOjPAFehF_E,19972
20
+ nextrec/data/preprocessor.py,sha256=AD5bHNbkAZAnI_SbDfJJaAh57CRtRjoOQJ6aIBkgoQs,65251
21
21
  nextrec/loss/__init__.py,sha256=rualGsY-IBvmM52q9eOBk0MyKcMkpkazcscOeDXi_SM,774
22
22
  nextrec/loss/grad_norm.py,sha256=YoE_XSIN1HOUcNq1dpfkIlWtMaB5Pu-SEWDaNgtRw1M,8316
23
23
  nextrec/loss/listwise.py,sha256=mluxXQt9XiuWGvXA1nk4I0miqaKB6_GPVQqxLhAiJKs,5999
@@ -70,17 +70,17 @@ nextrec/models/retrieval/youtube_dnn.py,sha256=ciD9RyBy19mfQcEoqw1UfydmVBsJvffDw
70
70
  nextrec/models/sequential/hstu.py,sha256=4-EUOQ4HTRG5MAhTA2b9FOOXXw8oyPxDBaaDFunkT6o,18979
71
71
  nextrec/models/sequential/sasrec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
72
  nextrec/utils/__init__.py,sha256=jD73RLigxcHFP-rXBoPi2VUTKH7kE5vNMQkr4lW8UUY,2655
73
- nextrec/utils/config.py,sha256=UIi4zntP2g4IJaeMQYoa6kMQlU_23Hq4N1ZugMgnB5A,20331
73
+ nextrec/utils/config.py,sha256=WeGHjoQYA5SoC9B_uS3D6ChzKr9Z5n2fwE-8l6nsuLE,20425
74
74
  nextrec/utils/console.py,sha256=RA3ZTjtUQXvueouSmXJNLkRjeUGQZesphwWjFMTbV4I,13577
75
75
  nextrec/utils/data.py,sha256=pSL96mWjWfW_RKE-qlUSs9vfiYnFZAaRirzA6r7DB6s,24994
76
76
  nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
77
77
  nextrec/utils/feature.py,sha256=E3NOFIW8gAoRXVrDhCSonzg8k7nMUZyZzMfCq9k73_A,623
78
78
  nextrec/utils/loss.py,sha256=GBWQGpDaYkMJySpdG078XbeUNXUC34PVqFy0AqNS9N0,4578
79
- nextrec/utils/model.py,sha256=M9ToX2sOw5t07a6lG2DagSjPJtUULopANOZ1EW_Wcds,7752
80
- nextrec/utils/torch_utils.py,sha256=1lvZ7BG-rGLIAlumQIoeq5T9dO9hx2p8sa2_DC_bTZU,11564
79
+ nextrec/utils/model.py,sha256=dcAL2lXNXRzFfCHMfOM_gIDLH68IAezMosSmmOD3FiQ,5624
80
+ nextrec/utils/torch_utils.py,sha256=UQpWS7F3nITYqvx2KRBaQJc9oTowRkIvowhuQLt6NFM,11953
81
81
  nextrec/utils/types.py,sha256=VhtLXUVvu0zAZVAUgRUML4FExRC-GH-ZmC1UiVSr3HE,1523
82
- nextrec-0.4.24.dist-info/METADATA,sha256=hH313iUy8qYMnSD05xW4m_E6LYh7x_NblMrxu6f34U4,21859
83
- nextrec-0.4.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
84
- nextrec-0.4.24.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
85
- nextrec-0.4.24.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
86
- nextrec-0.4.24.dist-info/RECORD,,
82
+ nextrec-0.4.25.dist-info/METADATA,sha256=AY7ejbF1WA7Z4EHCbZVKB-Flom6diV50LGwiUVt-auA,21859
83
+ nextrec-0.4.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
84
+ nextrec-0.4.25.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
85
+ nextrec-0.4.25.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
86
+ nextrec-0.4.25.dist-info/RECORD,,