nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +4 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +209 -316
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/loss_utils.py +5 -30
  16. nextrec/models/multi_task/esmm.py +4 -6
  17. nextrec/models/multi_task/mmoe.py +4 -6
  18. nextrec/models/multi_task/ple.py +6 -8
  19. nextrec/models/multi_task/poso.py +5 -7
  20. nextrec/models/multi_task/share_bottom.py +6 -8
  21. nextrec/models/ranking/afm.py +4 -6
  22. nextrec/models/ranking/autoint.py +4 -6
  23. nextrec/models/ranking/dcn.py +8 -7
  24. nextrec/models/ranking/dcn_v2.py +4 -6
  25. nextrec/models/ranking/deepfm.py +5 -7
  26. nextrec/models/ranking/dien.py +8 -7
  27. nextrec/models/ranking/din.py +8 -7
  28. nextrec/models/ranking/eulernet.py +5 -7
  29. nextrec/models/ranking/ffm.py +5 -7
  30. nextrec/models/ranking/fibinet.py +4 -6
  31. nextrec/models/ranking/fm.py +4 -6
  32. nextrec/models/ranking/lr.py +4 -6
  33. nextrec/models/ranking/masknet.py +8 -9
  34. nextrec/models/ranking/pnn.py +4 -6
  35. nextrec/models/ranking/widedeep.py +5 -7
  36. nextrec/models/ranking/xdeepfm.py +8 -7
  37. nextrec/models/retrieval/dssm.py +4 -10
  38. nextrec/models/retrieval/dssm_v2.py +0 -6
  39. nextrec/models/retrieval/mind.py +4 -10
  40. nextrec/models/retrieval/sdm.py +4 -10
  41. nextrec/models/retrieval/youtube_dnn.py +4 -10
  42. nextrec/models/sequential/hstu.py +1 -3
  43. nextrec/utils/__init__.py +12 -14
  44. nextrec/utils/config.py +15 -5
  45. nextrec/utils/console.py +2 -2
  46. nextrec/utils/feature.py +2 -2
  47. nextrec/utils/torch_utils.py +57 -112
  48. nextrec/utils/types.py +59 -0
  49. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
  50. nextrec-0.4.21.dist-info/RECORD +81 -0
  51. nextrec-0.4.20.dist-info/RECORD +0 -79
  52. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
  53. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
@@ -49,12 +49,10 @@ class SDM(BaseMatchModel):
49
49
  num_negative_samples: int = 4,
50
50
  temperature: float = 1.0,
51
51
  similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
52
- device: str = "cpu",
53
- embedding_l1_reg: float = 0.0,
54
- dense_l1_reg: float = 0.0,
55
- embedding_l2_reg: float = 0.0,
56
- dense_l2_reg: float = 0.0,
57
- early_stop_patience: int = 20,
52
+ embedding_l1_reg=0.0,
53
+ dense_l1_reg=0.0,
54
+ embedding_l2_reg=0.0,
55
+ dense_l2_reg=0.0,
58
56
  optimizer: str | torch.optim.Optimizer = "adam",
59
57
  optimizer_params: dict | None = None,
60
58
  scheduler: (
@@ -80,12 +78,10 @@ class SDM(BaseMatchModel):
80
78
  num_negative_samples=num_negative_samples,
81
79
  temperature=temperature,
82
80
  similarity_metric=similarity_metric,
83
- device=device,
84
81
  embedding_l1_reg=embedding_l1_reg,
85
82
  dense_l1_reg=dense_l1_reg,
86
83
  embedding_l2_reg=embedding_l2_reg,
87
84
  dense_l2_reg=dense_l2_reg,
88
- early_stop_patience=early_stop_patience,
89
85
  **kwargs,
90
86
  )
91
87
 
@@ -202,8 +198,6 @@ class SDM(BaseMatchModel):
202
198
  loss_params=loss_params,
203
199
  )
204
200
 
205
- self.to(device)
206
-
207
201
  def user_tower(self, user_input: dict) -> torch.Tensor:
208
202
  seq_feature = self.user_sequence_features[0]
209
203
  seq_input = user_input[seq_feature.name]
@@ -50,12 +50,10 @@ class YoutubeDNN(BaseMatchModel):
50
50
  num_negative_samples: int = 100,
51
51
  temperature: float = 1.0,
52
52
  similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
