nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +3 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +259 -326
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/__init__.py +0 -4
  16. nextrec/loss/grad_norm.py +3 -3
  17. nextrec/models/multi_task/esmm.py +4 -6
  18. nextrec/models/multi_task/mmoe.py +4 -6
  19. nextrec/models/multi_task/ple.py +6 -8
  20. nextrec/models/multi_task/poso.py +5 -7
  21. nextrec/models/multi_task/share_bottom.py +6 -8
  22. nextrec/models/ranking/afm.py +4 -6
  23. nextrec/models/ranking/autoint.py +4 -6
  24. nextrec/models/ranking/dcn.py +8 -7
  25. nextrec/models/ranking/dcn_v2.py +4 -6
  26. nextrec/models/ranking/deepfm.py +5 -7
  27. nextrec/models/ranking/dien.py +8 -7
  28. nextrec/models/ranking/din.py +8 -7
  29. nextrec/models/ranking/eulernet.py +5 -7
  30. nextrec/models/ranking/ffm.py +5 -7
  31. nextrec/models/ranking/fibinet.py +4 -6
  32. nextrec/models/ranking/fm.py +4 -6
  33. nextrec/models/ranking/lr.py +4 -6
  34. nextrec/models/ranking/masknet.py +8 -9
  35. nextrec/models/ranking/pnn.py +4 -6
  36. nextrec/models/ranking/widedeep.py +5 -7
  37. nextrec/models/ranking/xdeepfm.py +8 -7
  38. nextrec/models/retrieval/dssm.py +4 -10
  39. nextrec/models/retrieval/dssm_v2.py +0 -6
  40. nextrec/models/retrieval/mind.py +4 -10
  41. nextrec/models/retrieval/sdm.py +4 -10
  42. nextrec/models/retrieval/youtube_dnn.py +4 -10
  43. nextrec/models/sequential/hstu.py +1 -3
  44. nextrec/utils/__init__.py +17 -15
  45. nextrec/utils/config.py +15 -5
  46. nextrec/utils/console.py +2 -2
  47. nextrec/utils/feature.py +2 -2
  48. nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
  49. nextrec/utils/torch_utils.py +57 -112
  50. nextrec/utils/types.py +63 -0
  51. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
  52. nextrec-0.4.22.dist-info/RECORD +81 -0
  53. nextrec-0.4.20.dist-info/RECORD +0 -79
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
  55. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
  56. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
@@ -50,12 +50,10 @@ class YoutubeDNN(BaseMatchModel):
50
50
  num_negative_samples: int = 100,
51
51
  temperature: float = 1.0,
52
52
  similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
53
- device: str = "cpu",
54
- embedding_l1_reg: float = 0.0,
55
- dense_l1_reg: float = 0.0,
56
- embedding_l2_reg: float = 0.0,
57
- dense_l2_reg: float = 0.0,
58
- early_stop_patience: int = 20,
53
+ embedding_l1_reg=0.0,
54
+ dense_l1_reg=0.0,
55
+ embedding_l2_reg=0.0,
56
+ dense_l2_reg=0.0,
59
57
  optimizer: str | torch.optim.Optimizer = "adam",
60
58
  optimizer_params: dict | None = None,
61
59
  scheduler: (
@@ -81,12 +79,10 @@ class YoutubeDNN(BaseMatchModel):
81
79
  num_negative_samples=num_negative_samples,
82
80
  temperature=temperature,
83
81
  similarity_metric=similarity_metric,
84
- device=device,
85
82
  embedding_l1_reg=embedding_l1_reg,
86
83
  dense_l1_reg=dense_l1_reg,
87
84
  embedding_l2_reg=embedding_l2_reg,
88
85
  dense_l2_reg=dense_l2_reg,
89
- early_stop_patience=early_stop_patience,
90
86
  **kwargs,
91
87
  )
92
88
 
@@ -169,8 +165,6 @@ class YoutubeDNN(BaseMatchModel):
169
165
  loss_params=loss_params,
170
166
  )
171
167
 
172
- self.to(device)
173
-
174
168
  def user_tower(self, user_input: dict) -> torch.Tensor:
