nextrec 0.4.25__py3-none-any.whl → 0.4.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +54 -51
  5. nextrec/data/batch_utils.py +23 -3
  6. nextrec/data/dataloader.py +3 -8
  7. nextrec/models/multi_task/[pre]aitm.py +173 -0
  8. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  9. nextrec/models/multi_task/[pre]star.py +192 -0
  10. nextrec/models/multi_task/apg.py +330 -0
  11. nextrec/models/multi_task/cross_stitch.py +229 -0
  12. nextrec/models/multi_task/escm.py +290 -0
  13. nextrec/models/multi_task/esmm.py +8 -21
  14. nextrec/models/multi_task/hmoe.py +203 -0
  15. nextrec/models/multi_task/mmoe.py +20 -28
  16. nextrec/models/multi_task/pepnet.py +81 -76
  17. nextrec/models/multi_task/ple.py +30 -44
  18. nextrec/models/multi_task/poso.py +13 -22
  19. nextrec/models/multi_task/share_bottom.py +14 -25
  20. nextrec/models/ranking/afm.py +2 -2
  21. nextrec/models/ranking/autoint.py +2 -4
  22. nextrec/models/ranking/dcn.py +2 -3
  23. nextrec/models/ranking/dcn_v2.py +2 -3
  24. nextrec/models/ranking/deepfm.py +2 -3
  25. nextrec/models/ranking/dien.py +7 -9
  26. nextrec/models/ranking/din.py +8 -10
  27. nextrec/models/ranking/eulernet.py +1 -2
  28. nextrec/models/ranking/ffm.py +1 -2
  29. nextrec/models/ranking/fibinet.py +2 -3
  30. nextrec/models/ranking/fm.py +1 -1
  31. nextrec/models/ranking/lr.py +1 -1
  32. nextrec/models/ranking/masknet.py +1 -2
  33. nextrec/models/ranking/pnn.py +1 -2
  34. nextrec/models/ranking/widedeep.py +2 -3
  35. nextrec/models/ranking/xdeepfm.py +2 -4
  36. nextrec/models/representation/rqvae.py +4 -4
  37. nextrec/models/retrieval/dssm.py +18 -26
  38. nextrec/models/retrieval/dssm_v2.py +15 -22
  39. nextrec/models/retrieval/mind.py +9 -15
  40. nextrec/models/retrieval/sdm.py +36 -33
  41. nextrec/models/retrieval/youtube_dnn.py +16 -24
  42. nextrec/models/sequential/hstu.py +2 -2
  43. nextrec/utils/__init__.py +5 -1
  44. nextrec/utils/model.py +9 -14
  45. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/METADATA +72 -62
  46. nextrec-0.4.28.dist-info/RECORD +90 -0
  47. nextrec/models/multi_task/aitm.py +0 -0
  48. nextrec/models/multi_task/snr_trans.py +0 -0
  49. nextrec-0.4.25.dist-info/RECORD +0 -86
  50. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/WHEEL +0 -0
  51. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/entry_points.txt +0 -0
  52. {nextrec-0.4.25.dist-info → nextrec-0.4.28.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.25"
1
+ __version__ = "0.4.28"
@@ -0,0 +1,72 @@
1
+ """
2
+ Assert function definitions for NextRec models.
3
+
4
+ Date: create on 01/01/2026
5
+ Checkpoint: edit on 01/01/2026
6
+ Author: Yang Zhou, zyaztec@gmail.com
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from nextrec.utils.types import TaskTypeName, TrainingModeName
12
+
13
+
14
+ def assert_task(
15
+ task: list[TaskTypeName] | TaskTypeName | None,
16
+ nums_task: int,
17
+ *,
18
+ model_name: str,
19
+ ) -> None:
20
+ if task is None:
21
+ raise ValueError(f"{model_name} requires task to be specified.")
22
+
23
+ # case 1: task is str
24
+ if isinstance(task, str):
25
+ if nums_task != 1:
26
+ raise ValueError(
27
+ f"{model_name} received task='{task}' but nums_task={nums_task}. "
28
+ "String task is only allowed for single-task models."
29
+ )
30
+ return # single-task, valid
31
+
32
+ # case 2: task is list
33
+ if not isinstance(task, list):
34
+ raise TypeError(
35
+ f"{model_name} requires task to be a string or a list of strings."
36
+ )
37
+
38
+ # list but length == 1
39
+ if len(task) == 1:
40
+ if nums_task != 1:
41
+ raise ValueError(
42
+ f"{model_name} received task list of length 1 but nums_task={nums_task}. "
43
+ "Length-1 task list is only allowed for single-task models."
44
+ )
45
+ return # single-task, valid
46
+
47
+ # multi-task: length must match nums_task
48
+ if len(task) != nums_task:
49
+ raise ValueError(
50
+ f"{model_name} requires task length {nums_task}, got {len(task)}."
51
+ )
52
+
53
+
54
+ def assert_training_mode(
55
+ training_mode: TrainingModeName | list[TrainingModeName],
56
+ nums_task: int,
57
+ *,
58
+ model_name: str,
59
+ ) -> None:
60
+ valid_modes = {"pointwise", "pairwise", "listwise"}
61
+ if not isinstance(training_mode, list):
62
+ raise TypeError(
63
+ f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
64
+ )
65
+ if len(training_mode) != nums_task:
66
+ raise ValueError(
67
+ f"[{model_name}-init Error] training_mode list length must match number of tasks."
68
+ )
69
+ if any(mode not in valid_modes for mode in training_mode):
70
+ raise ValueError(
71
+ f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
72
+ )
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 01/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -190,6 +190,19 @@ class BasicLogger:
190
190
  def close(self) -> None:
191
191
  for backend in self.backends:
192
192
  backend.close()
193
+ for backend in self.backends:
194
+ if isinstance(backend, SwanLabLogger):
195
+ swanlab = backend.swanlab
196
+ if not backend.enabled or swanlab is None:
197
+ continue
198
+ finish_fn = getattr(swanlab, "finish", None)
199
+ if finish_fn is None:
200
+ continue
201
+ try:
202
+ finish_fn()
203
+ except TypeError:
204
+ finish_fn()
205
+ break
193
206
 
194
207
 
195
208
  class TensorBoardLogger(MetricsLoggerBackend):
@@ -369,10 +382,14 @@ class TrainingLogger(BasicLogger):
369
382
  wandb_kwargs = dict(wandb_kwargs or {})
370
383
  wandb_kwargs.setdefault("config", {})
371
384
  wandb_kwargs["config"].update(config)
385
+ if "notes" in wandb_kwargs:
386
+ wandb_kwargs["config"].pop("note", None)
372
387
 
373
388
  swanlab_kwargs = dict(swanlab_kwargs or {})
374
389
  swanlab_kwargs.setdefault("config", {})
375
390
  swanlab_kwargs["config"].update(config)
391
+ if "description" in swanlab_kwargs:
392
+ swanlab_kwargs["config"].pop("note", None)
376
393
 
377
394
  self.wandb_logger = None
378
395
  if use_wandb:
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 31/12/2025
5
+ Checkpoint: edit on 01/01/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -36,6 +36,7 @@ from torch.utils.data import DataLoader
36
36
  from torch.utils.data.distributed import DistributedSampler
37
37
 
38
38
  from nextrec import __version__
39
+ from nextrec.basic.asserts import assert_task
39
40
  from nextrec.basic.callback import (
40
41
  CallbackList,
41
42
  CheckpointSaver,
@@ -101,6 +102,7 @@ from nextrec.utils.types import (
101
102
 
102
103
  from nextrec.utils.data import FILE_FORMAT_CONFIG
103
104
 
105
+
104
106
  class BaseModel(SummarySet, FeatureSet, nn.Module):
105
107
  @property
106
108
  def model_name(self) -> str:
@@ -110,30 +112,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
110
112
  def default_task(self) -> TaskTypeName | list[TaskTypeName]:
111
113
  raise NotImplementedError
112
114
 
113
- @property
114
- def training_mode(self) -> TrainingModeName | list[TrainingModeName]:
115
- if self.nums_task > 1:
116
- return self.training_modes
117
- return self.training_modes[0] if self.training_modes else "pointwise"
118
-
119
-
120
- @training_mode.setter
121
- def training_mode(self, training_mode: TrainingModeName | list[TrainingModeName]):
122
- valid_modes = {"pointwise", "pairwise", "listwise"}
123
- if isinstance(training_mode, list):
124
- training_modes = list(training_mode)
125
- if len(training_modes) != self.nums_task:
126
- raise ValueError(
127
- "[BaseModel-init Error] training_mode list length must match number of tasks."
128
- )
129
- else:
130
- training_modes = [training_mode] * self.nums_task
131
- if any(mode not in valid_modes for mode in training_modes):
132
- raise ValueError(
133
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
134
- )
135
- self.training_modes = list(training_modes)
136
-
137
115
  def __init__(
138
116
  self,
139
117
  dense_features: list[DenseFeature] | None = None,
@@ -142,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
142
120
  target: list[str] | str | None = None,
143
121
  id_columns: list[str] | str | None = None,
144
122
  task: TaskTypeName | list[TaskTypeName] | None = None,
145
- training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
123
+ training_mode: TrainingModeName | list[TrainingModeName] | None = None,
146
124
  embedding_l1_reg: float = 0.0,
147
125
  dense_l1_reg: float = 0.0,
148
126
  embedding_l2_reg: float = 0.0,
@@ -162,10 +140,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
162
140
  dense_features: DenseFeature definitions.
163
141
  sparse_features: SparseFeature definitions.
164
142
  sequence_features: SequenceFeature definitions.
165
- target: Target column name. e.g., 'label' or ['label1', 'label2'].
143
+ target: Target column name. e.g., 'label_ctr' or ['label_ctr', 'label_cvr'].
166
144
  id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
167
145
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
168
- training_mode: Training mode for ranking tasks; a single mode or a list per task.
146
+ training_mode: Training mode for different tasks. e.g., 'pointwise', ['pointwise', 'pairwise'].
169
147
 
170
148
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
171
149
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
@@ -218,7 +196,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
218
196
  self.task = task or self.default_task
219
197
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
220
198
 
221
- self.training_mode = training_mode
199
+ training_mode = training_mode or "pointwise"
200
+ if isinstance(training_mode, list):
201
+ self.training_modes = list(training_mode)
202
+ else:
203
+ self.training_modes = [training_mode] * self.nums_task
222
204
 
223
205
  self.embedding_l1_reg = embedding_l1_reg
224
206
  self.dense_l1_reg = dense_l1_reg
@@ -328,13 +310,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
328
310
  def get_input(self, input_data: dict, require_labels: bool = True):
329
311
  """
330
312
  Prepare unified input features and labels from the given input data.
331
-
313
+
332
314
 
333
315
  Args:
334
316
  input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
335
317
  require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
336
-
337
- Note:
318
+
319
+ Note:
338
320
  target tensor shape will always be (batch_size, num_targets)
339
321
  """
340
322
  feature_source = input_data.get("features", {})
@@ -491,9 +473,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
491
473
  ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
492
474
  """
493
475
  self.ignore_label = ignore_label
494
- loss_list = get_loss_list(
495
- loss, self.training_modes, self.nums_task
496
- )
476
+
477
+ # get loss list
478
+ loss_list = get_loss_list(loss, self.training_modes, self.nums_task)
497
479
 
498
480
  self.loss_params = {} if loss_params is None else loss_params
499
481
  self.optimizer_params = optimizer_params or {}
@@ -546,7 +528,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
546
528
  raise ValueError(
547
529
  "[BaseModel-compile Error] GradNorm requires multi-task setup."
548
530
  )
549
- grad_norm_params = dict(loss_weights) if isinstance(loss_weights, dict) else {}
531
+ grad_norm_params = (
532
+ dict(loss_weights) if isinstance(loss_weights, dict) else {}
533
+ )
550
534
  grad_norm_params.pop("method", None)
551
535
  self.grad_norm = GradNormLossWeighting(
552
536
  nums_task=self.nums_task, device=self.device, **grad_norm_params
@@ -594,7 +578,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
594
578
  y_true = y_true.view(-1, 1)
595
579
 
596
580
  loss_fn = self.loss_fn[0]
597
-
581
+
598
582
  # mask ignored labels
599
583
  # we don't suggest using ignore_label for single task training
600
584
  if self.ignore_label is not None:
@@ -685,6 +669,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
685
669
  batch_size: int = 32,
686
670
  shuffle: bool = True,
687
671
  num_workers: int = 0,
672
+ prefetch_factor: int | None = None,
688
673
  sampler=None,
689
674
  return_dataset: bool = False,
690
675
  ):
@@ -696,6 +681,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
696
681
  batch_size: Batch size.
697
682
  shuffle: Whether to shuffle the data (ignored when a sampler is provided).
698
683
  num_workers: Number of DataLoader workers.
684
+ prefetch_factor: Number of batches loaded in advance by each worker.
699
685
  sampler: Optional sampler for DataLoader.
700
686
  return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
701
687
  Returns:
@@ -715,6 +701,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
715
701
  "[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
716
702
  )