53
- device: str = "cpu",
54
- embedding_l1_reg: float = 0.0,
55
- dense_l1_reg: float = 0.0,
56
- embedding_l2_reg: float = 0.0,
57
- dense_l2_reg: float = 0.0,
58
- early_stop_patience: int = 20,
53
+ embedding_l1_reg=0.0,
54
+ dense_l1_reg=0.0,
55
+ embedding_l2_reg=0.0,
56
+ dense_l2_reg=0.0,
59
57
  optimizer: str | torch.optim.Optimizer = "adam",
60
58
  optimizer_params: dict | None = None,
61
59
  scheduler: (
@@ -81,12 +79,10 @@ class YoutubeDNN(BaseMatchModel):
81
79
  num_negative_samples=num_negative_samples,
82
80
  temperature=temperature,
83
81
  similarity_metric=similarity_metric,
84
- device=device,
85
82
  embedding_l1_reg=embedding_l1_reg,
86
83
  dense_l1_reg=dense_l1_reg,
87
84
  embedding_l2_reg=embedding_l2_reg,
88
85
  dense_l2_reg=dense_l2_reg,
89
- early_stop_patience=early_stop_patience,
90
86
  **kwargs,
91
87
  )
92
88
 
@@ -169,8 +165,6 @@ class YoutubeDNN(BaseMatchModel):
169
165
  loss_params=loss_params,
170
166
  )
171
167
 
172
- self.to(device)
173
-
174
168
  def user_tower(self, user_input: dict) -> torch.Tensor:
175
169
  """
176
170
  User tower to encode historical behavior sequences and user features.
@@ -332,7 +332,6 @@ class HSTU(BaseModel):
332
332
  dense_l1_reg: float = 0.0,
333
333
  embedding_l2_reg: float = 0.0,
334
334
  dense_l2_reg: float = 0.0,
335
- device: str = "cpu",
336
335
  **kwargs,
337
336
  ):
338
337
  raise NotImplementedError(
@@ -348,7 +347,7 @@ class HSTU(BaseModel):
348
347
  )[0]
349
348
 
350
349
  self.hidden_dim = hidden_dim or max(
351
- int(getattr(self.item_history_feature, "embedding_dim", 0) or 0), 32
350
+ int(self.item_history_feature.embedding_dim or 0), 32
352
351
  )
353
352
  # Make hidden_dim divisible by num_heads
354
353
  if self.hidden_dim % num_heads != 0:
@@ -368,7 +367,6 @@ class HSTU(BaseModel):
368
367
  sequence_features=sequence_features,
369
368
  target=target,
370
369
  task=task or self.default_task,
371
- device=device,
372
370
  embedding_l1_reg=embedding_l1_reg,
373
371
  dense_l1_reg=dense_l1_reg,
374
372
  embedding_l2_reg=embedding_l2_reg,
nextrec/utils/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .config import (
14
14
  load_model_class,
15
15
  register_processor_features,
16
16
  resolve_path,
17
+ safe_value,
17
18
  select_features,
18
19
  )
19
20
  from .console import (
@@ -35,23 +36,19 @@ from .data import (
35
36
  resolve_file_paths,
36
37
  )
37
38
  from .embedding import get_auto_embedding_dim
38
- from .feature import normalize_to_list
39
+ from .feature import to_list
39
40
  from .model import compute_pair_scores, get_mlp_output_dim, merge_features
40
41
  from .torch_utils import (
41
42
  add_distributed_sampler,
42
- concat_tensors,
43
- configure_device,
43
+ get_device,
44
44
  gather_numpy,
45
- get_device_info,
46
45
  get_initializer,
47
46
  get_optimizer,
48
47
  get_scheduler,
49
48
  init_process_group,
50
- pad_sequence_tensors,
51
- resolve_device,
52
- stack_tensors,
53
49
  to_tensor,
54
50
  )
51
+ from .types import LossName, OptimizerName, SchedulerName, ActivationName
55
52
 
56
53
  __all__ = [
57
54
  # Console utilities
@@ -67,17 +64,12 @@ __all__ = [
67
64
  # Embedding utilities
68
65
  "get_auto_embedding_dim",
69
66
  # Device utilities (torch utils)
70
- "resolve_device",
71
- "get_device_info",
72
- "configure_device",
67
+ "get_device",
73
68
  "init_process_group",
74
69
  "gather_numpy",
75
70
  "add_distributed_sampler",
76
71
  # Tensor utilities
77
72
  "to_tensor",
78
- "stack_tensors",
79
- "concat_tensors",
80
- "pad_sequence_tensors",
81
73
  # Data utilities
82
74
  "resolve_file_paths",
83
75
  "read_table",
@@ -90,9 +82,10 @@ __all__ = [
90
82
  "get_mlp_output_dim",
91
83
  "compute_pair_scores",
92
84
  # Feature utilities
93
- "normalize_to_list",
85
+ "to_list",
94
86
  # Config utilities
95
87
  "resolve_path",
88
+ "safe_value",
96
89
  "register_processor_features",
97
90
  "build_feature_objects",
98
91
  "extract_feature_groups",
@@ -109,4 +102,9 @@ __all__ = [
109
102
  "data",
110
103
  "embedding",
111
104
  "torch_utils",
105
+ # Type aliases
106
+ "OptimizerName",
107
+ "SchedulerName",
108
+ "LossName",
109
+ "ActivationName",
112
110
  ]
nextrec/utils/config.py CHANGED
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
21
21
  import pandas as pd
22
22
  import torch
23
23
 
24
- from nextrec.utils.feature import normalize_to_list
24
+ from nextrec.utils.feature import to_list
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
@@ -52,6 +52,16 @@ def resolve_path(
52
52
  )
53
53
 
54
54
 
55
+ def safe_value(value: Any):
56
+ if isinstance(value, (str, int, float, bool)) or value is None:
57
+ return value
58
+ if isinstance(value, dict):
59
+ return {str(k): safe_value(v) for k, v in value.items()}
60
+ if isinstance(value, (list, tuple)):
61
+ return [safe_value(v) for v in value]
62
+ return str(value)
63
+
64
+
55
65
  def select_features(
56
66
  feature_cfg: Dict[str, Any], df_columns: List[str]
57
67
  ) -> Tuple[List[str], List[str], List[str]]:
@@ -152,9 +162,9 @@ def build_feature_objects(
152
162
  dense_features.append(
153
163
  DenseFeature(
154
164
  name=name,
155
- embedding_dim=embed_cfg.get("embedding_dim"),
165
+ proj_dim=embed_cfg.get("proj_dim"),
156
166
  input_dim=embed_cfg.get("input_dim", 1),
157
- use_embedding=embed_cfg.get("use_embedding", False),
167
+ use_projection=embed_cfg.get("use_projection", False),
158
168
  )
159
169
  )
160
170
 
@@ -239,7 +249,7 @@ def extract_feature_groups(
239
249
  collected: List[str] = []
240
250
 
241
251
  for group_name, names in feature_groups.items():
242
- name_list = normalize_to_list(names)
252
+ name_list = to_list(names)
243
253
  filtered = []
244
254
  missing_defined = [n for n in name_list if n not in defined]
245
255
  missing_cols = [n for n in name_list if n not in available_cols]
@@ -441,7 +451,7 @@ def build_model_instance(
441
451
  direct_features = binding.get("features") or binding.get("feature_names")
442
452
  if direct_features and (accepts(param_name) or accepts_var_kwargs):
443
453
  init_kwargs[param_name] = _select(
444
- normalize_to_list(direct_features),
454
+ to_list(direct_features),
445
455
  feature_pool,
446
456
  f"feature_bindings.{param_name}",
447
457
  )
nextrec/utils/console.py CHANGED
@@ -36,7 +36,7 @@ from rich.progress import (
36
36
  from rich.table import Table
37
37
  from rich.text import Text
38
38
 
39
- from nextrec.utils.feature import as_float, normalize_to_list
39
+ from nextrec.utils.feature import as_float, to_list
40
40
 
41
41
  T = TypeVar("T")
42
42
 
@@ -283,7 +283,7 @@ def display_metrics_table(
283
283
  if not is_main_process:
284
284
  return
285
285
 
286
- target_list = normalize_to_list(target_names)
286
+ target_list = to_list(target_names)
287
287
  task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
288
288
 
289
289
  if isinstance(loss, np.ndarray) and target_list:
nextrec/utils/feature.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Feature processing utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -10,7 +10,7 @@ import numbers
10
10
  from typing import Any
11
11
 
12
12
 
13
- def normalize_to_list(value: str | list[str] | None) -> list[str]:
13
+ def to_list(value: str | list[str] | None) -> list[str]:
14
14
  if value is None:
15
15
  return []
16
16
  if isinstance(value, str):
@@ -3,12 +3,16 @@ PyTorch-related utilities for NextRec.
3
3
 
4
4
  This module groups device setup, distributed helpers, optimizers/schedulers,
5
5
  initialization, and tensor helpers.
6
+
7
+ Date: create on 27/10/2025
8
+ Checkpoint: edit on 27/12/2025
9
+ Author: Yang Zhou, zyaztec@gmail.com
6
10
  """