175
169
  """
176
170
  User tower to encode historical behavior sequences and user features.
@@ -332,7 +332,6 @@ class HSTU(BaseModel):
332
332
  dense_l1_reg: float = 0.0,
333
333
  embedding_l2_reg: float = 0.0,
334
334
  dense_l2_reg: float = 0.0,
335
- device: str = "cpu",
336
335
  **kwargs,
337
336
  ):
338
337
  raise NotImplementedError(
@@ -348,7 +347,7 @@ class HSTU(BaseModel):
348
347
  )[0]
349
348
 
350
349
  self.hidden_dim = hidden_dim or max(
351
- int(getattr(self.item_history_feature, "embedding_dim", 0) or 0), 32
350
+ int(self.item_history_feature.embedding_dim or 0), 32
352
351
  )
353
352
  # Make hidden_dim divisible by num_heads
354
353
  if self.hidden_dim % num_heads != 0:
@@ -368,7 +367,6 @@ class HSTU(BaseModel):
368
367
  sequence_features=sequence_features,
369
368
  target=target,
370
369
  task=task or self.default_task,
371
- device=device,
372
370
  embedding_l1_reg=embedding_l1_reg,
373
371
  dense_l1_reg=dense_l1_reg,
374
372
  embedding_l2_reg=embedding_l2_reg,
nextrec/utils/__init__.py CHANGED
@@ -6,7 +6,7 @@ Last update: 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- from . import console, data, embedding, torch_utils
9
+ from . import console, data, embedding, loss, torch_utils
10
10
  from .config import (
11
11
  build_feature_objects,
12
12
  build_model_instance,
@@ -14,6 +14,7 @@ from .config import (
14
14
  load_model_class,
15
15
  register_processor_features,
16
16
  resolve_path,
17
+ safe_value,
17
18
  select_features,
18
19
  )
19
20
  from .console import (
@@ -35,23 +36,20 @@ from .data import (
35
36
  resolve_file_paths,
36
37
  )
37
38
  from .embedding import get_auto_embedding_dim
38
- from .feature import normalize_to_list
39
+ from .feature import to_list
39
40
  from .model import compute_pair_scores, get_mlp_output_dim, merge_features
41
+ from .loss import normalize_task_loss
40
42
  from .torch_utils import (
41
43
  add_distributed_sampler,
42
- concat_tensors,
43
- configure_device,
44
+ get_device,
44
45
  gather_numpy,
45
- get_device_info,
46
46
  get_initializer,
47
47
  get_optimizer,
48
48
  get_scheduler,
49
49
  init_process_group,
50
- pad_sequence_tensors,
51
- resolve_device,
52
- stack_tensors,
53
50
  to_tensor,
54
51
  )
52
+ from .types import LossName, OptimizerName, SchedulerName, ActivationName
55
53
 
56
54
  __all__ = [
57
55
  # Console utilities
@@ -67,17 +65,12 @@ __all__ = [
67
65
  # Embedding utilities
68
66
  "get_auto_embedding_dim",
69
67
  # Device utilities (torch utils)
70
- "resolve_device",
71
- "get_device_info",
72
- "configure_device",
68
+ "get_device",
73
69
  "init_process_group",
74
70
  "gather_numpy",
75
71
  "add_distributed_sampler",
76
72
  # Tensor utilities
77
73
  "to_tensor",
78
- "stack_tensors",
79
- "concat_tensors",
80
- "pad_sequence_tensors",
81
74
  # Data utilities
82
75
  "resolve_file_paths",
83
76
  "read_table",
@@ -89,10 +82,13 @@ __all__ = [
89
82
  "merge_features",
90
83
  "get_mlp_output_dim",
91
84
  "compute_pair_scores",
85
+ # Loss utilities
86
+ "normalize_task_loss",
92
87
  # Feature utilities
93
- "normalize_to_list",
88
+ "to_list",
94
89
  # Config utilities
95
90
  "resolve_path",
91
+ "safe_value",
96
92
  "register_processor_features",
97
93
  "build_feature_objects",
98
94
  "extract_feature_groups",
@@ -108,5 +104,11 @@ __all__ = [
108
104
  "console",
109
105
  "data",
110
106
  "embedding",
107
+ "loss",
111
108
  "torch_utils",
109
+ # Type aliases
110
+ "OptimizerName",
111
+ "SchedulerName",
112
+ "LossName",
113
+ "ActivationName",
112
114
  ]
nextrec/utils/config.py CHANGED
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
21
21
  import pandas as pd
22
22
  import torch
23
23
 
24
- from nextrec.utils.feature import normalize_to_list
24
+ from nextrec.utils.feature import to_list
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
@@ -52,6 +52,16 @@ def resolve_path(
52
52
  )
53
53
 
54
54
 
55
+ def safe_value(value: Any):
56
+ if isinstance(value, (str, int, float, bool)) or value is None:
57
+ return value
58
+ if isinstance(value, dict):
59
+ return {str(k): safe_value(v) for k, v in value.items()}
60
+ if isinstance(value, (list, tuple)):
61
+ return [safe_value(v) for v in value]
62
+ return str(value)
63
+
64
+
55
65
  def select_features(
56
66
  feature_cfg: Dict[str, Any], df_columns: List[str]
57
67
  ) -> Tuple[List[str], List[str], List[str]]:
@@ -152,9 +162,9 @@ def build_feature_objects(
152
162
  dense_features.append(
153
163
  DenseFeature(
154
164
  name=name,
155
- embedding_dim=embed_cfg.get("embedding_dim"),
165
+ proj_dim=embed_cfg.get("proj_dim"),
156
166
  input_dim=embed_cfg.get("input_dim", 1),
157
- use_embedding=embed_cfg.get("use_embedding", False),
167
+ use_projection=embed_cfg.get("use_projection", False),
158
168
  )
159
169
  )
160
170
 
@@ -239,7 +249,7 @@ def extract_feature_groups(
239
249
  collected: List[str] = []
240
250
 
241
251
  for group_name, names in feature_groups.items():
242
- name_list = normalize_to_list(names)
252
+ name_list = to_list(names)
243
253
  filtered = []
244
254
  missing_defined = [n for n in name_list if n not in defined]
245
255
  missing_cols = [n for n in name_list if n not in available_cols]
@@ -441,7 +451,7 @@ def build_model_instance(
441
451
  direct_features = binding.get("features") or binding.get("feature_names")
442
452
  if direct_features and (accepts(param_name) or accepts_var_kwargs):
443
453
  init_kwargs[param_name] = _select(
444
- normalize_to_list(direct_features),
454
+ to_list(direct_features),
445
455
  feature_pool,
446
456
  f"feature_bindings.{param_name}",
447
457
  )
nextrec/utils/console.py CHANGED
@@ -36,7 +36,7 @@ from rich.progress import (
36
36
  from rich.table import Table
37
37
  from rich.text import Text
38
38
 
39
- from nextrec.utils.feature import as_float, normalize_to_list
39
+ from nextrec.utils.feature import as_float, to_list
40
40
 
41
41
  T = TypeVar("T")
42
42
 
@@ -283,7 +283,7 @@ def display_metrics_table(
283
283
  if not is_main_process:
284
284
  return
285
285
 
286
- target_list = normalize_to_list(target_names)
286
+ target_list = to_list(target_names)
287
287
  task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
288
288
 
289
289
  if isinstance(loss, np.ndarray) and target_list:
nextrec/utils/feature.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Feature processing utilities for NextRec
3
3
 
4
4
  Date: create on 03/12/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -10,7 +10,7 @@ import numbers
10
10
  from typing import Any
11
11
 
12
12
 
13
- def normalize_to_list(value: str | list[str] | None) -> list[str]:
13
+ def to_list(value: str | list[str] | None) -> list[str]:
14
14
  if value is None:
15
15
  return []
16
16
  if isinstance(value, str):
@@ -1,13 +1,13 @@
1
1
  """
