nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/listwise.py CHANGED
@@ -20,10 +20,14 @@ class SampledSoftmaxLoss(nn.Module):
20
20
  super().__init__()
21
21
  self.reduction = reduction
22
22
 
23
- def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
23
+ def forward(
24
+ self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
25
+ ) -> torch.Tensor:
24
26
  pos_logits = pos_logits.unsqueeze(1)
25
27
  all_logits = torch.cat([pos_logits, neg_logits], dim=1)
26
- targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
28
+ targets = torch.zeros(
29
+ all_logits.size(0), dtype=torch.long, device=all_logits.device
30
+ )
27
31
  loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
28
32
  return loss
29
33
 
@@ -87,7 +91,11 @@ class ListMLELoss(nn.Module):
87
91
  def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
88
92
  sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
89
93
  batch_size, list_size = scores.shape
90
- batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
94
+ batch_indices = (
95
+ torch.arange(batch_size, device=scores.device)
96
+ .unsqueeze(1)
97
+ .expand(-1, list_size)
98
+ )
91
99
  sorted_scores = scores[batch_indices, sorted_indices]
92
100
 
93
101
  loss = torch.tensor(0.0, device=scores.device)
@@ -139,19 +147,19 @@ class ApproxNDCGLoss(nn.Module):
139
147
  device = scores.device
140
148
 
141
149
  # diff[b, i, j] = (s_j - s_i) / T
142
- scores_i = scores.unsqueeze(2) # [B, L, 1]
143
- scores_j = scores.unsqueeze(1) # [B, 1, L]
150
+ scores_i = scores.unsqueeze(2) # [B, L, 1]
151
+ scores_j = scores.unsqueeze(1) # [B, 1, L]
144
152
  diff = (scores_j - scores_i) / self.temperature # [B, L, L]
145
153
 
146
- P_ji = torch.sigmoid(diff) # [B, L, L]
154
+ P_ji = torch.sigmoid(diff) # [B, L, L]
147
155
  eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
148
156
  P_ji = P_ji * (1.0 - eye)
149
157
 
150
- exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
158
+ exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
151
159
 
152
160
  discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
153
161
 
154
- gains = torch.pow(2.0, labels) - 1.0 # [B, L]
162
+ gains = torch.pow(2.0, labels) - 1.0 # [B, L]
155
163
  approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
156
164
 
157
165
  ideal_dcg = self._ideal_dcg(labels, k) # [B]
@@ -163,4 +171,4 @@ class ApproxNDCGLoss(nn.Module):
163
171
  return loss.mean()
164
172
  if self.reduction == "sum":
165
173
  return loss.sum()
166
- return loss
174
+ return loss
@@ -6,8 +6,6 @@ Checkpoint: edit on 29/11/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
- from typing import Literal
10
-
11
9
  import torch.nn as nn
12
10
 
