nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +285 -186
- nextrec/data/data_processing.py +1 -3
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.7.dist-info/METADATA +376 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/RECORD +26 -25
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
nextrec/data/data_processing.py
CHANGED
|
@@ -25,9 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
25
25
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def split_dict_random(
|
|
29
|
-
data_dict: dict, test_size: float = 0.2, random_state: int | None = None
|
|
30
|
-
):
|
|
28
|
+
def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
31
29
|
|
|
32
30
|
lengths = [len(v) for v in data_dict.values()]
|
|
33
31
|
if len(set(lengths)) != 1:
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
Loss utilities for NextRec.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 17/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
9
11
|
import torch.nn as nn
|
|
10
12
|
|
|
11
13
|
from nextrec.loss.listwise import (
|
|
@@ -30,14 +32,81 @@ VALID_TASK_TYPES = [
|
|
|
30
32
|
"regression",
|
|
31
33
|
]
|
|
32
34
|
|
|
35
|
+
# Define all supported loss types
|
|
36
|
+
LossType = Literal[
|
|
37
|
+
# Pointwise losses
|
|
38
|
+
"bce",
|
|
39
|
+
"binary_crossentropy",
|
|
40
|
+
"weighted_bce",
|
|
41
|
+
"focal",
|
|
42
|
+
"focal_loss",
|
|
43
|
+
"cb_focal",
|
|
44
|
+
"class_balanced_focal",
|
|
45
|
+
"crossentropy",
|
|
46
|
+
"ce",
|
|
47
|
+
"mse",
|
|
48
|
+
"mae",
|
|
49
|
+
# Pairwise ranking losses
|
|
50
|
+
"bpr",
|
|
51
|
+
"hinge",
|
|
52
|
+
"triplet",
|
|
53
|
+
# Listwise ranking losses
|
|
54
|
+
"sampled_softmax",
|
|
55
|
+
"softmax",
|
|
56
|
+
"infonce",
|
|
57
|
+
"listnet",
|
|
58
|
+
"listmle",
|
|
59
|
+
"approx_ndcg",
|
|
60
|
+
]
|
|
61
|
+
|
|
33
62
|
|
|
34
|
-
def
|
|
63
|
+
def build_cb_focal(kw):
|
|
35
64
|
if "class_counts" not in kw:
|
|
36
65
|
raise ValueError("class_balanced_focal requires class_counts")
|
|
37
66
|
return ClassBalancedFocalLoss(**kw)
|
|
38
67
|
|
|
39
68
|
|
|
40
|
-
def get_loss_fn(loss=None, **kw):
|
|
69
|
+
def get_loss_fn(loss: LossType | nn.Module | None = None, **kw) -> nn.Module:
|
|
70
|
+
"""
|
|
71
|
+
Get loss function by name or return the provided loss module.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
loss: Loss function name or nn.Module instance. Supported options:
|
|
75
|
+
|
|
76
|
+
**Pointwise Losses:**
|
|
77
|
+
- "bce", "binary_crossentropy": Binary Cross-Entropy Loss
|
|
78
|
+
- "weighted_bce": Weighted Binary Cross-Entropy Loss
|
|
79
|
+
- "focal", "focal_loss": Focal Loss (for class imbalance)
|
|
80
|
+
- "cb_focal", "class_balanced_focal": Class-Balanced Focal Loss (requires class_counts parameter)
|
|
81
|
+
- "crossentropy", "ce": Cross-Entropy Loss for multi-class classification
|
|
82
|
+
- "mse": Mean Squared Error Loss
|
|
83
|
+
- "mae": Mean Absolute Error Loss
|
|
84
|
+
|
|
85
|
+
**Pairwise Ranking Losses:**
|
|
86
|
+
- "bpr": Bayesian Personalized Ranking Loss
|
|
87
|
+
- "hinge": Hinge Loss
|
|
88
|
+
- "triplet": Triplet Loss
|
|
89
|
+
|
|
90
|
+
**Listwise Ranking Losses:**
|
|
91
|
+
- "sampled_softmax", "softmax": Sampled Softmax Loss
|
|
92
|
+
- "infonce": InfoNCE Loss
|
|
93
|
+
- "listnet": ListNet Loss
|
|
94
|
+
- "listmle": ListMLE Loss
|
|
95
|
+
- "approx_ndcg": Approximate NDCG Loss
|
|
96
|
+
|
|
97
|
+
**kw: Additional keyword arguments passed to the loss function
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
nn.Module: Loss function instance
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ValueError: If loss is None or unsupported type
|
|
104
|
+
|
|
105
|
+
Examples:
|
|
106
|
+
>>> loss_fn = get_loss_fn("bce")
|
|
107
|
+
>>> loss_fn = get_loss_fn("focal", alpha=0.25, gamma=2.0)
|
|
108
|
+
>>> loss_fn = get_loss_fn("cb_focal", class_counts=[100, 50, 200])
|
|
109
|
+
"""
|
|
41
110
|
if isinstance(loss, nn.Module):
|
|
42
111
|
return loss
|
|
43
112
|
if loss is None:
|
|
@@ -49,7 +118,7 @@ def get_loss_fn(loss=None, **kw):
|
|
|
49
118
|
if loss in ["focal", "focal_loss"]:
|
|
50
119
|
return FocalLoss(**kw)
|
|
51
120
|
if loss in ["cb_focal", "class_balanced_focal"]:
|
|
52
|
-
return
|
|
121
|
+
return build_cb_focal(kw)
|
|
53
122
|
if loss in ["crossentropy", "ce"]:
|
|
54
123
|
return nn.CrossEntropyLoss(**kw)
|
|
55
124
|
if loss == "mse":
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generative Recommendation Models
|
|
3
|
+
|
|
4
|
+
This module contains generative models for recommendation tasks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from nextrec.models.generative.hstu import HSTU
|
|
8
|
+
from nextrec.models.generative.rqvae import (
|
|
9
|
+
RQVAE,
|
|
10
|
+
RQ,
|
|
11
|
+
VQEmbedding,
|
|
12
|
+
BalancedKmeans,
|
|
13
|
+
kmeans,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = ["HSTU", "RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
[Info: this version is not released yet, i need to more research on source code and paper]
|
|
3
3
|
Date: create on 01/12/2025
|
|
4
|
-
Checkpoint: edit on
|
|
4
|
+
Checkpoint: edit on 11/12/2025
|
|
5
5
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
6
|
Reference:
|
|
7
7
|
[1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
|
|
@@ -55,10 +55,13 @@ import torch.nn as nn
|
|
|
55
55
|
import torch.nn.functional as F
|
|
56
56
|
|
|
57
57
|
from nextrec.basic.model import BaseModel
|
|
58
|
+
from nextrec.basic.layers import RMSNorm, EmbeddingLayer
|
|
58
59
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
59
60
|
|
|
61
|
+
from nextrec.utils.model import select_features
|
|
60
62
|
|
|
61
|
-
|
|
63
|
+
|
|
64
|
+
def relative_position_bucket(
|
|
62
65
|
relative_position: torch.Tensor,
|
|
63
66
|
num_buckets: int = 32,
|
|
64
67
|
max_distance: int = 128,
|
|
@@ -116,7 +119,7 @@ class RelativePositionBias(nn.Module):
|
|
|
116
119
|
rel_pos = (
|
|
117
120
|
mem - ctx
|
|
118
121
|
) # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
|
|
119
|
-
buckets =
|
|
122
|
+
buckets = relative_position_bucket(
|
|
120
123
|
rel_pos,
|
|
121
124
|
num_buckets=self.num_buckets,
|
|
122
125
|
max_distance=self.max_distance,
|
|
@@ -138,39 +141,40 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
138
141
|
|
|
139
142
|
def __init__(
|
|
140
143
|
self,
|
|
141
|
-
|
|
144
|
+
hidden_dim: int,
|
|
142
145
|
num_heads: int,
|
|
143
146
|
dropout: float = 0.1,
|
|
144
147
|
alpha: float | None = None,
|
|
148
|
+
use_rms_norm: bool = False,
|
|
145
149
|
):
|
|
146
150
|
super().__init__()
|
|
147
|
-
if
|
|
151
|
+
if hidden_dim % num_heads != 0:
|
|
148
152
|
raise ValueError(
|
|
149
|
-
f"[HSTUPointwiseAttention Error]
|
|
153
|
+
f"[HSTUPointwiseAttention Error] hidden_dim({hidden_dim}) % num_heads({num_heads}) != 0"
|
|
150
154
|
)
|
|
151
155
|
|
|
152
|
-
self.
|
|
156
|
+
self.hidden_dim = hidden_dim
|
|
153
157
|
self.num_heads = num_heads
|
|
154
|
-
self.
|
|
155
|
-
self.alpha = alpha if alpha is not None else (self.
|
|
156
|
-
# project input to 4 *
|
|
157
|
-
self.in_proj = nn.Linear(
|
|
158
|
-
# project output back to
|
|
159
|
-
self.out_proj = nn.Linear(
|
|
158
|
+
self.head_dim = hidden_dim // num_heads
|
|
159
|
+
self.alpha = alpha if alpha is not None else (self.head_dim**-0.5)
|
|
160
|
+
# project input to 4 * hidden_dim for U, V, Q, K
|
|
161
|
+
self.in_proj = nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
|
|
162
|
+
# project output back to hidden_dim
|
|
163
|
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
|
160
164
|
self.dropout = nn.Dropout(dropout)
|
|
161
|
-
self.norm = nn.LayerNorm(
|
|
165
|
+
self.norm = RMSNorm(hidden_dim) if use_rms_norm else nn.LayerNorm(hidden_dim)
|
|
162
166
|
|
|
163
|
-
def
|
|
167
|
+
def reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
|
|
164
168
|
"""
|
|
165
|
-
[B, T, D] -> [B, H, T,
|
|
169
|
+
[B, T, D] -> [B, H, T, head_dim]
|
|
166
170
|
"""
|
|
167
171
|
B, T, D = x.shape
|
|
168
|
-
return x.view(B, T, self.num_heads, self.
|
|
172
|
+
return x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
|
|
169
173
|
|
|
170
174
|
def forward(
|
|
171
175
|
self,
|
|
172
176
|
x: torch.Tensor,
|
|
173
|
-
|
|
177
|
+
attention_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
|
|
174
178
|
key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
|
|
175
179
|
rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
|
|
176
180
|
) -> torch.Tensor:
|
|
@@ -180,10 +184,10 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
180
184
|
h = F.silu(self.in_proj(x)) # [B, T, 4D]
|
|
181
185
|
U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
|
|
182
186
|
|
|
183
|
-
Qh = self.
|
|
184
|
-
Kh = self.
|
|
185
|
-
Vh = self.
|
|
186
|
-
Uh = self.
|
|
187
|
+
Qh = self.reshape_heads(Q) # [B, H, T, d_head]
|
|
188
|
+
Kh = self.reshape_heads(K) # [B, H, T, d_head]
|
|
189
|
+
Vh = self.reshape_heads(V) # [B, H, T, d_head]
|
|
190
|
+
Uh = self.reshape_heads(U) # [B, H, T, d_head]
|
|
187
191
|
|
|
188
192
|
# attention logits: QK^T (without 1/sqrt(d) and softmax)
|
|
189
193
|
logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
|
|
@@ -197,10 +201,10 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
197
201
|
# 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
|
|
198
202
|
allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
|
|
199
203
|
|
|
200
|
-
# causal mask:
|
|
201
|
-
if
|
|
202
|
-
allowed = allowed * (
|
|
203
|
-
logits = logits +
|
|
204
|
+
# causal mask: attention_mask is usually an upper triangular matrix of -inf with shape [T, T]
|
|
205
|
+
if attention_mask is not None:
|
|
206
|
+
allowed = allowed * (attention_mask.view(1, 1, T, T) == 0).float()
|
|
207
|
+
logits = logits + attention_mask.view(1, 1, T, T)
|
|
204
208
|
|
|
205
209
|
# padding mask: key_padding_mask is usually [B, T], True = pad
|
|
206
210
|
if key_padding_mask is not None:
|
|
@@ -211,12 +215,15 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
211
215
|
logits = logits.masked_fill(valid == 0, float("-inf"))
|
|
212
216
|
|
|
213
217
|
# Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
|
|
214
|
-
|
|
218
|
+
# Note: F.silu(-inf) = nan, so we need to handle -inf values carefully
|
|
219
|
+
# Replace -inf with a very negative value before silu to avoid nan
|
|
220
|
+
logits_safe = logits.masked_fill(torch.isinf(logits) & (logits < 0), -1e9)
|
|
221
|
+
attention = F.silu(logits_safe) # [B, H, T, T]
|
|
215
222
|
denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
|
|
216
223
|
denom = denom.clamp(min=1.0)
|
|
217
224
|
|
|
218
|
-
|
|
219
|
-
AV = torch.matmul(
|
|
225
|
+
attention = attention / denom # [B, H, T, T]
|
|
226
|
+
AV = torch.matmul(attention, Vh) # [B, H, T, head_dim]
|
|
220
227
|
AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
|
|
221
228
|
U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
|
|
222
229
|
y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
|
|
@@ -230,16 +237,20 @@ class HSTULayer(nn.Module):
|
|
|
230
237
|
|
|
231
238
|
def __init__(
|
|
232
239
|
self,
|
|
233
|
-
|
|
240
|
+
hidden_dim: int,
|
|
234
241
|
num_heads: int,
|
|
235
242
|
dropout: float = 0.1,
|
|
236
243
|
use_rab_pos: bool = True,
|
|
237
244
|
rab_num_buckets: int = 32,
|
|
238
245
|
rab_max_distance: int = 128,
|
|
246
|
+
use_rms_norm: bool = False,
|
|
239
247
|
):
|
|
240
248
|
super().__init__()
|
|
241
|
-
self.
|
|
242
|
-
|
|
249
|
+
self.attention = HSTUPointwiseAttention(
|
|
250
|
+
hidden_dim=hidden_dim,
|
|
251
|
+
num_heads=num_heads,
|
|
252
|
+
dropout=dropout,
|
|
253
|
+
use_rms_norm=use_rms_norm,
|
|
243
254
|
)
|
|
244
255
|
self.dropout = nn.Dropout(dropout)
|
|
245
256
|
self.use_rab_pos = use_rab_pos
|
|
@@ -256,7 +267,7 @@ class HSTULayer(nn.Module):
|
|
|
256
267
|
def forward(
|
|
257
268
|
self,
|
|
258
269
|
x: torch.Tensor,
|
|
259
|
-
|
|
270
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
260
271
|
key_padding_mask: Optional[torch.Tensor] = None,
|
|
261
272
|
) -> torch.Tensor:
|
|
262
273
|
"""
|
|
@@ -267,8 +278,11 @@ class HSTULayer(nn.Module):
|
|
|
267
278
|
rab = None
|
|
268
279
|
if self.use_rab_pos:
|
|
269
280
|
rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
|
|
270
|
-
out = self.
|
|
271
|
-
x=x,
|
|
281
|
+
out = self.attention(
|
|
282
|
+
x=x,
|
|
283
|
+
attention_mask=attention_mask,
|
|
284
|
+
key_padding_mask=key_padding_mask,
|
|
285
|
+
rab=rab,
|
|
272
286
|
)
|
|
273
287
|
return x + self.dropout(out)
|
|
274
288
|
|
|
@@ -295,7 +309,8 @@ class HSTU(BaseModel):
|
|
|
295
309
|
sequence_features: list[SequenceFeature],
|
|
296
310
|
dense_features: Optional[list[DenseFeature]] = None,
|
|
297
311
|
sparse_features: Optional[list[SparseFeature]] = None,
|
|
298
|
-
|
|
312
|
+
item_history: str = "item_history",
|
|
313
|
+
hidden_dim: Optional[int] = None,
|
|
299
314
|
num_heads: int = 8,
|
|
300
315
|
num_layers: int = 4,
|
|
301
316
|
max_seq_len: int = 200,
|
|
@@ -304,6 +319,8 @@ class HSTU(BaseModel):
|
|
|
304
319
|
use_rab_pos: bool = True,
|
|
305
320
|
rab_num_buckets: int = 32,
|
|
306
321
|
rab_max_distance: int = 128,
|
|
322
|
+
# Normalization settings
|
|
323
|
+
use_rms_norm: bool = False,
|
|
307
324
|
tie_embeddings: bool = True,
|
|
308
325
|
target: Optional[list[str] | str] = None,
|
|
309
326
|
task: str | list[str] | None = None,
|
|
@@ -324,22 +341,23 @@ class HSTU(BaseModel):
|
|
|
324
341
|
"[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history)."
|
|
325
342
|
)
|
|
326
343
|
|
|
327
|
-
|
|
328
|
-
|
|
344
|
+
self.item_history_feature = select_features(
|
|
345
|
+
sequence_features, [item_history], "item_history"
|
|
346
|
+
)[0]
|
|
329
347
|
|
|
330
|
-
hidden_dim =
|
|
331
|
-
int(getattr(self.
|
|
348
|
+
self.hidden_dim = hidden_dim or max(
|
|
349
|
+
int(getattr(self.item_history_feature, "embedding_dim", 0) or 0), 32
|
|
332
350
|
)
|
|
333
351
|
# Make hidden_dim divisible by num_heads
|
|
334
|
-
if hidden_dim % num_heads != 0:
|
|
335
|
-
hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
|
|
352
|
+
if self.hidden_dim % num_heads != 0:
|
|
353
|
+
self.hidden_dim = num_heads * math.ceil(self.hidden_dim / num_heads)
|
|
336
354
|
|
|
337
355
|
self.padding_idx = (
|
|
338
|
-
self.
|
|
339
|
-
if self.
|
|
356
|
+
self.item_history_feature.padding_idx
|
|
357
|
+
if self.item_history_feature.padding_idx is not None
|
|
340
358
|
else 0
|
|
341
359
|
)
|
|
342
|
-
self.vocab_size = self.
|
|
360
|
+
self.vocab_size = self.item_history_feature.vocab_size
|
|
343
361
|
self.max_seq_len = max_seq_len
|
|
344
362
|
|
|
345
363
|
super().__init__(
|
|
@@ -356,32 +374,51 @@ class HSTU(BaseModel):
|
|
|
356
374
|
**kwargs,
|
|
357
375
|
)
|
|
358
376
|
|
|
377
|
+
# Optional contextual encoders (user/item attributes, real-time context, etc.)
|
|
378
|
+
self.context_features = [
|
|
379
|
+
feat
|
|
380
|
+
for feat in self.all_features
|
|
381
|
+
if feat.name != self.item_history_feature.name
|
|
382
|
+
]
|
|
383
|
+
self.context_embedding = (
|
|
384
|
+
EmbeddingLayer(self.context_features) if self.context_features else None
|
|
385
|
+
)
|
|
386
|
+
self.context_proj = (
|
|
387
|
+
nn.Linear(self.context_embedding.output_dim, self.hidden_dim)
|
|
388
|
+
if self.context_embedding is not None
|
|
389
|
+
else None
|
|
390
|
+
)
|
|
391
|
+
self.context_dropout = nn.Dropout(dropout) if self.context_embedding else None
|
|
392
|
+
|
|
359
393
|
# token & position embedding (paper usually includes pos embedding / RAB in encoder)
|
|
360
394
|
self.token_embedding = nn.Embedding(
|
|
361
395
|
num_embeddings=self.vocab_size,
|
|
362
|
-
embedding_dim=hidden_dim,
|
|
396
|
+
embedding_dim=self.hidden_dim,
|
|
363
397
|
padding_idx=self.padding_idx,
|
|
364
398
|
)
|
|
365
|
-
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
|
|
399
|
+
self.position_embedding = nn.Embedding(max_seq_len, self.hidden_dim)
|
|
366
400
|
self.input_dropout = nn.Dropout(dropout)
|
|
367
401
|
|
|
368
402
|
# HSTU layers
|
|
369
403
|
self.layers = nn.ModuleList(
|
|
370
404
|
[
|
|
371
405
|
HSTULayer(
|
|
372
|
-
|
|
406
|
+
hidden_dim=self.hidden_dim,
|
|
373
407
|
num_heads=num_heads,
|
|
374
408
|
dropout=dropout,
|
|
375
409
|
use_rab_pos=use_rab_pos,
|
|
376
410
|
rab_num_buckets=rab_num_buckets,
|
|
377
411
|
rab_max_distance=rab_max_distance,
|
|
412
|
+
use_rms_norm=use_rms_norm,
|
|
378
413
|
)
|
|
379
414
|
for _ in range(num_layers)
|
|
380
415
|
]
|
|
381
416
|
)
|
|
382
417
|
|
|
383
|
-
self.final_norm =
|
|
384
|
-
|
|
418
|
+
self.final_norm = (
|
|
419
|
+
RMSNorm(self.hidden_dim) if use_rms_norm else nn.LayerNorm(self.hidden_dim)
|
|
420
|
+
)
|
|
421
|
+
self.lm_head = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
|
385
422
|
if tie_embeddings:
|
|
386
423
|
self.lm_head.weight = self.token_embedding.weight
|
|
387
424
|
|
|
@@ -403,10 +440,11 @@ class HSTU(BaseModel):
|
|
|
403
440
|
loss_params=loss_params,
|
|
404
441
|
)
|
|
405
442
|
self.register_regularization_weights(
|
|
406
|
-
embedding_attr="token_embedding",
|
|
443
|
+
embedding_attr="token_embedding",
|
|
444
|
+
include_modules=["layers", "lm_head", "context_proj"],
|
|
407
445
|
)
|
|
408
446
|
|
|
409
|
-
def
|
|
447
|
+
def build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
410
448
|
"""
|
|
411
449
|
build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
|
|
412
450
|
This will be added to the logits to simulate causal structure.
|
|
@@ -417,14 +455,14 @@ class HSTU(BaseModel):
|
|
|
417
455
|
self.causal_mask = mask
|
|
418
456
|
return self.causal_mask[:seq_len, :seq_len]
|
|
419
457
|
|
|
420
|
-
def
|
|
458
|
+
def trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
|
|
421
459
|
if seq.size(1) <= self.max_seq_len:
|
|
422
460
|
return seq
|
|
423
461
|
return seq[:, -self.max_seq_len :]
|
|
424
462
|
|
|
425
463
|
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
426
|
-
seq = x[self.
|
|
427
|
-
seq = self.
|
|
464
|
+
seq = x[self.item_history_feature.name].long() # [B, T_raw]
|
|
465
|
+
seq = self.trim_sequence(seq) # [B, T]
|
|
428
466
|
|
|
429
467
|
B, T = seq.shape
|
|
430
468
|
device = seq.device
|
|
@@ -436,20 +474,35 @@ class HSTU(BaseModel):
|
|
|
436
474
|
|
|
437
475
|
# padding mask:True = pad
|
|
438
476
|
padding_mask = seq.eq(self.padding_idx) # [B, T]
|
|
439
|
-
|
|
477
|
+
attention_mask = self.build_causal_mask(seq_len=T, device=device) # [T, T]
|
|
440
478
|
|
|
441
479
|
for layer in self.layers:
|
|
442
480
|
hidden_states = layer(
|
|
443
|
-
x=hidden_states,
|
|
481
|
+
x=hidden_states,
|
|
482
|
+
attention_mask=attention_mask,
|
|
483
|
+
key_padding_mask=padding_mask,
|
|
444
484
|
)
|
|
445
485
|
hidden_states = self.final_norm(hidden_states) # [B, T, D]
|
|
446
486
|
|
|
447
487
|
valid_lengths = (~padding_mask).sum(dim=1) # [B]
|
|
448
488
|
last_index = (valid_lengths - 1).clamp(min=0)
|
|
489
|
+
|
|
490
|
+
# For sequences with no valid tokens, we use position 0's hidden state
|
|
491
|
+
# In production, these sequences should be filtered out before inference
|
|
449
492
|
last_hidden = hidden_states[
|
|
450
493
|
torch.arange(B, device=device), last_index
|
|
451
494
|
] # [B, D]
|
|
452
495
|
|
|
496
|
+
if self.context_embedding is not None and self.context_proj is not None:
|
|
497
|
+
context_repr = self.context_embedding(
|
|
498
|
+
x, self.context_features, squeeze_dim=True
|
|
499
|
+
) # [B, D_ctx]
|
|
500
|
+
context_repr = self.context_proj(context_repr) # [B, D]
|
|
501
|
+
if self.context_dropout is not None:
|
|
502
|
+
context_repr = self.context_dropout(context_repr)
|
|
503
|
+
# fuse contextual signal into the autoregressive token summary
|
|
504
|
+
last_hidden = last_hidden + context_repr
|
|
505
|
+
|
|
453
506
|
logits = self.lm_head(last_hidden) # [B, vocab_size]
|
|
454
507
|
return logits
|
|
455
508
|
|