nextrec 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +316 -321
  8. nextrec/cli.py +185 -43
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +31 -33
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +6 -7
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/eulernet.py +365 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +120 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/METADATA +6 -7
  54. nextrec-0.4.10.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/ffm.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,411 @@
1
+ """
2
+ PyTorch-related utilities for NextRec.
3
+
4
+ This module groups device setup, distributed helpers, optimizers/schedulers,
5
+ initialization, and tensor helpers.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from typing import Any, Dict, Iterable, Set
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader, IterableDataset
18
+ from torch.utils.data.distributed import DistributedSampler
19
+
20
+ from nextrec.basic.loggers import colorize
21
+
22
+ KNOWN_NONLINEARITIES: Set[str] = {
23
+ "linear",
24
+ "conv1d",
25
+ "conv2d",
26
+ "conv3d",
27
+ "conv_transpose1d",
28
+ "conv_transpose2d",
29
+ "conv_transpose3d",
30
+ "sigmoid",
31
+ "tanh",
32
+ "relu",
33
+ "leaky_relu",
34
+ "selu",
35
+ "gelu",
36
+ }
37
+
38
+
39
+ def resolve_nonlinearity(activation: str) -> str:
40
+ if activation in KNOWN_NONLINEARITIES:
41
+ return activation
42
+ return "linear"
43
+
44
+
45
+ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
46
+ if "gain" in param:
47
+ return param["gain"]
48
+ nonlinearity = resolve_nonlinearity(activation)
49
+ try:
50
+ return nn.init.calculate_gain(nonlinearity, param.get("param")) # type: ignore
51
+ except ValueError:
52
+ return 1.0
53
+
54
+
55
+ def get_initializer(
56
+ init_type: str = "normal",
57
+ activation: str = "linear",
58
+ param: Dict[str, Any] | None = None,
59
+ ):
60
+ param = param or {}
61
+ nonlinearity = resolve_nonlinearity(activation)
62
+ gain = resolve_gain(activation, param)
63
+
64
+ def initializer_fn(tensor):
65
+ if init_type == "xavier_uniform":
66
+ nn.init.xavier_uniform_(tensor, gain=gain)
67
+ elif init_type == "xavier_normal":
68
+ nn.init.xavier_normal_(tensor, gain=gain)
69
+ elif init_type == "kaiming_uniform":
70
+ nn.init.kaiming_uniform_(
71
+ tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
72
+ )
73
+ elif init_type == "kaiming_normal":
74
+ nn.init.kaiming_normal_(
75
+ tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
76
+ )
77
+ elif init_type == "orthogonal":
78
+ nn.init.orthogonal_(tensor, gain=gain)
79
+ elif init_type == "normal":
80
+ nn.init.normal_(
81
+ tensor, mean=param.get("mean", 0.0), std=param.get("std", 0.0001)
82
+ )
83
+ elif init_type == "uniform":
84
+ nn.init.uniform_(tensor, a=param.get("a", -0.05), b=param.get("b", 0.05))
85
+ else:
86
+ raise ValueError(f"Unknown init_type: {init_type}")
87
+ return tensor
88
+
89
+ return initializer_fn
90
+
91
+
92
+ def resolve_device() -> str:
93
+ if torch.cuda.is_available():
94
+ return "cuda"
95
+ if torch.backends.mps.is_available():
96
+ import platform
97
+
98
+ mac_ver = platform.mac_ver()[0]
99
+ try:
100
+ major, _ = (int(x) for x in mac_ver.split(".")[:2])
101
+ except Exception:
102
+ major, _ = 0, 0
103
+ if major >= 14:
104
+ return "mps"
105
+ return "cpu"
106
+
107
+
108
+ def get_device_info() -> dict:
109
+ info = {
110
+ "cuda_available": torch.cuda.is_available(),
111
+ "cuda_device_count": (
112
+ torch.cuda.device_count() if torch.cuda.is_available() else 0
113
+ ),
114
+ "mps_available": torch.backends.mps.is_available(),
115
+ "current_device": resolve_device(),
116
+ }
117
+
118
+ if torch.cuda.is_available():
119
+ info["cuda_device_name"] = torch.cuda.get_device_name(0)
120
+ info["cuda_capability"] = torch.cuda.get_device_capability(0)
121
+
122
+ return info
123
+
124
+
125
+ def configure_device(
126
+ distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
127
+ ) -> torch.device:
128
+ try:
129
+ device = torch.device(base_device)
130
+ except Exception:
131
+ logging.warning(
132
+ "[configure_device Warning] Invalid base_device, falling back to CPU."
133
+ )
134
+ return torch.device("cpu")
135
+
136
+ if distributed:
137
+ if device.type == "cuda":
138
+ if not torch.cuda.is_available():
139
+ logging.warning(
140
+ "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
141
+ )
142
+ return torch.device("cpu")
143
+ if not (0 <= local_rank < torch.cuda.device_count()):
144
+ logging.warning(
145
+ f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
146
+ )
147
+ return torch.device("cpu")
148
+ try:
149
+ torch.cuda.set_device(local_rank)
150
+ return torch.device(f"cuda:{local_rank}")
151
+ except Exception as exc:
152
+ logging.warning(
153
+ f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
154
+ )
155
+ return torch.device("cpu")
156
+ return torch.device("cpu")
157
+ return device
158
+
159
+
160
+ def get_optimizer(
161
+ optimizer: str | torch.optim.Optimizer = "adam",
162
+ params: Iterable[torch.nn.Parameter] | None = None,
163
+ **optimizer_params,
164
+ ):
165
+ if params is None:
166
+ raise ValueError("params cannot be None. Please provide model parameters.")
167
+
168
+ if "lr" not in optimizer_params:
169
+ optimizer_params["lr"] = 1e-3
170
+ if isinstance(optimizer, str):
171
+ opt_name = optimizer.lower()
172
+ if opt_name == "adam":
173
+ opt_class = torch.optim.Adam
174
+ elif opt_name == "sgd":
175
+ opt_class = torch.optim.SGD
176
+ elif opt_name == "adamw":
177
+ opt_class = torch.optim.AdamW
178
+ elif opt_name == "adagrad":
179
+ opt_class = torch.optim.Adagrad
180
+ elif opt_name == "rmsprop":
181
+ opt_class = torch.optim.RMSprop
182
+ else:
183
+ raise NotImplementedError(f"Unsupported optimizer: {optimizer}")
184
+ optimizer_fn = opt_class(params=params, **optimizer_params)
185
+ elif isinstance(optimizer, torch.optim.Optimizer):
186
+ optimizer_fn = optimizer
187
+ else:
188
+ raise TypeError(f"Invalid optimizer type: {type(optimizer)}")
189
+ return optimizer_fn
190
+
191
+
192
+ def get_scheduler(
193
+ scheduler: (
194
+ str
195
+ | torch.optim.lr_scheduler._LRScheduler
196
+ | torch.optim.lr_scheduler.LRScheduler
197
+ | type[torch.optim.lr_scheduler._LRScheduler]
198
+ | type[torch.optim.lr_scheduler.LRScheduler]
199
+ | None
200
+ ),
201
+ optimizer,
202
+ **scheduler_params,
203
+ ):
204
+ if isinstance(scheduler, str):
205
+ if scheduler == "step":
206
+ scheduler_fn = torch.optim.lr_scheduler.StepLR(
207
+ optimizer, **scheduler_params
208
+ )
209
+ elif scheduler == "cosine":
210
+ scheduler_fn = torch.optim.lr_scheduler.CosineAnnealingLR(
211
+ optimizer, **scheduler_params
212
+ )
213
+ else:
214
+ raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
215
+ elif isinstance(
216
+ scheduler,
217
+ (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler),
218
+ ):
219
+ scheduler_fn = scheduler
220
+ else:
221
+ raise TypeError(f"Invalid scheduler type: {type(scheduler)}")
222
+
223
+ return scheduler_fn
224
+
225
+
226
+ def to_tensor(
227
+ value: Any, dtype: torch.dtype, device: torch.device | str | None = None
228
+ ) -> torch.Tensor:
229
+ if value is None:
230
+ raise ValueError("[Tensor Utils Error] Cannot convert None to tensor.")
231
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
232
+ if tensor.dtype != dtype:
233
+ tensor = tensor.to(dtype=dtype)
234
+
235
+ if device is not None:
236
+ target_device = (
237
+ device if isinstance(device, torch.device) else torch.device(device)
238
+ )
239
+ if tensor.device != target_device:
240
+ tensor = tensor.to(target_device)
241
+ return tensor
242
+
243
+
244
+ def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
245
+ if not tensors:
246
+ raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
247
+ return torch.stack(tensors, dim=dim)
248
+
249
+
250
+ def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
251
+ if not tensors:
252
+ raise ValueError(
253
+ "[Tensor Utils Error] Cannot concatenate empty list of tensors."
254
+ )
255
+ return torch.cat(tensors, dim=dim)
256
+
257
+
258
+ def pad_sequence_tensors(
259
+ tensors: list[torch.Tensor],
260
+ max_len: int | None = None,
261
+ padding_value: float = 0.0,
262
+ padding_side: str = "right",
263
+ ) -> torch.Tensor:
264
+ if not tensors:
265
+ raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
266
+ if max_len is None:
267
+ max_len = max(t.size(0) for t in tensors)
268
+ batch_size = len(tensors)
269
+ padded = torch.full(
270
+ (batch_size, max_len),
271
+ padding_value,
272
+ dtype=tensors[0].dtype,
273
+ device=tensors[0].device,
274
+ )
275
+
276
+ for i, tensor in enumerate(tensors):
277
+ length = min(tensor.size(0), max_len)
278
+ if padding_side == "right":
279
+ padded[i, :length] = tensor[:length]
280
+ elif padding_side == "left":
281
+ padded[i, -length:] = tensor[:length]
282
+ else:
283
+ raise ValueError(
284
+ f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
285
+ )
286
+ return padded
287
+
288
+
289
+ def init_process_group(
290
+ distributed: bool, rank: int, world_size: int, device_id: int | None = None
291
+ ) -> None:
292
+ """
293
+ initialize distributed process group for multi-GPU training.
294
+
295
+ Args:
296
+ distributed: whether to enable distributed training
297
+ rank: global rank of the current process
298
+ world_size: total number of processes
299
+ """
300
+ if (not distributed) or (not dist.is_available()) or dist.is_initialized():
301
+ return
302
+ backend = "nccl" if device_id is not None else "gloo"
303
+ if backend == "nccl":
304
+ torch.cuda.set_device(device_id)
305
+ dist.init_process_group(
306
+ backend=backend, init_method="env://", rank=rank, world_size=world_size
307
+ )
308
+
309
+
310
+ def gather_numpy(self, array: np.ndarray | None) -> np.ndarray | None:
311
+ """
312
+ Gather numpy arrays (or None) across ranks. Uses all_gather_object to avoid
313
+ shape mismatches and ensures every rank participates even when local data is empty.
314
+ """
315
+ if not (self.distributed and dist.is_available() and dist.is_initialized()):
316
+ return array
317
+
318
+ world_size = dist.get_world_size()
319
+ gathered: list[np.ndarray | None] = [None for _ in range(world_size)]
320
+ dist.all_gather_object(gathered, array)
321
+ pieces: list[np.ndarray] = []
322
+ for item in gathered:
323
+ if item is None:
324
+ continue
325
+ item_np = np.asarray(item)
326
+ if item_np.size > 0:
327
+ pieces.append(item_np)
328
+ if not pieces:
329
+ return None
330
+ return np.concatenate(pieces, axis=0)
331
+
332
+
333
+ def add_distributed_sampler(
334
+ loader: DataLoader,
335
+ distributed: bool,
336
+ world_size: int,
337
+ rank: int,
338
+ shuffle: bool,
339
+ drop_last: bool,
340
+ default_batch_size: int,
341
+ is_main_process: bool = False,
342
+ ) -> tuple[DataLoader, DistributedSampler | None]:
343
+ """
344
+ add distributedsampler to a dataloader, this for distributed training
345
+ when each device has its own dataloader
346
+ """
347
+ # early return if not distributed
348
+ if not (distributed and dist.is_available() and dist.is_initialized()):
349
+ return loader, None
350
+ # return if already has DistributedSampler
351
+ if isinstance(loader.sampler, DistributedSampler):
352
+ return loader, loader.sampler
353
+ dataset = getattr(loader, "dataset", None)
354
+ if dataset is None:
355
+ return loader, None
356
+ if isinstance(dataset, IterableDataset):
357
+ if is_main_process:
358
+ logging.info(
359
+ colorize(
360
+ "[Distributed Info] Iterable/streaming DataLoader provided; DistributedSampler is skipped. Ensure dataset handles sharding per rank.",
361
+ color="yellow",
362
+ )
363
+ )
364
+ return loader, None
365
+ sampler = DistributedSampler(
366
+ dataset,
367
+ num_replicas=world_size,
368
+ rank=rank,
369
+ shuffle=shuffle,
370
+ drop_last=drop_last,
371
+ )
372
+ loader_kwargs = {
373
+ "batch_size": (
374
+ loader.batch_size if loader.batch_size is not None else default_batch_size
375
+ ),
376
+ "shuffle": False,
377
+ "sampler": sampler,
378
+ "num_workers": loader.num_workers,
379
+ "collate_fn": loader.collate_fn,
380
+ "drop_last": drop_last,
381
+ }
382
+ if getattr(loader, "pin_memory", False):
383
+ loader_kwargs["pin_memory"] = True
384
+ pin_memory_device = getattr(loader, "pin_memory_device", None)
385
+ if pin_memory_device:
386
+ loader_kwargs["pin_memory_device"] = pin_memory_device
387
+ timeout = getattr(loader, "timeout", None)
388
+ if timeout:
389
+ loader_kwargs["timeout"] = timeout
390
+ worker_init_fn = getattr(loader, "worker_init_fn", None)
391
+ if worker_init_fn is not None:
392
+ loader_kwargs["worker_init_fn"] = worker_init_fn
393
+ generator = getattr(loader, "generator", None)
394
+ if generator is not None:
395
+ loader_kwargs["generator"] = generator
396
+ if loader.num_workers > 0:
397
+ loader_kwargs["persistent_workers"] = getattr(
398
+ loader, "persistent_workers", False
399
+ )
400
+ prefetch_factor = getattr(loader, "prefetch_factor", None)
401
+ if prefetch_factor is not None:
402
+ loader_kwargs["prefetch_factor"] = prefetch_factor
403
+ distributed_loader = DataLoader(dataset, **loader_kwargs)
404
+ if is_main_process:
405
+ logging.info(
406
+ colorize(
407
+ "[Distributed Info] Attached DistributedSampler to provided DataLoader",
408
+ color="cyan",
409
+ )
410
+ )
411
+ return distributed_loader, sampler
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.8
3
+ Version: 0.4.10
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -33,6 +33,7 @@ Requires-Dist: pyarrow<15.0.0,>=12.0.0; sys_platform == 'win32'
33
33
  Requires-Dist: pyarrow>=12.0.0; sys_platform == 'darwin'
34
34
  Requires-Dist: pyarrow>=16.0.0; sys_platform == 'linux' and python_version >= '3.12'
35
35
  Requires-Dist: pyyaml>=6.0
36
+ Requires-Dist: rich>=13.7.0
36
37
  Requires-Dist: scikit-learn<2.0,>=1.2; sys_platform == 'linux' and python_version < '3.12'
37
38
  Requires-Dist: scikit-learn>=1.3.0; sys_platform == 'darwin'
38
39
  Requires-Dist: scikit-learn>=1.3.0; sys_platform == 'linux' and python_version >= '3.12'
@@ -43,7 +44,6 @@ Requires-Dist: scipy>=1.10.0; sys_platform == 'win32'
43
44
  Requires-Dist: scipy>=1.11.0; sys_platform == 'linux' and python_version >= '3.12'
44
45
  Requires-Dist: torch>=2.0.0
45
46
  Requires-Dist: torchvision>=0.15.0
46
- Requires-Dist: tqdm>=4.65.0
47
47
  Requires-Dist: transformers>=4.38.0
48
48
  Provides-Extra: dev
49
49
  Requires-Dist: jupyter>=1.0.0; extra == 'dev'
@@ -66,7 +66,7 @@ Description-Content-Type: text/markdown
66
66
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
67
67
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
68
68
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
69
- ![Version](https://img.shields.io/badge/Version-0.4.8-orange.svg)
69
+ ![Version](https://img.shields.io/badge/Version-0.4.10-orange.svg)
70
70
 
71
71
  中文文档 | [English Version](README_en.md)
72
72
 
@@ -99,11 +99,10 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
99
99
 
100
100
  ## NextRec近期进展
101
101
 
102
- - **12/12/2025** 在v0.4.8中加入了[RQ-VAE](/nextrec/models/generative/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynby)已经同步在仓库中
102
+ - **12/12/2025** 在v0.4.10中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
103
103
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
104
104
  - **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
105
105
  - **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
106
- - **23/11/2025** 在v0.2.2中对basemodel进行了逻辑上的大幅重构和流程统一,并且对listwise/pairwise/pointwise损失进行了统一
107
106
  - **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,4种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
108
107
 
109
108
  ## 架构
@@ -241,11 +240,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
241
240
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
242
241
  ```
243
242
 
244
- > 截止当前版本0.4.8,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
243
+ > 截止当前版本0.4.10,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
245
244
 
246
245
  ## 兼容平台
247
246
 
248
- 当前最新版本为0.4.8,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
247
+ 当前最新版本为0.4.10,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
249
248
 
250
249
  | 平台 | 配置 |
251
250
  |------|------|
@@ -0,0 +1,70 @@
1
+ nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
+ nextrec/__version__.py,sha256=N_k8mdXQaZTz0YYxAgWi2g6nf_GP6B5r8Q49Om9EynA,23
3
+ nextrec/cli.py,sha256=PXRNXMRm_a_1u6StnjsHefq0rKqsc6Mzx3mZmc9553g,23803
4
+ nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ nextrec/basic/activation.py,sha256=uzTWfCOtBSkbu_Gk9XBNTj8__s241CaYLJk6l8nGX9I,2885
6
+ nextrec/basic/callback.py,sha256=a6gg7r3x1v0xaOSya9PteLql7I14nepY7gX8tDYtins,14679
7
+ nextrec/basic/features.py,sha256=Wnbzr7UMotgv1Vzeg0o9Po-KKIvYUSYIghoVDfMPx_g,4340
8
+ nextrec/basic/layers.py,sha256=GJH2Tx3IkZrYGb7-ET976iHCC28Ubck_NO9-iyY4mDI,28911
9
+ nextrec/basic/loggers.py,sha256=JnQiFvmsVgZ63gqBLR2ZFWrVPzkxRbzWhTdeoiJKcos,6526
10
+ nextrec/basic/metrics.py,sha256=8RswR_3MGvIBkT_n6fnmON2eYH-hfD7kIKVnyJJjL3o,23131
11
+ nextrec/basic/model.py,sha256=OCcV9nTAZukurRISzPGCQM5yJ0Fpph3vOMKb2CPkI68,98685
12
+ nextrec/basic/session.py,sha256=UOG_-EgCOxvqZwCkiEd8sgNV2G1sm_HbzKYVQw8yYDI,4483
13
+ nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
14
+ nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
15
+ nextrec/data/data_processing.py,sha256=lKXDBszrO5fJMAQetgSPr2mSQuzOluuz1eHV4jp0TDU,5538
16
+ nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
17
+ nextrec/data/dataloader.py,sha256=xTORNbaQVa20sk2S3kyV0SSngscvq8bNqHr0AmYjFqM,18768
18
+ nextrec/data/preprocessor.py,sha256=wNjivq2N-iDzBropkp3YfSkN0jSA4l4h81C-ECa6k4c,44643
19
+ nextrec/loss/__init__.py,sha256=-sibZK8QXLblVNWqdqjrPPzMCDyIXSq7yd2eZ57p9Nw,810
20
+ nextrec/loss/listwise.py,sha256=UT9vJCOTOQLogVwaeTV7Z5uxIYnngGdxk-p9e97MGkU,5744
21
+ nextrec/loss/loss_utils.py,sha256=Eg_EKm47onSCLhgs2q7IkB7TV9TwV1Dz4QgVR2yh-gc,4610
22
+ nextrec/loss/pairwise.py,sha256=X9yg-8pcPt2IWU0AiUhWAt3_4W_3wIF0uSdDYTdoPFY,3398
23
+ nextrec/loss/pointwise.py,sha256=o9J3OznY0hlbDsUXqn3k-BBzYiuUH5dopz8QBFqS_kQ,7343
24
+ nextrec/models/generative/__init__.py,sha256=QtfMCu0-R7jg_HgjrUTP8ShWzXjQnPCpQjqpGwbMzp4,176
25
+ nextrec/models/generative/hstu.py,sha256=P2Kl7HEL3afwiCApGKQ6UbUNO9eNXXrB10H7iiF8cI0,19735
26
+ nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ nextrec/models/multi_task/esmm.py,sha256=3QQePhSkOcM4t472S4E5xU9_KiLiSwHb9CfdkEgmqqk,6491
29
+ nextrec/models/multi_task/mmoe.py,sha256=uFTbc0MiFBDTCIt8mTW6xs0oyOn1EesIHHZo81HR35k,8583
30
+ nextrec/models/multi_task/ple.py,sha256=z32etizNlTLwwR7CYKxy8u9owAbtiRh492Fje_y64hQ,13016
31
+ nextrec/models/multi_task/poso.py,sha256=foH7XDUz0XN0s0zoyHLuTmrcs3QOT8-x4YGxLX1Lxxg,19016
32
+ nextrec/models/multi_task/share_bottom.py,sha256=rmEnsX3LA3pNsLKfG1ir5WDLdkSY-imO_ASiclirJiA,6519
33
+ nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ nextrec/models/ranking/afm.py,sha256=96jGUPL4yTWobMIVBjHpOxl9AtAzCAGR8yw7Sy2JmdQ,10125
35
+ nextrec/models/ranking/autoint.py,sha256=S6Cxnp1q2OErSYqmIix5P-b4qLWR-0dY6TMStuU6WLg,8109
36
+ nextrec/models/ranking/dcn.py,sha256=whkjiKEuadl6oSP-NJdSOCOqvWZGX4EsId9oqlfVpa8,7299
37
+ nextrec/models/ranking/dcn_v2.py,sha256=QnqQbJsrtQp4mtvnBXFUVefKyr4dw-gHNWrCbO26oHw,11163
38
+ nextrec/models/ranking/deepfm.py,sha256=aXoK59e2KaaPe5vfyFW4YiHbX4E2iG3gxFCxmWo8RHk,5200
39
+ nextrec/models/ranking/dien.py,sha256=c7Zs85vxhOgKHg5s0QcSLCn1xXCCSD177TMERgM_v8g,18958
40
+ nextrec/models/ranking/din.py,sha256=gdUhuKiKXBNOALbK8fGhlbSeuDT8agcEdNSrC_wveHc,9422
41
+ nextrec/models/ranking/eulernet.py,sha256=mZTrD8rKbGbWMEeWpTl8mVimytLFJTLM5-LS_I3U6cw,13115
42
+ nextrec/models/ranking/ffm.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ nextrec/models/ranking/fibinet.py,sha256=_eroddVHooJcaGT8MqS4mUrtv5j4pnTmfI3FoAKOZhs,7919
44
+ nextrec/models/ranking/fm.py,sha256=SsrSKK3y4xg5Lv-t3JLnZan55Hzze2AxAiVPuscy0bk,4536
45
+ nextrec/models/ranking/lr.py,sha256=Qf8RozgWlsKjHGVbo-94d2Z_4kMfCXHmvwkYu3WVZjQ,4030
46
+ nextrec/models/ranking/masknet.py,sha256=tY1y2lO0iq82oylPN0SBnL5Bikc8weinFXpURyVT1hE,12373
47
+ nextrec/models/ranking/pnn.py,sha256=FcNIFAw5J0ORGSR6L8ZK7NeXlJPpojwe_SpsxMQqCFw,8174
48
+ nextrec/models/ranking/widedeep.py,sha256=-ghKfe_0puvlI9fBQr8lK3gXkfVvslGwP40AJTGqc7w,5077
49
+ nextrec/models/ranking/xdeepfm.py,sha256=FMtl_zYO1Ty_2d9VWRsz6Jo-Xjw8vikpIQPZCDVavVY,8156
50
+ nextrec/models/representation/__init__.py,sha256=O3QHMMXBszwM-mTl7bA3wawNZvDGet-QIv6Ys5GHGJ8,190
51
+ nextrec/models/representation/rqvae.py,sha256=JyZxVY9CibcdBGk97TxjG5O3WQC10_60tHNcP_qtegs,29290
52
+ nextrec/models/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
53
+ nextrec/models/retrieval/dssm.py,sha256=yMTu9msn6WgPqpM_zaJ0Z8R3uAYPzuJY_NS0xjJGSx0,8043
54
+ nextrec/models/retrieval/dssm_v2.py,sha256=IJzCT78Ra1DeSpR85KF1Awqnywu_6kDko3XDPbcup5s,7075
55
+ nextrec/models/retrieval/mind.py,sha256=ktNXwZxhnH1DV08c0gdsbodo4QJQL4UQGeyOqpmGwVM,15090
56
+ nextrec/models/retrieval/sdm.py,sha256=LhkCZSfGhxOxziEkUtjr_hnqcyciJ2qpMoBSFBVW9lQ,10558
57
+ nextrec/models/retrieval/youtube_dnn.py,sha256=xtGPV6_5LeSZBKkrTaU1CmtxlhgYLvZmjpwYaXYIaEA,7403
58
+ nextrec/utils/__init__.py,sha256=5ss2XQq8QZ2Ko5eiQ7oIig5cIZNrYGIptaarYEeO7Fk,2550
59
+ nextrec/utils/config.py,sha256=0HOeMyTlx8g6BZVpXzo2lEOkb-mzNwhbigQuUomsYnY,19934
60
+ nextrec/utils/console.py,sha256=D2Vax9_b7bgvAAOyk-Q2oUhSk1B-OngY5buS9Gb9-I0,11398
61
+ nextrec/utils/data.py,sha256=alruiWZFbmwy3kO12q42VXmtHmXFFjVULpHa43fx_mI,21098
62
+ nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
63
+ nextrec/utils/feature.py,sha256=rsUAv3ELyDpehVw8nPEEsLCCIjuKGTJJZuFaWB_wrPk,633
64
+ nextrec/utils/model.py,sha256=dYl1XfIZt6aVjNyV2AAhcArwFRMcEAKrjG_pr8AVHs0,1163
65
+ nextrec/utils/torch_utils.py,sha256=AKfYbSOJjEw874xsDB5IO3Ote4X7vnqzt_E0jJny0o8,13468
66
+ nextrec-0.4.10.dist-info/METADATA,sha256=b7ILFNk7WRZCg_2ZCx7_SWdU_d3mzN2b5IWTCnB0mbg,19318
67
+ nextrec-0.4.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
68
+ nextrec-0.4.10.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
69
+ nextrec-0.4.10.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
70
+ nextrec-0.4.10.dist-info/RECORD,,
@@ -1,58 +0,0 @@
1
- """
2
- CLI utilities for NextRec.
3
-
4
- This module provides small helpers used by command-line entrypoints, such as
5
- printing startup banners and resolving the installed package version.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import logging
11
- import os
12
- import platform
13
- import sys
14
- from datetime import datetime
15
-
16
-
17
- def get_nextrec_version() -> str:
18
- """
19
- Best-effort version resolver for NextRec.
20
-
21
- Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
22
- """
23
- try:
24
- from nextrec import __version__ # type: ignore
25
-
26
- if __version__:
27
- return str(__version__)
28
- except Exception:
29
- pass
30
-
31
- try:
32
- from importlib.metadata import version
33
-
34
- return version("nextrec")
35
- except Exception:
36
- return "unknown"
37
-
38
-
39
- def log_startup_info(
40
- logger: logging.Logger, *, mode: str, config_path: str | None
41
- ) -> None:
42
- """Log a short, user-friendly startup banner."""
43
- version = get_nextrec_version()
44
- now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
45
-
46
- lines = [
47
- "NextRec CLI",
48
- f"- Version: {version}",
49
- f"- Time: {now}",
50
- f"- Mode: {mode}",
51
- f"- Config: {config_path or '(not set)'}",
52
- f"- Python: {platform.python_version()} ({sys.executable})",
53
- f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
54
- f"- Workdir: {os.getcwd()}",
55
- f"- Command: {' '.join(sys.argv)}",
56
- ]
57
- for line in lines:
58
- logger.info(line)
nextrec/utils/device.py DELETED
@@ -1,78 +0,0 @@
1
- """
2
- Device management utilities for NextRec
3
-
4
- Date: create on 03/12/2025
5
- Checkpoint: edit on 06/12/2025
6
- Author: Yang Zhou, zyaztec@gmail.com
7
- """
8
-
9
- import torch
10
- import platform
11
- import logging
12
-
13
-
14
- def resolve_device() -> str:
15
- if torch.cuda.is_available():
16
- return "cuda"
17
- if torch.backends.mps.is_available():
18
- mac_ver = platform.mac_ver()[0]
19
- try:
20
- major, _ = (int(x) for x in mac_ver.split(".")[:2])
21
- except Exception:
22
- major, _ = 0, 0
23
- if major >= 14:
24
- return "mps"
25
- return "cpu"
26
-
27
-
28
- def get_device_info() -> dict:
29
- info = {
30
- "cuda_available": torch.cuda.is_available(),
31
- "cuda_device_count": (
32
- torch.cuda.device_count() if torch.cuda.is_available() else 0
33
- ),
34
- "mps_available": torch.backends.mps.is_available(),
35
- "current_device": resolve_device(),
36
- }
37
-
38
- if torch.cuda.is_available():
39
- info["cuda_device_name"] = torch.cuda.get_device_name(0)
40
- info["cuda_capability"] = torch.cuda.get_device_capability(0)
41
-
42
- return info
43
-
44
-
45
- def configure_device(
46
- distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
47
- ) -> torch.device:
48
- try:
49
- device = torch.device(base_device)
50
- except Exception:
51
- logging.warning(
52
- "[configure_device Warning] Invalid base_device, falling back to CPU."
53
- )
54
- return torch.device("cpu")
55
-
56
- if distributed:
57
- if device.type == "cuda":
58
- if not torch.cuda.is_available():
59
- logging.warning(
60
- "[Distributed Warning] CUDA requested but unavailable. Falling back to CPU."
61
- )
62
- return torch.device("cpu")
63
- if not (0 <= local_rank < torch.cuda.device_count()):
64
- logging.warning(
65
- f"[Distributed Warning] local_rank {local_rank} is invalid for available CUDA devices. Falling back to CPU."
66
- )
67
- return torch.device("cpu")
68
- try:
69
- torch.cuda.set_device(local_rank)
70
- return torch.device(f"cuda:{local_rank}")
71
- except Exception as exc:
72
- logging.warning(
73
- f"[Distributed Warning] Failed to set CUDA device for local_rank {local_rank}: {exc}. Falling back to CPU."
74
- )
75
- return torch.device("cpu")
76
- else:
77
- return torch.device("cpu")
78
- return device