717
703
  dataset = TensorDictDataset(tensors)
704
+ loader_kwargs = {}
705
+ if num_workers > 0 and prefetch_factor is not None:
706
+ loader_kwargs["prefetch_factor"] = prefetch_factor
718
707
  loader = DataLoader(
719
708
  dataset,
720
709
  batch_size=batch_size,
@@ -724,6 +713,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
724
713
  num_workers=num_workers,
725
714
  pin_memory=self.device.type == "cuda",
726
715
  persistent_workers=num_workers > 0,
716
+ **loader_kwargs,
727
717
  )
728
718
  return (loader, dataset) if return_dataset else loader
729
719
 
@@ -798,6 +788,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
798
788
  )
799
789
  self.to(self.device)
800
790
 
791
+ assert_task(self.task, len(self.target_columns), model_name=self.model_name)
792
+
801
793
  if not self.compiled:
802
794
  self.compile(
803
795
  optimizer="adam",
@@ -902,6 +894,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
902
894
  else:
903
895
  swanlab.login(api_key=swanlab_api)
904
896
 
897
+ if use_wandb and self.note:
898
+ wandb_kwargs = dict(wandb_kwargs or {})
899
+ wandb_kwargs.setdefault("notes", self.note)
900
+
901
+ if use_swanlab and self.note:
902
+ swanlab_kwargs = dict(swanlab_kwargs or {})
903
+ swanlab_kwargs.setdefault("description", self.note)
904
+
905
905
  self.training_logger = (
906
906
  TrainingLogger(
907
907
  session=self.session,
@@ -1649,7 +1649,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1649
1649
  stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1650
1650
  num_workers: DataLoader worker count.
1651
1651
 
1652
- Note:
1652
+ Note:
1653
1653
  predict does not support distributed mode currently, consider it as a single-process operation.
1654
1654
  """
1655
1655
  self.eval()
@@ -1837,7 +1837,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1837
1837
  ):
1838
1838
  """
1839
1839
  Make predictions on the given data using streaming mode for large datasets.
1840
-
1840
+
1841
1841
  Args:
1842
1842
  data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1843
1843
  batch_size: Batch size for prediction.
@@ -2279,9 +2279,10 @@ class BaseMatchModel(BaseModel):
2279
2279
  self.num_negative_samples = num_negative_samples
2280
2280
  self.temperature = temperature
2281
2281
  self.similarity_metric = similarity_metric
2282
- if self.training_mode not in self.support_training_modes:
2282
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2283
+ if primary_mode not in self.support_training_modes:
2283
2284
  raise ValueError(
2284
- f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2285
+ f"{self.model_name.upper()} does not support training_mode='{primary_mode}'. Supported modes: {self.support_training_modes}"
2285
2286
  )
2286
2287
  self.user_features_all = (
2287
2288
  self.user_dense_features
@@ -2298,7 +2299,7 @@ class BaseMatchModel(BaseModel):
2298
2299
  self.head = RetrievalHead(
2299
2300
  similarity_metric=self.similarity_metric,
2300
2301
  temperature=self.temperature,
2301
- training_mode=self.training_mode,
2302
+ training_mode=primary_mode,
2302
2303
  apply_sigmoid=True,
2303
2304
  )
2304
2305
 
@@ -2338,26 +2339,27 @@ class BaseMatchModel(BaseModel):
2338
2339
  }
2339
2340
 
2340
2341
  effective_loss = loss
2342
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2341
2343
  if effective_loss is None:
2342
- effective_loss = default_loss_by_mode[self.training_mode]
2344
+ effective_loss = default_loss_by_mode[primary_mode]
2343
2345
  elif isinstance(effective_loss, str):
2344
- if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2346
+ if primary_mode in {"pairwise", "listwise"} and effective_loss in {
2345
2347
  "bce",
2346
2348
  "binary_crossentropy",
2347
2349
  }:
2348
- effective_loss = default_loss_by_mode[self.training_mode]
2350
+ effective_loss = default_loss_by_mode[primary_mode]
2349
2351
  elif isinstance(effective_loss, list):
2350
2352
  if not effective_loss:
2351
- effective_loss = [default_loss_by_mode[self.training_mode]]
2353
+ effective_loss = [default_loss_by_mode[primary_mode]]
2352
2354
  else:
2353
2355
  first = effective_loss[0]
2354
2356
  if (
2355
- self.training_mode in {"pairwise", "listwise"}
2357
+ primary_mode in {"pairwise", "listwise"}
2356
2358
  and isinstance(first, str)
2357
2359
  and first in {"bce", "binary_crossentropy"}
2358
2360
  ):
2359
2361
  effective_loss = [
2360
- default_loss_by_mode[self.training_mode],
2362
+ default_loss_by_mode[primary_mode],
2361
2363
  *effective_loss[1:],
2362
2364
  ]
2363
2365
  return super().compile(
@@ -2435,11 +2437,12 @@ class BaseMatchModel(BaseModel):
2435
2437
  return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
2436
2438
 
2437
2439
  def compute_loss(self, y_pred, y_true):
2438
- if self.training_mode == "pointwise":
2440
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2441
+ if primary_mode == "pointwise":
2439
2442
  return super().compute_loss(y_pred, y_true)
2440
2443
 
2441
2444
  # pairwise / listwise using inbatch neg
2442
- elif self.training_mode in ["pairwise", "listwise"]:
2445
+ elif primary_mode in ["pairwise", "listwise"]:
2443
2446
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
2444
2447
  raise ValueError(
2445
2448
  "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
@@ -2482,7 +2485,7 @@ class BaseMatchModel(BaseModel):
2482
2485
  loss *= float(self.loss_weights[0])
2483
2486
  return loss
2484
2487
  else:
2485
- raise ValueError(f"Unknown training mode: {self.training_mode}")
2488
+ raise ValueError(f"Unknown training mode: {primary_mode}")
2486
2489
 
2487
2490
  def prepare_feature_data(
2488
2491
  self,
@@ -5,13 +5,27 @@ Date: create on 03/12/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  """
7
7
 
8
- from typing import Any, Mapping
8
+ from typing import Any, Mapping, Literal
9
9
 
10
10
  import numpy as np
11
11
  import torch
12
12
 
13
13
 
14
- def stack_section(batch: list[dict], section: str):
14
+ def stack_section(batch: list[dict], section: Literal["features", "labels", "ids"]):
15
+ """
16
+ input example:
17
+ batch = [
18
+ {"features": {"f1": tensor1, "f2": tensor2}, "labels": {"label": tensor3}},
19
+ {"features": {"f1": tensor4, "f2": tensor5}, "labels": {"label": tensor6}},
20
+ ...
21
+ ]
22
+ output example:
23
+ {
24
+ "f1": torch.stack([tensor1, tensor4], dim=0),
25
+ "f2": torch.stack([tensor2, tensor5], dim=0),
26
+ }
27
+
28
+ """
15
29
  entries = [item.get(section) for item in batch if item.get(section) is not None]
16
30
  if not entries:
17
31
  return None
@@ -22,7 +36,13 @@ def stack_section(batch: list[dict], section: str):
22
36
  for item in batch
23
37
  if item.get(section) is not None and name in item[section]
24
38
  ]
25
- merged[name] = torch.stack(tensors, dim=0)
39
+ tensor_sample = tensors[0]
40
+ if isinstance(tensor_sample, torch.Tensor):
41
+ merged[name] = torch.stack(tensors, dim=0)
42
+ elif isinstance(tensor_sample, np.ndarray):
43
+ merged[name] = np.stack(tensors, axis=0)
44
+ else:
45
+ merged[name] = tensors
26
46
  return merged
27
47
 
28
48
 
@@ -2,7 +2,7 @@
2
2
  Dataloader definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 24/12/2025
5
+ Checkpoint: edit on 01/01/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -523,13 +523,8 @@ def build_tensors_from_data(
523
523
  raise KeyError(
524
524
  f"[RecDataLoader Error] ID column '{id_col}' not found in provided data."
525
525
  )
526
- try:
527
- id_arr = np.asarray(column, dtype=np.int64)
528
- except Exception as exc:
529
- raise TypeError(
530
- f"[RecDataLoader Error] ID column '{id_col}' must contain numeric values. Received dtype={np.asarray(column).dtype}, error: {exc}"
531
- ) from exc
532
- id_tensors[id_col] = to_tensor(id_arr, dtype=torch.long)
526
+ # Normalize all id columns to strings for consistent downstream handling.
527
+ id_tensors[id_col] = np.asarray(column, dtype=str)
533
528
  if not feature_tensors:
534
529
  return None
535
530
  return {"features": feature_tensors, "labels": label_tensors, "ids": id_tensors}
@@ -0,0 +1,173 @@
1
+ """
2
+ Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
3
+ Checkpoint: edit on 01/01/2026
4
+ Author: Yang Zhou, zyaztec@gmail.com
5
+ Reference:
6
+ - [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
7
+ URL: https://arxiv.org/abs/2105.08489
8
+ - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
+
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
+ from nextrec.basic.layers import MLP, EmbeddingLayer
20
+ from nextrec.basic.heads import TaskHead
21
+ from nextrec.basic.model import BaseModel
22
+ from nextrec.utils.model import get_mlp_output_dim
23
+ from nextrec.utils.types import TaskTypeName
24
+
25
+
26
+ class AITMTransfer(nn.Module):
27
+ """Attentive information transfer from previous task to current task."""
28
+
29
+ def __init__(self, input_dim: int):
30
+ super().__init__()
31
+ self.input_dim = input_dim
32
+ self.prev_proj = nn.Linear(input_dim, input_dim)
33
+ self.value = nn.Linear(input_dim, input_dim)
34
+ self.key = nn.Linear(input_dim, input_dim)
35
+ self.query = nn.Linear(input_dim, input_dim)
36
+
37
+ def forward(self, prev_feat: torch.Tensor, curr_feat: torch.Tensor) -> torch.Tensor:
38
+ prev = self.prev_proj(prev_feat).unsqueeze(1)
39
+ curr = curr_feat.unsqueeze(1)
40
+ stacked = torch.cat([prev, curr], dim=1)
41
+ value = self.value(stacked)
42
+ key = self.key(stacked)
43
+ query = self.query(stacked)
44
+ attn_scores = torch.sum(key * query, dim=2, keepdim=True) / math.sqrt(
45
+ self.input_dim
46
+ )
47
+ attn = torch.softmax(attn_scores, dim=1)
48
+ return torch.sum(attn * value, dim=1)
49
+
50
+
51
+ class AITM(BaseModel):
52
+ """
53
+ Attentive Information Transfer Multi-Task model.
54
+
55
+ AITM learns task-specific representations and transfers information from
56
+ task i-1 to task i via attention, enabling sequential task dependency modeling.
57
+ """
58
+
59
+ @property
60
+ def model_name(self):
61
+ return "AITM"
62
+
63
+ @property
64
+ def default_task(self):
65
+ nums_task = getattr(self, "nums_task", None)
66
+ if nums_task is not None and nums_task > 0:
67
+ return ["binary"] * nums_task
68
+ return ["binary"]
69
+
70
+ def __init__(
71
+ self,
72
+ dense_features: list[DenseFeature] | None = None,
73
+ sparse_features: list[SparseFeature] | None = None,
74
+ sequence_features: list[SequenceFeature] | None = None,
75
+ bottom_mlp_params: dict | list[dict] | None = None,
76
+ tower_mlp_params_list: list[dict] | None = None,
77
+ calibrator_alpha: float = 0.1,
78
+ target: list[str] | str | None = None,
79
+ task: list[TaskTypeName] | None = None,
80
+ **kwargs,
81
+ ):
82
+ dense_features = dense_features or []
83
+ sparse_features = sparse_features or []
84
+ sequence_features = sequence_features or []
85
+ bottom_mlp_params = bottom_mlp_params or {}
86
+ tower_mlp_params_list = tower_mlp_params_list or []
87
+ self.calibrator_alpha = calibrator_alpha
88
+
89
+ if target is None:
90
+ raise ValueError("AITM requires target names for all tasks.")
91
+ if isinstance(target, str):
92
+ target = [target]
93
+
94
+ self.nums_task = len(target)
95
+ if self.nums_task < 2:
96
+ raise ValueError("AITM requires at least 2 tasks.")
97
+
98
+ super(AITM, self).__init__(
99
+ dense_features=dense_features,
100
+ sparse_features=sparse_features,
101
+ sequence_features=sequence_features,
102
+ target=target,
103
+ task=task,
104
+ **kwargs,
105
+ )
106
+
107
+ if len(tower_mlp_params_list) != self.nums_task:
108
+ raise ValueError(
109
+ "Number of tower mlp params "
110
+ f"({len(tower_mlp_params_list)}) must match number of tasks ({self.nums_task})."
111
+ )
112
+
113
+ bottom_mlp_params_list: list[dict]
114
+ if isinstance(bottom_mlp_params, list):
115
+ if len(bottom_mlp_params) != self.nums_task:
116
+ raise ValueError(
117
+ "Number of bottom mlp params "
118
+ f"({len(bottom_mlp_params)}) must match number of tasks ({self.nums_task})."
119
+ )
120
+ bottom_mlp_params_list = [params.copy() for params in bottom_mlp_params]
121
+ else:
122
+ bottom_mlp_params_list = [
123
+ bottom_mlp_params.copy() for _ in range(self.nums_task)
124
+ ]
125
+
126
+ self.embedding = EmbeddingLayer(features=self.all_features)
127
+ input_dim = self.embedding.input_dim
128
+
129
+ self.bottoms = nn.ModuleList(
130
+ [
131
+ MLP(input_dim=input_dim, output_dim=None, **params)
132
+ for params in bottom_mlp_params_list
133
+ ]
134
+ )
135
+ bottom_dims = [
136
+ get_mlp_output_dim(params, input_dim) for params in bottom_mlp_params_list
137
+ ]
138
+ if len(set(bottom_dims)) != 1:
139
+ raise ValueError(f"All bottom output dims must match, got {bottom_dims}.")
140
+ bottom_output_dim = bottom_dims[0]
141
+
142
+ self.transfers = nn.ModuleList(
143
+ [AITMTransfer(bottom_output_dim) for _ in range(self.nums_task - 1)]
144
+ )
145
+ self.grad_norm_shared_modules = ["embedding", "transfers"]
146
+
147
+ self.towers = nn.ModuleList(
148
+ [
149
+ MLP(input_dim=bottom_output_dim, output_dim=1, **params)
150
+ for params in tower_mlp_params_list
151
+ ]
152
+ )
153
+ self.prediction_layer = TaskHead(
154
+ task_type=self.task, task_dims=[1] * self.nums_task
155
+ )
156
+
157
+ self.register_regularization_weights(
158
+ embedding_attr="embedding",
159
+ include_modules=["bottoms", "transfers", "towers"],
160
+ )
161
+
162
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
163
+ input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
164
+ task_feats = [bottom(input_flat) for bottom in self.bottoms]
165
+
166
+ for idx in range(1, self.nums_task):
167
+ task_feats[idx] = self.transfers[idx - 1](
168
+ task_feats[idx - 1], task_feats[idx]
169
+ )
170
+
171
+ task_outputs = [tower(task_feats[idx]) for idx, tower in enumerate(self.towers)]
172
+ logits = torch.cat(task_outputs, dim=1)
173
+ return self.prediction_layer(logits)