13
11
  from nextrec.loss.listwise import (
@@ -20,19 +18,19 @@ from nextrec.loss.listwise import (
20
18
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
21
19
  from nextrec.loss.pointwise import (
22
20
  ClassBalancedFocalLoss,
23
- CosineContrastiveLoss,
24
21
  FocalLoss,
25
22
  WeightedBCELoss,
26
23
  )
27
24
 
28
25
 
29
26
  VALID_TASK_TYPES = [
30
- "binary",
31
- "multiclass",
32
- "multilabel",
33
- "regression",
27
+ "binary",
28
+ "multiclass",
29
+ "multilabel",
30
+ "regression",
34
31
  ]
35
32
 
33
+
36
34
  def _build_cb_focal(kw):
37
35
  if "class_counts" not in kw:
38
36
  raise ValueError("class_balanced_focal requires class_counts")
@@ -81,6 +79,7 @@ def get_loss_fn(loss=None, **kw):
81
79
 
82
80
  raise ValueError(f"[Loss Error] Unsupported loss: {loss}")
83
81
 
82
+
84
83
  def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
85
84
  """
86
85
  Parse loss_kwargs for each head.
@@ -95,4 +94,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
95
94
  if index < len(loss_params) and loss_params[index] is not None:
96
95
  return loss_params[index]
97
96
  return {}
98
- return loss_params
97
+ return loss_params
nextrec/loss/pairwise.py CHANGED
@@ -36,6 +36,7 @@ class BPRLoss(nn.Module):
36
36
  return loss.sum()
37
37
  return loss
38
38
 
39
+
39
40
  class HingeLoss(nn.Module):
40
41
  """
41
42
  Hinge loss for pairwise ranking.
@@ -59,6 +60,7 @@ class HingeLoss(nn.Module):
59
60
  return loss.sum()
60
61
  return loss
61
62
 
63
+
62
64
  class TripletLoss(nn.Module):
63
65
  """
64
66
  Triplet margin loss with cosine or euclidean distance.
nextrec/loss/pointwise.py CHANGED
@@ -46,6 +46,7 @@ class WeightedBCELoss(nn.Module):
46
46
  If `auto_balance=True` and `pos_weight` is None, the positive weight is
47
47
  computed from the batch as (#neg / #pos) for stable imbalance handling.
48
48
  """
49
+
49
50
  def __init__(
50
51
  self,
51
52
  pos_weight: float | torch.Tensor | None = None,
@@ -59,11 +60,14 @@ class WeightedBCELoss(nn.Module):
59
60
  self.auto_balance = auto_balance
60
61
 
61
62
  if pos_weight is not None:
62
- self.register_buffer("pos_weight", torch.as_tensor(pos_weight, dtype=torch.float32),)
63
+ self.register_buffer(
64
+ "pos_weight",
65
+ torch.as_tensor(pos_weight, dtype=torch.float32),
66
+ )
63
67
  else:
64
68
  self.pos_weight = None
65
69
 
66
- def _resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
70
+ def resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
67
71
  if self.pos_weight is not None:
68
72
  return self.pos_weight.to(device=labels.device)
69
73
 
@@ -77,7 +81,7 @@ class WeightedBCELoss(nn.Module):
77
81
 
78
82
  def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
79
83
  labels = labels.float()
80
- current_pos_weight = self._resolve_pos_weight(labels)
84
+ current_pos_weight = self.resolve_pos_weight(labels)
81
85
  current_pos_weight = current_pos_weight.to(inputs.dtype)
82
86
 
83
87
  if self.logits:
@@ -120,23 +124,27 @@ class FocalLoss(nn.Module):
120
124
  if inputs.dim() > 1 and inputs.size(1) > 1:
121
125
  log_probs = F.log_softmax(inputs, dim=1)
122
126
  probs = log_probs.exp()
123
- targets_one_hot = F.one_hot(targets.long(), num_classes=inputs.size(1)).float()
127
+ targets_one_hot = F.one_hot(
128
+ targets.long(), num_classes=inputs.size(1)
129
+ ).float()
124
130
 
125
- alpha = self._get_alpha(inputs)
131
+ alpha = self.get_alpha(inputs)
126
132
  alpha_factor = targets_one_hot * alpha
127
133
  focal_weight = (1.0 - probs) ** self.gamma
128
134
  loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
129
135
  else:
130
136
  targets = targets.float()
131
137
  if self.logits:
132
- ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
138
+ ce_loss = F.binary_cross_entropy_with_logits(
139
+ inputs, targets, reduction="none"
140
+ )
133
141
  probs = torch.sigmoid(inputs)
134
142
  else:
135
143
  ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
136
144
  probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
137
145
 
138
146
  p_t = probs * targets + (1 - probs) * (1 - targets)
139
- alpha_factor = self._get_binary_alpha(targets, inputs.device)
147
+ alpha_factor = self.get_binary_alpha(targets, inputs.device)
140
148
  focal_weight = (1.0 - p_t) ** self.gamma
141
149
  loss = alpha_factor * focal_weight * ce_loss
142
150
  if self.reduction == "mean":
@@ -145,27 +153,35 @@ class FocalLoss(nn.Module):
145
153
  return loss.sum()
146
154
  return loss
147
155
 
148
- def _get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
156
+ def get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
149
157
  if self.alpha is None:
150
158
  return torch.ones_like(inputs)
151
159
  if isinstance(self.alpha, torch.Tensor):
152
160
  return self.alpha.to(inputs.device)
153
- alpha_tensor = torch.tensor(self.alpha, device=inputs.device, dtype=inputs.dtype)
161
+ alpha_tensor = torch.tensor(
162
+ self.alpha, device=inputs.device, dtype=inputs.dtype
163
+ )
154
164
  return alpha_tensor
155
165
 
156
- def _get_binary_alpha(self, targets: torch.Tensor, device: torch.device) -> torch.Tensor:
166
+ def get_binary_alpha(
167
+ self, targets: torch.Tensor, device: torch.device
168
+ ) -> torch.Tensor:
157
169
  if self.alpha is None:
158
170
  return torch.ones_like(targets)
159
171
  if isinstance(self.alpha, (float, int)):
160
- return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(device)
172
+ return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(
173
+ device
174
+ )
161
175
  alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
162
176
  return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
163
177
 
178
+
164
179
  class ClassBalancedFocalLoss(nn.Module):
165
180
  """
166
181
  Focal loss weighted by effective number of samples per class.
167
182
  Reference: "Class-Balanced Loss Based on Effective Number of Samples"
168
183
  """
184
+
169
185
  def __init__(
170
186
  self,
171
187
  class_counts: Sequence[int] | torch.Tensor,
@@ -183,7 +199,9 @@ class ClassBalancedFocalLoss(nn.Module):
183
199
  self.register_buffer("class_weights", weights)
184
200
 
185
201
  def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
186
- focal = FocalLoss(gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True)
202
+ focal = FocalLoss(
203
+ gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
204
+ )
187
205
  loss = focal(inputs, targets)
188
206
  if self.reduction == "mean":
189
207
  return loss.mean()
@@ -1,14 +1,14 @@
1
1
  """
2
2
  [Info: this version is not released yet, i need to more research on source code and paper]
3
3
  Date: create on 01/12/2025
4
- Checkpoint: edit on 01/12/2025
4
+ Checkpoint: edit on 01/12/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  Reference:
7
7
  [1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
8
8
  [2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
9
9
 
10
- Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
11
- Meta’s Generative Recommenders. It replaces softmax attention with lightweight
10
+ Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
11
+ Meta’s Generative Recommenders. It replaces softmax attention with lightweight
12
12
  pointwise activations, enabling extremely deep stacks on long behavior sequences.
13
13
 
14
14
  In each HSTU layer:
@@ -16,8 +16,8 @@ In each HSTU layer:
16
16
  (2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
17
17
  (3) Aggregated context is modulated by U-gating and mapped back through an output projection
18
18
 
19
- Stacking layers yields an efficient causal encoder for next-item
20
- generation. With a tied-embedding LM head, HSTU forms
19
+ Stacking layers yields an efficient causal encoder for next-item
20
+ generation. With a tied-embedding LM head, HSTU forms
21
21
  a full generative recommendation model.
22
22
 
23
23
  Key Advantages:
@@ -75,7 +75,16 @@ def _relative_position_bucket(
75
75
  is_small = n < max_exact
76
76
 
77
77
  # when the distance is too far, do log scaling
78
- large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
78
+ large_val = (
79
+ max_exact
80
+ + (
81
+ (
82
+ torch.log(n.float() / max_exact + 1e-6)
83
+ / math.log(max_distance / max_exact)
84
+ )
85
+ * (num_buckets - max_exact)
86
+ ).long()
87
+ )
79
88
  large_val = torch.clamp(large_val, max=num_buckets - 1)
80
89
 
81
90
  buckets = torch.where(is_small, n.long(), large_val)
@@ -104,10 +113,19 @@ class RelativePositionBias(nn.Module):
104
113
  # positions: [T]
105
114
  ctx = torch.arange(seq_len, device=device)[:, None]
106
115
  mem = torch.arange(seq_len, device=device)[None, :]
107
- rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
108
- buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
109
- values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
110
- return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
116
+ rel_pos = (
117
+ mem - ctx
118
+ ) # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
119
+ buckets = _relative_position_bucket(
120
+ rel_pos,
121
+ num_buckets=self.num_buckets,
122
+ max_distance=self.max_distance,
123
+ ) # map to buckets
124
+ values = self.embedding(
125
+ buckets
126
+ ) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
127
+ return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
128
+
111
129
 
112
130
  class HSTUPointwiseAttention(nn.Module):
113
131
  """
@@ -123,16 +141,18 @@ class HSTUPointwiseAttention(nn.Module):
123
141
  d_model: int,
124
142
  num_heads: int,
125
143
  dropout: float = 0.1,
126
- alpha: float | None = None
144
+ alpha: float | None = None,
127
145
  ):
128
146
  super().__init__()
129
147
  if d_model % num_heads != 0:
130
- raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
148
+ raise ValueError(
149
+ f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0"
150
+ )
131
151
 
132
152
  self.d_model = d_model
133
153
  self.num_heads = num_heads
134
154
  self.d_head = d_model // num_heads
135
- self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
155
+ self.alpha = alpha if alpha is not None else (self.d_head**-0.5)
136
156
  # project input to 4 * d_model for U, V, Q, K
137
157
  self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
138
158
  # project output back to d_model
@@ -150,9 +170,9 @@ class HSTUPointwiseAttention(nn.Module):
150
170
  def forward(
151
171
  self,
152
172
  x: torch.Tensor,
153
- attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
173
+ attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
154
174
  key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
155
- rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
175
+ rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
156
176
  ) -> torch.Tensor:
157
177
  B, T, D = x.shape
158
178
 
@@ -185,8 +205,8 @@ class HSTUPointwiseAttention(nn.Module):
185
205
  # padding mask: key_padding_mask is usually [B, T], True = pad
186
206
  if key_padding_mask is not None:
187
207
  # valid: 1 for non-pad, 0 for pad
188
- valid = (~key_padding_mask).float() # [B, T]
189
- valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
208
+ valid = (~key_padding_mask).float() # [B, T]
209
+ valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
190
210
  allowed = allowed * valid
191
211
  logits = logits.masked_fill(valid == 0, float("-inf"))
192
212
 
@@ -197,7 +217,7 @@ class HSTUPointwiseAttention(nn.Module):
197
217
 
198
218
  attn = attn / denom # [B, H, T, T]
199
219
  AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
200
- AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
220
+ AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
201
221
  U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
202
222
  y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
203
223
  return y
@@ -218,10 +238,20 @@ class HSTULayer(nn.Module):
218
238
  rab_max_distance: int = 128,
219
239
  ):
220
240
  super().__init__()
221
- self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
241
+ self.attn = HSTUPointwiseAttention(
242
+ d_model=d_model, num_heads=num_heads, dropout=dropout
243
+ )
222
244
  self.dropout = nn.Dropout(dropout)
223
245
  self.use_rab_pos = use_rab_pos
224
- self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
246
+ self.rel_pos_bias = (
247
+ RelativePositionBias(
248
+ num_heads=num_heads,
249
+ num_buckets=rab_num_buckets,
250
+ max_distance=rab_max_distance,
251
+ )
252
+ if use_rab_pos
253
+ else None
254
+ )
225
255
 
226
256
  def forward(
227
257
  self,
@@ -236,8 +266,10 @@ class HSTULayer(nn.Module):
236
266
  device = x.device
237
267
  rab = None
238
268
  if self.use_rab_pos:
239
- rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
240
- out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
269
+ rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
270
+ out = self.attn(
271
+ x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab
272
+ )
241
273
  return x + self.dropout(out)
242
274
 
243
275
 
@@ -255,7 +287,7 @@ class HSTU(BaseModel):
255
287
  return "HSTU"
256
288
 
257
289
  @property
258
- def task_type(self) -> str:
290
+ def default_task(self) -> str:
259
291
  return "multiclass"
260
292
 
261
293
  def __init__(
@@ -272,9 +304,9 @@ class HSTU(BaseModel):
272
304
  use_rab_pos: bool = True,
273
305
  rab_num_buckets: int = 32,
274
306
  rab_max_distance: int = 128,
275
-
276
307
  tie_embeddings: bool = True,
277
308
  target: Optional[list[str] | str] = None,
309
+ task: str | list[str] | None = None,
278
310
  optimizer: str = "adam",
279
311
  optimizer_params: Optional[dict] = None,
280
312
  scheduler: Optional[str] = None,
@@ -288,17 +320,25 @@ class HSTU(BaseModel):
288
320
  **kwargs,
289
321
  ):
290
322
  if not sequence_features:
291
- raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
323
+ raise ValueError(
324
+ "[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history)."
325
+ )
292
326
 
293
327
  # demo version: use the first SequenceFeature as the main sequence
294
328
  self.history_feature = sequence_features[0]
295
329
 
296
- hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
330
+ hidden_dim = d_model or max(
331
+ int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32
332
+ )
297
333
  # Make hidden_dim divisible by num_heads
298
334
  if hidden_dim % num_heads != 0:
299
335
  hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
300
336
 
301
- self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
337
+ self.padding_idx = (
338
+ self.history_feature.padding_idx
339
+ if self.history_feature.padding_idx is not None
340
+ else 0
341
+ )
302
342
  self.vocab_size = self.history_feature.vocab_size
303
343
  self.max_seq_len = max_seq_len
304
344
 
@@ -307,7 +347,7 @@ class HSTU(BaseModel):
307
347
  sparse_features=sparse_features,
308
348
  sequence_features=sequence_features,
309
349
  target=target,
310
- task=self.task_type,
350
+ task=task or self.default_task,
311
351
  device=device,
312
352
  embedding_l1_reg=embedding_l1_reg,
313
353
  dense_l1_reg=dense_l1_reg,
@@ -326,8 +366,19 @@ class HSTU(BaseModel):
326
366
  self.input_dropout = nn.Dropout(dropout)
327
367
 
328
368
  # HSTU layers
329
- self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
330
- rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
369
+ self.layers = nn.ModuleList(
370
+ [
371
+ HSTULayer(
372
+ d_model=hidden_dim,
373
+ num_heads=num_heads,
374
+ dropout=dropout,
375
+ use_rab_pos=use_rab_pos,
376
+ rab_num_buckets=rab_num_buckets,
377
+ rab_max_distance=rab_max_distance,
378
+ )
379
+ for _ in range(num_layers)
380
+ ]
381
+ )
331
382
 
332
383
  self.final_norm = nn.LayerNorm(hidden_dim)
333
384
  self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
@@ -343,8 +394,17 @@ class HSTU(BaseModel):
343
394
  loss_params = loss_params or {}
344
395
  loss_params.setdefault("ignore_index", self.ignore_index)
345
396
 
346
- self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
347
- self.register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
397
+ self.compile(
398
+ optimizer=optimizer,
399
+ optimizer_params=optimizer_params,
400
+ scheduler=scheduler,
401
+ scheduler_params=scheduler_params,
402
+ loss="crossentropy",
403
+ loss_params=loss_params,
404
+ )
405
+ self.register_regularization_weights(
406
+ embedding_attr="token_embedding", include_modules=["layers", "lm_head"]
407
+ )
348
408
 
349
409
  def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
350
410
  """
@@ -353,7 +413,7 @@ class HSTU(BaseModel):
353
413
  """
354
414
  if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
355
415
  mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
356
- mask = torch.triu(mask, diagonal=1)
416
+ mask = torch.triu(mask, diagonal=1)
357
417
  self.causal_mask = mask
358
418
  return self.causal_mask[:seq_len, :seq_len]
359
419
 
@@ -364,27 +424,31 @@ class HSTU(BaseModel):
364
424
 
365
425
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
366
426
  seq = x[self.history_feature.name].long() # [B, T_raw]
367
- seq = self._trim_sequence(seq) # [B, T]
427
+ seq = self._trim_sequence(seq) # [B, T]
368
428
 
369
429
  B, T = seq.shape
370
430
  device = seq.device
371
431
  # position ids: [B, T]
372
432
  pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
373
- token_emb = self.token_embedding(seq) # [B, T, D]
374
- pos_emb = self.position_embedding(pos_ids) # [B, T, D]
433
+ token_emb = self.token_embedding(seq) # [B, T, D]
434
+ pos_emb = self.position_embedding(pos_ids) # [B, T, D]
375
435
  hidden_states = self.input_dropout(token_emb + pos_emb)
376
436
 
377
437
  # padding mask:True = pad
378
- padding_mask = seq.eq(self.padding_idx) # [B, T]
438
+ padding_mask = seq.eq(self.padding_idx) # [B, T]
379
439
  attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
380
440
 
381
441
  for layer in self.layers:
382
- hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
442
+ hidden_states = layer(
443
+ x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask
444
+ )
383
445
  hidden_states = self.final_norm(hidden_states) # [B, T, D]
384
446
 
385
447
  valid_lengths = (~padding_mask).sum(dim=1) # [B]
386
448
  last_index = (valid_lengths - 1).clamp(min=0)
387
- last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
449
+ last_hidden = hidden_states[
450
+ torch.arange(B, device=device), last_index
451
+ ] # [B, D]
388
452
 
389
453
  logits = self.lm_head(last_hidden) # [B, vocab_size]
390
454
  return logits
@@ -394,6 +458,8 @@ class HSTU(BaseModel):
394
458
  y_true: [B] or [B, 1], the id of the next item.
395
459
  """
396
460
  if y_true is None:
397
- raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
461
+ raise ValueError(
462
+ "[HSTU-compute_loss] Training requires y_true (next item id)."
463
+ )
398
464
  labels = y_true.view(-1).long()
399
465
  return self.loss_fn[0](y_pred, labels)