7
11
 
8
12
  from __future__ import annotations
9
13
 
10
14
  import logging
11
- from typing import Any, Dict, Iterable, Set
15
+ from typing import Any, Dict, Iterable, Literal
12
16
 
13
17
  import numpy as np
14
18
  import torch
@@ -18,26 +22,25 @@ from torch.utils.data import DataLoader, IterableDataset
18
22
  from torch.utils.data.distributed import DistributedSampler
19
23
 
20
24
  from nextrec.basic.loggers import colorize
21
-
22
- KNOWN_NONLINEARITIES: Set[str] = {
23
- "linear",
24
- "conv1d",
25
- "conv2d",
26
- "conv3d",
27
- "conv_transpose1d",
28
- "conv_transpose2d",
29
- "conv_transpose3d",
30
- "sigmoid",
31
- "tanh",
32
- "relu",
33
- "leaky_relu",
34
- "selu",
35
- "gelu",
36
- }
25
+ from nextrec.utils.types import OptimizerName, SchedulerName
37
26
 
38
27
 
39
28
  def resolve_nonlinearity(activation: str) -> str:
40
- if activation in KNOWN_NONLINEARITIES:
29
+ if activation in [
30
+ "linear",
31
+ "conv1d",
32
+ "conv2d",
33
+ "conv3d",
34
+ "conv_transpose1d",
35
+ "conv_transpose2d",
36
+ "conv_transpose3d",
37
+ "sigmoid",
38
+ "tanh",
39
+ "relu",
40
+ "leaky_relu",
41
+ "selu",
42
+ "gelu",
43
+ ]:
41
44
  return activation
42
45
  return "linear"
43
46
 
@@ -53,8 +56,30 @@ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
53
56
 
54
57
 
55
58
  def get_initializer(
56
- init_type: str = "normal",
57
- activation: str = "linear",
59
+ init_type: Literal[
60
+ "xavier_uniform",
61
+ "xavier_normal",
62
+ "kaiming_uniform",
63
+ "kaiming_normal",
64
+ "orthogonal",
65
+ "normal",
66
+ "uniform",
67
+ ] = "normal",
68
+ activation: Literal[
69
+ "linear",
70
+ "conv1d",
71
+ "conv2d",
72
+ "conv3d",
73
+ "conv_transpose1d",
74
+ "conv_transpose2d",
75
+ "conv_transpose3d",
76
+ "sigmoid",
77
+ "tanh",
78
+ "relu",
79
+ "leaky_relu",
80
+ "selu",
81
+ "gelu",
82
+ ] = "linear",
58
83
  param: Dict[str, Any] | None = None,
59
84
  ):
60
85
  param = param or {}