2
2
  Loss utilities for NextRec.
3
3
 
4
- Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
4
+ Date: create on 28/12/2025
6
5
  Author: Yang Zhou, zyaztec@gmail.com
7
6
  """
8
7
 
9
- from typing import Literal
8
+ from __future__ import annotations
10
9
 
10
+ import torch
11
11
  import torch.nn as nn
12
12
 
13
13
  from nextrec.loss.listwise import (
@@ -19,39 +19,21 @@ from nextrec.loss.listwise import (
19
19
  )
20
20
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
21
21
  from nextrec.loss.pointwise import ClassBalancedFocalLoss, FocalLoss, WeightedBCELoss
22
+ from nextrec.utils.types import LossName
22
23
 
23
- VALID_TASK_TYPES = [
24
- "binary",
25
- "multilabel",
26
- "regression",
27
- ]
28
-
29
- # Define all supported loss types
30
- LossType = Literal[
31
- # Pointwise losses
32
- "bce",
33
- "binary_crossentropy",
34
- "weighted_bce",
35
- "focal",
36
- "focal_loss",
37
- "cb_focal",
38
- "class_balanced_focal",
39
- "crossentropy",
40
- "ce",
41
- "mse",
42
- "mae",
43
- # Pairwise ranking losses
44
- "bpr",
45
- "hinge",
46
- "triplet",
47
- # Listwise ranking losses
48
- "sampled_softmax",
49
- "softmax",
50
- "infonce",
51
- "listnet",
52
- "listmle",
53
- "approx_ndcg",
54
- ]
24
+
25
+ def normalize_task_loss(
26
+ task_loss,
27
+ valid_count,
28
+ total_count,
29
+ eps=1e-8,
30
+ ) -> torch.Tensor:
31
+ if not torch.is_tensor(valid_count):
32
+ valid_count = torch.tensor(float(valid_count), device=task_loss.device)
33
+ if not torch.is_tensor(total_count):
34
+ total_count = torch.tensor(float(total_count), device=task_loss.device)
35
+ scale = valid_count.to(task_loss.dtype) / (total_count.to(task_loss.dtype) + eps)
36
+ return task_loss * scale
55
37
 
56
38
 
57
39
  def build_cb_focal(kw):
@@ -60,7 +42,10 @@ def build_cb_focal(kw):
60
42
  return ClassBalancedFocalLoss(**kw)
61
43
 
62
44
 
63
- def get_loss_fn(loss=None, **kw) -> nn.Module:
45
+ def get_loss_fn(
46
+ loss: LossName | None | nn.Module = None,
47
+ **kw,
48
+ ) -> nn.Module:
64
49
  """
65
50
  Get loss function by name or return the provided loss module.
66
51
 
@@ -3,12 +3,16 @@ PyTorch-related utilities for NextRec.
3
3
 
4
4
  This module groups device setup, distributed helpers, optimizers/schedulers,
5
5
  initialization, and tensor helpers.
6
+
7
+ Date: create on 27/10/2025
8
+ Checkpoint: edit on 27/12/2025
9
+ Author: Yang Zhou, zyaztec@gmail.com
6
10
  """
7
11
 
8
12
  from __future__ import annotations
9
13
 
10
14
  import logging
11
- from typing import Any, Dict, Iterable, Set
15
+ from typing import Any, Dict, Iterable, Literal
12
16
 
13
17
  import numpy as np
14
18
  import torch
@@ -18,26 +22,25 @@ from torch.utils.data import DataLoader, IterableDataset
18
22
  from torch.utils.data.distributed import DistributedSampler
19
23
 
20
24
  from nextrec.basic.loggers import colorize
21
-
22
- KNOWN_NONLINEARITIES: Set[str] = {
23
- "linear",
24
- "conv1d",
25
- "conv2d",
26
- "conv3d",
27
- "conv_transpose1d",
28
- "conv_transpose2d",
29
- "conv_transpose3d",
30
- "sigmoid",
31
- "tanh",
32
- "relu",
33
- "leaky_relu",
34
- "selu",
35
- "gelu",
36
- }
25
+ from nextrec.utils.types import OptimizerName, SchedulerName
37
26
 
38
27
 
39
28
  def resolve_nonlinearity(activation: str) -> str:
40
- if activation in KNOWN_NONLINEARITIES:
29
+ if activation in [
30
+ "linear",
31
+ "conv1d",
32
+ "conv2d",
33
+ "conv3d",
34
+ "conv_transpose1d",
35
+ "conv_transpose2d",
36
+ "conv_transpose3d",
37
+ "sigmoid",
38
+ "tanh",
39
+ "relu",
40
+ "leaky_relu",
41
+ "selu",
42
+ "gelu",
43
+ ]:
41
44
  return activation
42
45
  return "linear"
43
46
 
@@ -53,8 +56,30 @@ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
53
56
 
54
57
 
55
58
  def get_initializer(
56
- init_type: str = "normal",
57
- activation: str = "linear",
59
+ init_type: Literal[
60
+ "xavier_uniform",
61
+ "xavier_normal",
62
+ "kaiming_uniform",
63
+ "kaiming_normal",
64
+ "orthogonal",
65
+ "normal",
66
+ "uniform",
67
+ ] = "normal",
68
+ activation: Literal[
69
+ "linear",
70
+ "conv1d",
71
+ "conv2d",
72
+ "conv3d",
73
+ "conv_transpose1d",
74
+ "conv_transpose2d",
75
+ "conv_transpose3d",
76
+ "sigmoid",
77
+ "tanh",
78
+ "relu",
79
+ "leaky_relu",
80
+ "selu",
81
+ "gelu",
82
+ ] = "linear",
58
83
  param: Dict[str, Any] | None = None,
59
84
  ):
60
85
  param = param or {}
@@ -89,47 +114,14 @@ def get_initializer(
89
114
  return initializer_fn
90
115
 
91
116
 
92
- def resolve_device() -> str:
93
- if torch.cuda.is_available():
94
- return "cuda"
95
- if torch.backends.mps.is_available():
96
- import platform
97
-
98
- mac_ver = platform.mac_ver()[0]
99
- try:
100
- major, _ = (int(x) for x in mac_ver.split(".")[:2])
101
- except Exception:
102
- major, _ = 0, 0
103
- if major >= 14:
104
- return "mps"
105
- return "cpu"
106
-
107
-
108
- def get_device_info() -> dict:
109
- info = {
110
- "cuda_available": torch.cuda.is_available(),
111
- "cuda_device_count": (
112
- torch.cuda.device_count() if torch.cuda.is_available() else 0
113
- ),
114
- "mps_available": torch.backends.mps.is_available(),
115
- "current_device": resolve_device(),
116
- }
117
-
118
- if torch.cuda.is_available():
119
- info["cuda_device_name"] = torch.cuda.get_device_name(0)
120
- info["cuda_capability"] = torch.cuda.get_device_capability(0)
121
-
122
- return info
123
-
124
-
125
- def configure_device(
117
+ def get_device(
126
118
  distributed: bool, local_rank: int, base_device: torch.device | str = "cpu"
127
119
  ) -> torch.device:
128
120
  try:
129
121
  device = torch.device(base_device)
130
122
  except Exception:
131
123
  logging.warning(
132
- "[configure_device Warning] Invalid base_device, falling back to CPU."
124
+ "[get_device Warning] Invalid base_device, falling back to CPU."
133
125
  )
134
126
  return torch.device("cpu")
135
127
 
@@ -158,7 +150,7 @@ def configure_device(
158
150
 
159
151
 
160
152
  def get_optimizer(
161
- optimizer: str | torch.optim.Optimizer = "adam",
153
+ optimizer: OptimizerName | torch.optim.Optimizer = "adam",
162
154
  params: Iterable[torch.nn.Parameter] | None = None,
163
155
  **optimizer_params,
164
156
  ):
@@ -191,7 +183,7 @@ def get_optimizer(
191
183
 
192
184
  def get_scheduler(
193
185
  scheduler: (
194
- str
186
+ SchedulerName
195
187
  | torch.optim.lr_scheduler._LRScheduler
196
188
  | torch.optim.lr_scheduler.LRScheduler
197
189
  | type[torch.optim.lr_scheduler._LRScheduler]
@@ -241,51 +233,6 @@ def to_tensor(
241
233
  return tensor
242
234
 
243
235
 
244
- def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
245
- if not tensors:
246
- raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
247
- return torch.stack(tensors, dim=dim)
248
-
249
-
250
- def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
251
- if not tensors:
252
- raise ValueError(
253
- "[Tensor Utils Error] Cannot concatenate empty list of tensors."
254
- )
255
- return torch.cat(tensors, dim=dim)
256
-
257
-
258
- def pad_sequence_tensors(
259
- tensors: list[torch.Tensor],
260
- max_len: int | None = None,
261
- padding_value: float = 0.0,
262
- padding_side: str = "right",
263
- ) -> torch.Tensor:
264
- if not tensors:
265
- raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
266
- if max_len is None:
267
- max_len = max(t.size(0) for t in tensors)
268
- batch_size = len(tensors)
269
- padded = torch.full(
270
- (batch_size, max_len),
271
- padding_value,
272
- dtype=tensors[0].dtype,
273
- device=tensors[0].device,
274
- )
275
-
276
- for i, tensor in enumerate(tensors):
277
- length = min(tensor.size(0), max_len)
278
- if padding_side == "right":
279
- padded[i, :length] = tensor[:length]
280
- elif padding_side == "left":
281
- padded[i, -length:] = tensor[:length]
282
- else:
283
- raise ValueError(
284
- f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}"
285
- )
286
- return padded
287
-
288
-
289
236
  def init_process_group(
290
237
  distributed: bool, rank: int, world_size: int, device_id: int | None = None
291
238
  ) -> None:
@@ -350,7 +297,7 @@ def add_distributed_sampler(
350
297
  # return if already has DistributedSampler
351
298
  if isinstance(loader.sampler, DistributedSampler):
352
299
  return loader, loader.sampler
353
- dataset = getattr(loader, "dataset", None)
300
+ dataset = loader.dataset
354
301
  if dataset is None:
355
302
  return loader, None
356
303
  if isinstance(dataset, IterableDataset):
@@ -379,25 +326,23 @@ def add_distributed_sampler(
379
326
  "collate_fn": loader.collate_fn,
380
327
  "drop_last": drop_last,
381
328
  }
382
- if getattr(loader, "pin_memory", False):
329
+ if loader.pin_memory:
383
330
  loader_kwargs["pin_memory"] = True
384
- pin_memory_device = getattr(loader, "pin_memory_device", None)
331
+ pin_memory_device = loader.pin_memory_device
385
332
  if pin_memory_device:
386
333
  loader_kwargs["pin_memory_device"] = pin_memory_device
387
- timeout = getattr(loader, "timeout", None)
334
+ timeout = loader.timeout
388
335
  if timeout:
389
336
  loader_kwargs["timeout"] = timeout
390
- worker_init_fn = getattr(loader, "worker_init_fn", None)
337
+ worker_init_fn = loader.worker_init_fn
391
338
  if worker_init_fn is not None:
392
339
  loader_kwargs["worker_init_fn"] = worker_init_fn
393
- generator = getattr(loader, "generator", None)
340
+ generator = loader.generator
394
341
  if generator is not None:
395
342
  loader_kwargs["generator"] = generator
396
343
  if loader.num_workers > 0:
397
- loader_kwargs["persistent_workers"] = getattr(
398
- loader, "persistent_workers", False
399
- )
400
- prefetch_factor = getattr(loader, "prefetch_factor", None)
344
+ loader_kwargs["persistent_workers"] = loader.persistent_workers
345
+ prefetch_factor = loader.prefetch_factor
401
346
  if prefetch_factor is not None:
402
347
  loader_kwargs["prefetch_factor"] = prefetch_factor
403
348
  distributed_loader = DataLoader(dataset, **loader_kwargs)
nextrec/utils/types.py ADDED
@@ -0,0 +1,63 @@
1
+ """
2
+ Shared type aliases for NextRec.
3
+
4
+ Keep Literal-based public string options centralized to avoid drift.
5
+ """
6
+
7
+ from typing import Literal
8
+
9
+ OptimizerName = Literal["adam", "sgd", "adamw", "adagrad", "rmsprop"]
10
+
11
+ SchedulerName = Literal["step", "cosine"]
12
+
13
+ LossName = Literal[
14
+ "bce",
15
+ "binary_crossentropy",
16
+ "weighted_bce",
17
+ "focal",
18
+ "focal_loss",
19
+ "cb_focal",
20
+ "class_balanced_focal",
21
+ "crossentropy",
22
+ "ce",
23
+ "mse",
24
+ "mae",
25
+ "bpr",
26
+ "hinge",
27
+ "triplet",
28
+ "sampled_softmax",
29
+ "softmax",
30
+ "infonce",
31
+ "listnet",
32
+ "listmle",
33
+ "approx_ndcg",
34
+ ]
35
+
36
+ ActivationName = Literal[
37
+ "dice",
38
+ "relu",
39
+ "relu6",
40
+ "elu",
41
+ "selu",
42
+ "leaky_relu",
43
+ "prelu",
44
+ "gelu",
45
+ "sigmoid",
46
+ "tanh",
47
+ "softplus",
48
+ "softsign",
49
+ "hardswish",
50
+ "mish",
51
+ "silu",
52
+ "swish",
53
+ "hardsigmoid",
54
+ "tanhshrink",
55
+ "softshrink",
56
+ "none",
57
+ "linear",
58
+ "identity",
59
+ ]
60
+
61
+ TrainingModeName = Literal["pointwise", "pairwise", "listwise"]
62
+
63
+ TaskTypeName = Literal["binary", "regression"]