@@ -89,47 +114,14 @@ def get_initializer(
89
114
  return initializer_fn
90
115
 
91
116
 
92
- def resolve_device() -> str:
93
- if torch.cuda.is_available():
94
- return "cuda"
95
- if torch.backends.mps.is_available():
96
- import platform
97
-
98
- mac_ver = platform.mac_ver()[0]
99
- try:
100
- major, _ = (int(x) for x in mac_ver.split(".")[:2])
101
- except Exception:
102
- major, _ = 0, 0
103
- if major >= 14:
104
- return "mps"
105
- return "cpu"
106
-
107
-
108
- def get_device_info() -> dict:
109
- info = {
110
- "cuda_available": torch.cuda.is_available(),
111
- "cuda_device_count": (
112
- torch.cuda.device_count() if torch.cuda.is_available() else 0
113
- ),
114
- "mps_available": torch.backends.mps.is_available(),
115
- "current_device": resolve_device(),
116
- }
117
-
118
- if torch.cuda.is_available():
119
- info["cuda_device_name"] = torch.cuda.get_device_name(0)
120
- info["cuda_capability"] = torch.cuda.get_device_capability(0)
121
-
122
- return info
123
-
124
-
125
- def configure_device(
117
+ def get_device(
126
118
  distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
127
119
  ) -> torch.device:
128
120
  try:
129
121
  device = torch.device(base_device)
130
122
  except Exception:
131
123
  logging.warning(
132
- "[configure_device Warning] Invalid base_device, falling back to CPU."
124
+ "[get_device Warning] Invalid base_device, falling back to CPU."
133
125
  )
134
126
  return torch.device("cpu")
135
127
 
@@ -158,7 +150,7 @@ def configure_device(
158
150
 
159
151
 
160
152
  def get_optimizer(
161
- optimizer: str | torch.optim.Optimizer = "adam",
153
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
162
154
  params: Iterable[torch.nn.Parameter] | None = None,
163
155
  **optimizer_params,
164
156
  ):
@@ -191,7 +183,7 @@ def get_optimizer(
191
183
 
192
184
  def get_scheduler(
193
185
  scheduler: (
194
- str
186
+ SchedulerName
195
187
  | torch.optim.lr_scheduler._LRScheduler
196
188
  | torch.optim.lr_scheduler.LRScheduler
197
189
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -241,51 +233,6 @@ def to_tensor(
241
233
  return tensor
242
234
 
243
235
 
244
- def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
245
- if not tensors:
246
- raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
247
- return torch.stack(tensors, dim=dim)
248
-
249
-
250
- def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
251
- if not tensors:
252
- raise ValueError(
253
- "[Tensor Utils Error] Cannot concatenate empty list of tensors."
254
- )
255
- return torch.cat(tensors, dim=dim)
256
-
257
-
258
- def pad_sequence_tensors(
259
- tensors: list[torch.Tensor],
260
- max_len: int | None = None,
261
- padding_value: float = 0.0,
262
- padding_side: str = "right",
263
- ) -> torch.Tensor:
264
- if not tensors:
265
- raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
266
- if max_len is None:
267
- max_len = max(t.size(0) for t in tensors)
268
- batch_size = len(tensors)
269
- padded = torch.full(
270
- (batch_size, max_len),
271
- padding_value,
272
- dtype=tensors[0].dtype,
273
- device=tensors[0].device,
274
- )
275
-
276
- for i, tensor in enumerate(tensors):
277
- length = min(tensor.size(0), max_len)
278
- if padding_side == "right":
279
- padded[i, :length] = tensor[:length]
280
- elif padding_side == "left":
281
- padded[i, -length:] = tensor[:length]
282
- else:
283
- raise ValueError(
284
- f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
285
- )
286
- return padded
287
-
288
-
289
236
  def init_process_group(
290
237
  distributed: bool, rank: int, world_size: int, device_id: int | None = None
291
238
  ) -> None:
@@ -350,7 +297,7 @@ def add_distributed_sampler(
350
297
  # return if already has DistributedSampler
351
298
  if isinstance(loader.sampler, DistributedSampler):
352
299
  return loader, loader.sampler
353
- dataset = getattr(loader, "dataset", None)
300
+ dataset = loader.dataset
354
301
  if dataset is None:
355
302
  return loader, None
356
303
  if isinstance(dataset, IterableDataset):
@@ -379,25 +326,23 @@ def add_distributed_sampler(
379
326
  "collate_fn": loader.collate_fn,
380
327
  "drop_last": drop_last,
381
328
  }
382
- if getattr(loader, "pin_memory", False):
329
+ if loader.pin_memory:
383
330
  loader_kwargs["pin_memory"] = True
384
- pin_memory_device = getattr(loader, "pin_memory_device", None)
331
+ pin_memory_device = loader.pin_memory_device
385
332
  if pin_memory_device:
386
333
  loader_kwargs["pin_memory_device"] = pin_memory_device
387
- timeout = getattr(loader, "timeout", None)
334
+ timeout = loader.timeout
388
335
  if timeout:
389
336
  loader_kwargs["timeout"] = timeout
390
- worker_init_fn = getattr(loader, "worker_init_fn", None)
337
+ worker_init_fn = loader.worker_init_fn
391
338
  if worker_init_fn is not None:
392
339
  loader_kwargs["worker_init_fn"] = worker_init_fn
393
- generator = getattr(loader, "generator", None)
340
+ generator = loader.generator
394
341
  if generator is not None:
395
342
  loader_kwargs["generator"] = generator
396
343
  if loader.num_workers > 0:
397
- loader_kwargs["persistent_workers"] = getattr(
398
- loader, "persistent_workers", False
399
- )
400
- prefetch_factor = getattr(loader, "prefetch_factor", None)
344
+ loader_kwargs["persistent_workers"] = loader.persistent_workers
345
+ prefetch_factor = loader.prefetch_factor
401
346
  if prefetch_factor is not None:
402
347
  loader_kwargs["prefetch_factor"] = prefetch_factor
403
348
  distributed_loader = DataLoader(dataset, **loader_kwargs)
nextrec/utils/types.py ADDED
@@ -0,0 +1,59 @@
1
+ """
2
+ Shared type aliases for NextRec.
3
+
4
+ Keep Literal-based public string options centralized to avoid drift.
5
+ """
6
+
7
+ from typing import Literal
8
+
9
+ OptimizerName = Literal["adam", "sgd", "adamw", "adagrad", "rmsprop"]
10
+
11
+ SchedulerName = Literal["step", "cosine"]
12
+
13
+ LossName = Literal[
14
+ "bce",
15
+ "binary_crossentropy",
16
+ "weighted_bce",
17
+ "focal",
18
+ "focal_loss",
19
+ "cb_focal",
20
+ "class_balanced_focal",
21
+ "crossentropy",
22
+ "ce",
23
+ "mse",
24
+ "mae",
25
+ "bpr",
26
+ "hinge",
27
+ "triplet",
28
+ "sampled_softmax",
29
+ "softmax",
30
+ "infonce",
31
+ "listnet",
32
+ "listmle",
33
+ "approx_ndcg",
34
+ ]
35
+
36
+ ActivationName = Literal[
37
+ "dice",
38
+ "relu",
39
+ "relu6",
40
+ "elu",
41
+ "selu",
42
+ "leaky_relu",
43
+ "prelu",
44
+ "gelu",
45
+ "sigmoid",
46
+ "tanh",
47
+ "softplus",
48
+ "softsign",
49
+ "hardswish",
50
+ "mish",
51
+ "silu",
52
+ "swish",
53
+ "hardsigmoid",
54
+ "tanhshrink",
55
+ "softshrink",
56
+ "none",
57
+ "linear",
58
+ "identity",
59
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.20
3
+ Version: 0.4.21
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -42,9 +42,11 @@ Requires-Dist: scipy<1.12,>=1.8; sys_platform == 'linux' and python_version < '3
42
42
  Requires-Dist: scipy>=1.10.0; sys_platform == 'darwin'
43
43
  Requires-Dist: scipy>=1.10.0; sys_platform == 'win32'
44
44
  Requires-Dist: scipy>=1.11.0; sys_platform == 'linux' and python_version >= '3.12'
45
+ Requires-Dist: swanlab>=0.7.2
45
46
  Requires-Dist: torch>=2.0.0
46
47
  Requires-Dist: torchvision>=0.15.0
47
48
  Requires-Dist: transformers>=4.38.0
49
+ Requires-Dist: wandb>=0.23.1
48
50
  Provides-Extra: dev
49
51
  Requires-Dist: jupyter>=1.0.0; extra == 'dev'
50
52
  Requires-Dist: matplotlib>=3.7.0; extra == 'dev'
@@ -67,7 +69,7 @@ Description-Content-Type: text/markdown
67
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
68
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
69
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
70
- ![Version](https://img.shields.io/badge/Version-0.4.20-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.21-orange.svg)
71
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
72
74
 
73
75
  中文文档 | [English Version](README_en.md)
@@ -100,7 +102,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
100
102
  - **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
101
103
 
102
104
  ## NextRec近期进展
103
-
105
+ - **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
104
106
  - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
105
107
  - **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
106
108
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
@@ -245,11 +247,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
245
247
 
246
248
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
247
249
 
248
- > 截止当前版本0.4.20,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
250
+ > 截止当前版本0.4.21,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
249
251
 
250
252
  ## 兼容平台
251
253
 
252
- 当前最新版本为0.4.20,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
254
+ 当前最新版本为0.4.21,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
253
255
 
254
256
  | 平台 | 配置 |
255
257
  |------|------|