nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/listwise.py
CHANGED
|
@@ -20,10 +20,14 @@ class SampledSoftmaxLoss(nn.Module):
|
|
|
20
20
|
super().__init__()
|
|
21
21
|
self.reduction = reduction
|
|
22
22
|
|
|
23
|
-
def forward(
|
|
23
|
+
def forward(
|
|
24
|
+
self, pos_logits: torch.Tensor, neg_logits: torch.Tensor
|
|
25
|
+
) -> torch.Tensor:
|
|
24
26
|
pos_logits = pos_logits.unsqueeze(1)
|
|
25
27
|
all_logits = torch.cat([pos_logits, neg_logits], dim=1)
|
|
26
|
-
targets = torch.zeros(
|
|
28
|
+
targets = torch.zeros(
|
|
29
|
+
all_logits.size(0), dtype=torch.long, device=all_logits.device
|
|
30
|
+
)
|
|
27
31
|
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
28
32
|
return loss
|
|
29
33
|
|
|
@@ -87,7 +91,11 @@ class ListMLELoss(nn.Module):
|
|
|
87
91
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
88
92
|
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
89
93
|
batch_size, list_size = scores.shape
|
|
90
|
-
batch_indices =
|
|
94
|
+
batch_indices = (
|
|
95
|
+
torch.arange(batch_size, device=scores.device)
|
|
96
|
+
.unsqueeze(1)
|
|
97
|
+
.expand(-1, list_size)
|
|
98
|
+
)
|
|
91
99
|
sorted_scores = scores[batch_indices, sorted_indices]
|
|
92
100
|
|
|
93
101
|
loss = torch.tensor(0.0, device=scores.device)
|
|
@@ -139,19 +147,19 @@ class ApproxNDCGLoss(nn.Module):
|
|
|
139
147
|
device = scores.device
|
|
140
148
|
|
|
141
149
|
# diff[b, i, j] = (s_j - s_i) / T
|
|
142
|
-
scores_i = scores.unsqueeze(2)
|
|
143
|
-
scores_j = scores.unsqueeze(1)
|
|
150
|
+
scores_i = scores.unsqueeze(2) # [B, L, 1]
|
|
151
|
+
scores_j = scores.unsqueeze(1) # [B, 1, L]
|
|
144
152
|
diff = (scores_j - scores_i) / self.temperature # [B, L, L]
|
|
145
153
|
|
|
146
|
-
P_ji = torch.sigmoid(diff)
|
|
154
|
+
P_ji = torch.sigmoid(diff) # [B, L, L]
|
|
147
155
|
eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
|
|
148
156
|
P_ji = P_ji * (1.0 - eye)
|
|
149
157
|
|
|
150
|
-
exp_rank = 1.0 + P_ji.sum(dim=-1)
|
|
158
|
+
exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
|
|
151
159
|
|
|
152
160
|
discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
|
|
153
161
|
|
|
154
|
-
gains = torch.pow(2.0, labels) - 1.0
|
|
162
|
+
gains = torch.pow(2.0, labels) - 1.0 # [B, L]
|
|
155
163
|
approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
|
|
156
164
|
|
|
157
165
|
ideal_dcg = self._ideal_dcg(labels, k) # [B]
|
|
@@ -163,4 +171,4 @@ class ApproxNDCGLoss(nn.Module):
|
|
|
163
171
|
return loss.mean()
|
|
164
172
|
if self.reduction == "sum":
|
|
165
173
|
return loss.sum()
|
|
166
|
-
return loss
|
|
174
|
+
return loss
|
nextrec/loss/loss_utils.py
CHANGED
|
@@ -6,8 +6,6 @@ Checkpoint: edit on 29/11/2025
|
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from typing import Literal
|
|
10
|
-
|
|
11
9
|
import torch.nn as nn
|
|
12
10
|
|
|
13
11
|
from nextrec.loss.listwise import (
|
|
@@ -20,19 +18,19 @@ from nextrec.loss.listwise import (
|
|
|
20
18
|
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
21
19
|
from nextrec.loss.pointwise import (
|
|
22
20
|
ClassBalancedFocalLoss,
|
|
23
|
-
CosineContrastiveLoss,
|
|
24
21
|
FocalLoss,
|
|
25
22
|
WeightedBCELoss,
|
|
26
23
|
)
|
|
27
24
|
|
|
28
25
|
|
|
29
26
|
VALID_TASK_TYPES = [
|
|
30
|
-
"binary",
|
|
31
|
-
"multiclass",
|
|
32
|
-
"multilabel",
|
|
33
|
-
"regression",
|
|
27
|
+
"binary",
|
|
28
|
+
"multiclass",
|
|
29
|
+
"multilabel",
|
|
30
|
+
"regression",
|
|
34
31
|
]
|
|
35
32
|
|
|
33
|
+
|
|
36
34
|
def _build_cb_focal(kw):
|
|
37
35
|
if "class_counts" not in kw:
|
|
38
36
|
raise ValueError("class_balanced_focal requires class_counts")
|
|
@@ -81,6 +79,7 @@ def get_loss_fn(loss=None, **kw):
|
|
|
81
79
|
|
|
82
80
|
raise ValueError(f"[Loss Error] Unsupported loss: {loss}")
|
|
83
81
|
|
|
82
|
+
|
|
84
83
|
def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
|
|
85
84
|
"""
|
|
86
85
|
Parse loss_kwargs for each head.
|
|
@@ -95,4 +94,4 @@ def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> di
|
|
|
95
94
|
if index < len(loss_params) and loss_params[index] is not None:
|
|
96
95
|
return loss_params[index]
|
|
97
96
|
return {}
|
|
98
|
-
return loss_params
|
|
97
|
+
return loss_params
|
nextrec/loss/pairwise.py
CHANGED
|
@@ -36,6 +36,7 @@ class BPRLoss(nn.Module):
|
|
|
36
36
|
return loss.sum()
|
|
37
37
|
return loss
|
|
38
38
|
|
|
39
|
+
|
|
39
40
|
class HingeLoss(nn.Module):
|
|
40
41
|
"""
|
|
41
42
|
Hinge loss for pairwise ranking.
|
|
@@ -59,6 +60,7 @@ class HingeLoss(nn.Module):
|
|
|
59
60
|
return loss.sum()
|
|
60
61
|
return loss
|
|
61
62
|
|
|
63
|
+
|
|
62
64
|
class TripletLoss(nn.Module):
|
|
63
65
|
"""
|
|
64
66
|
Triplet margin loss with cosine or euclidean distance.
|
nextrec/loss/pointwise.py
CHANGED
|
@@ -46,6 +46,7 @@ class WeightedBCELoss(nn.Module):
|
|
|
46
46
|
If `auto_balance=True` and `pos_weight` is None, the positive weight is
|
|
47
47
|
computed from the batch as (#neg / #pos) for stable imbalance handling.
|
|
48
48
|
"""
|
|
49
|
+
|
|
49
50
|
def __init__(
|
|
50
51
|
self,
|
|
51
52
|
pos_weight: float | torch.Tensor | None = None,
|
|
@@ -59,11 +60,14 @@ class WeightedBCELoss(nn.Module):
|
|
|
59
60
|
self.auto_balance = auto_balance
|
|
60
61
|
|
|
61
62
|
if pos_weight is not None:
|
|
62
|
-
self.register_buffer(
|
|
63
|
+
self.register_buffer(
|
|
64
|
+
"pos_weight",
|
|
65
|
+
torch.as_tensor(pos_weight, dtype=torch.float32),
|
|
66
|
+
)
|
|
63
67
|
else:
|
|
64
68
|
self.pos_weight = None
|
|
65
69
|
|
|
66
|
-
def
|
|
70
|
+
def resolve_pos_weight(self, labels: torch.Tensor) -> torch.Tensor:
|
|
67
71
|
if self.pos_weight is not None:
|
|
68
72
|
return self.pos_weight.to(device=labels.device)
|
|
69
73
|
|
|
@@ -77,7 +81,7 @@ class WeightedBCELoss(nn.Module):
|
|
|
77
81
|
|
|
78
82
|
def forward(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
79
83
|
labels = labels.float()
|
|
80
|
-
current_pos_weight = self.
|
|
84
|
+
current_pos_weight = self.resolve_pos_weight(labels)
|
|
81
85
|
current_pos_weight = current_pos_weight.to(inputs.dtype)
|
|
82
86
|
|
|
83
87
|
if self.logits:
|
|
@@ -120,23 +124,27 @@ class FocalLoss(nn.Module):
|
|
|
120
124
|
if inputs.dim() > 1 and inputs.size(1) > 1:
|
|
121
125
|
log_probs = F.log_softmax(inputs, dim=1)
|
|
122
126
|
probs = log_probs.exp()
|
|
123
|
-
targets_one_hot = F.one_hot(
|
|
127
|
+
targets_one_hot = F.one_hot(
|
|
128
|
+
targets.long(), num_classes=inputs.size(1)
|
|
129
|
+
).float()
|
|
124
130
|
|
|
125
|
-
alpha = self.
|
|
131
|
+
alpha = self.get_alpha(inputs)
|
|
126
132
|
alpha_factor = targets_one_hot * alpha
|
|
127
133
|
focal_weight = (1.0 - probs) ** self.gamma
|
|
128
134
|
loss = torch.sum(alpha_factor * focal_weight * (-log_probs), dim=1)
|
|
129
135
|
else:
|
|
130
136
|
targets = targets.float()
|
|
131
137
|
if self.logits:
|
|
132
|
-
ce_loss = F.binary_cross_entropy_with_logits(
|
|
138
|
+
ce_loss = F.binary_cross_entropy_with_logits(
|
|
139
|
+
inputs, targets, reduction="none"
|
|
140
|
+
)
|
|
133
141
|
probs = torch.sigmoid(inputs)
|
|
134
142
|
else:
|
|
135
143
|
ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
|
|
136
144
|
probs = torch.clamp(inputs, min=1e-6, max=1 - 1e-6)
|
|
137
145
|
|
|
138
146
|
p_t = probs * targets + (1 - probs) * (1 - targets)
|
|
139
|
-
alpha_factor = self.
|
|
147
|
+
alpha_factor = self.get_binary_alpha(targets, inputs.device)
|
|
140
148
|
focal_weight = (1.0 - p_t) ** self.gamma
|
|
141
149
|
loss = alpha_factor * focal_weight * ce_loss
|
|
142
150
|
if self.reduction == "mean":
|
|
@@ -145,27 +153,35 @@ class FocalLoss(nn.Module):
|
|
|
145
153
|
return loss.sum()
|
|
146
154
|
return loss
|
|
147
155
|
|
|
148
|
-
def
|
|
156
|
+
def get_alpha(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
149
157
|
if self.alpha is None:
|
|
150
158
|
return torch.ones_like(inputs)
|
|
151
159
|
if isinstance(self.alpha, torch.Tensor):
|
|
152
160
|
return self.alpha.to(inputs.device)
|
|
153
|
-
alpha_tensor = torch.tensor(
|
|
161
|
+
alpha_tensor = torch.tensor(
|
|
162
|
+
self.alpha, device=inputs.device, dtype=inputs.dtype
|
|
163
|
+
)
|
|
154
164
|
return alpha_tensor
|
|
155
165
|
|
|
156
|
-
def
|
|
166
|
+
def get_binary_alpha(
|
|
167
|
+
self, targets: torch.Tensor, device: torch.device
|
|
168
|
+
) -> torch.Tensor:
|
|
157
169
|
if self.alpha is None:
|
|
158
170
|
return torch.ones_like(targets)
|
|
159
171
|
if isinstance(self.alpha, (float, int)):
|
|
160
|
-
return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(
|
|
172
|
+
return torch.where(targets == 1, self.alpha, 1 - float(self.alpha)).to(
|
|
173
|
+
device
|
|
174
|
+
)
|
|
161
175
|
alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
|
|
162
176
|
return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
|
|
163
177
|
|
|
178
|
+
|
|
164
179
|
class ClassBalancedFocalLoss(nn.Module):
|
|
165
180
|
"""
|
|
166
181
|
Focal loss weighted by effective number of samples per class.
|
|
167
182
|
Reference: "Class-Balanced Loss Based on Effective Number of Samples"
|
|
168
183
|
"""
|
|
184
|
+
|
|
169
185
|
def __init__(
|
|
170
186
|
self,
|
|
171
187
|
class_counts: Sequence[int] | torch.Tensor,
|
|
@@ -183,7 +199,9 @@ class ClassBalancedFocalLoss(nn.Module):
|
|
|
183
199
|
self.register_buffer("class_weights", weights)
|
|
184
200
|
|
|
185
201
|
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
|
186
|
-
focal = FocalLoss(
|
|
202
|
+
focal = FocalLoss(
|
|
203
|
+
gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
|
|
204
|
+
)
|
|
187
205
|
loss = focal(inputs, targets)
|
|
188
206
|
if self.reduction == "mean":
|
|
189
207
|
return loss.mean()
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
"""
|
|
2
2
|
[Info: this version is not released yet, i need to more research on source code and paper]
|
|
3
3
|
Date: create on 01/12/2025
|
|
4
|
-
Checkpoint: edit on 01/12/2025
|
|
4
|
+
Checkpoint: edit on 01/12/2025
|
|
5
5
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
6
|
Reference:
|
|
7
7
|
[1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
|
|
8
8
|
[2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
|
|
9
9
|
|
|
10
|
-
Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
|
|
11
|
-
Meta’s Generative Recommenders. It replaces softmax attention with lightweight
|
|
10
|
+
Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
|
|
11
|
+
Meta’s Generative Recommenders. It replaces softmax attention with lightweight
|
|
12
12
|
pointwise activations, enabling extremely deep stacks on long behavior sequences.
|
|
13
13
|
|
|
14
14
|
In each HSTU layer:
|
|
@@ -16,8 +16,8 @@ In each HSTU layer:
|
|
|
16
16
|
(2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
|
|
17
17
|
(3) Aggregated context is modulated by U-gating and mapped back through an output projection
|
|
18
18
|
|
|
19
|
-
Stacking layers yields an efficient causal encoder for next-item
|
|
20
|
-
generation. With a tied-embedding LM head, HSTU forms
|
|
19
|
+
Stacking layers yields an efficient causal encoder for next-item
|
|
20
|
+
generation. With a tied-embedding LM head, HSTU forms
|
|
21
21
|
a full generative recommendation model.
|
|
22
22
|
|
|
23
23
|
Key Advantages:
|
|
@@ -75,7 +75,16 @@ def _relative_position_bucket(
|
|
|
75
75
|
is_small = n < max_exact
|
|
76
76
|
|
|
77
77
|
# when the distance is too far, do log scaling
|
|
78
|
-
large_val =
|
|
78
|
+
large_val = (
|
|
79
|
+
max_exact
|
|
80
|
+
+ (
|
|
81
|
+
(
|
|
82
|
+
torch.log(n.float() / max_exact + 1e-6)
|
|
83
|
+
/ math.log(max_distance / max_exact)
|
|
84
|
+
)
|
|
85
|
+
* (num_buckets - max_exact)
|
|
86
|
+
).long()
|
|
87
|
+
)
|
|
79
88
|
large_val = torch.clamp(large_val, max=num_buckets - 1)
|
|
80
89
|
|
|
81
90
|
buckets = torch.where(is_small, n.long(), large_val)
|
|
@@ -104,10 +113,19 @@ class RelativePositionBias(nn.Module):
|
|
|
104
113
|
# positions: [T]
|
|
105
114
|
ctx = torch.arange(seq_len, device=device)[:, None]
|
|
106
115
|
mem = torch.arange(seq_len, device=device)[None, :]
|
|
107
|
-
rel_pos =
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
116
|
+
rel_pos = (
|
|
117
|
+
mem - ctx
|
|
118
|
+
) # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
|
|
119
|
+
buckets = _relative_position_bucket(
|
|
120
|
+
rel_pos,
|
|
121
|
+
num_buckets=self.num_buckets,
|
|
122
|
+
max_distance=self.max_distance,
|
|
123
|
+
) # map to buckets
|
|
124
|
+
values = self.embedding(
|
|
125
|
+
buckets
|
|
126
|
+
) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
|
|
127
|
+
return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
|
|
128
|
+
|
|
111
129
|
|
|
112
130
|
class HSTUPointwiseAttention(nn.Module):
|
|
113
131
|
"""
|
|
@@ -123,16 +141,18 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
123
141
|
d_model: int,
|
|
124
142
|
num_heads: int,
|
|
125
143
|
dropout: float = 0.1,
|
|
126
|
-
alpha: float | None = None
|
|
144
|
+
alpha: float | None = None,
|
|
127
145
|
):
|
|
128
146
|
super().__init__()
|
|
129
147
|
if d_model % num_heads != 0:
|
|
130
|
-
raise ValueError(
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0"
|
|
150
|
+
)
|
|
131
151
|
|
|
132
152
|
self.d_model = d_model
|
|
133
153
|
self.num_heads = num_heads
|
|
134
154
|
self.d_head = d_model // num_heads
|
|
135
|
-
self.alpha = alpha if alpha is not None else (self.d_head
|
|
155
|
+
self.alpha = alpha if alpha is not None else (self.d_head**-0.5)
|
|
136
156
|
# project input to 4 * d_model for U, V, Q, K
|
|
137
157
|
self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
|
|
138
158
|
# project output back to d_model
|
|
@@ -150,9 +170,9 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
150
170
|
def forward(
|
|
151
171
|
self,
|
|
152
172
|
x: torch.Tensor,
|
|
153
|
-
attn_mask: Optional[torch.Tensor] = None,
|
|
173
|
+
attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
|
|
154
174
|
key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
|
|
155
|
-
rab: Optional[torch.Tensor] = None,
|
|
175
|
+
rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
|
|
156
176
|
) -> torch.Tensor:
|
|
157
177
|
B, T, D = x.shape
|
|
158
178
|
|
|
@@ -185,8 +205,8 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
185
205
|
# padding mask: key_padding_mask is usually [B, T], True = pad
|
|
186
206
|
if key_padding_mask is not None:
|
|
187
207
|
# valid: 1 for non-pad, 0 for pad
|
|
188
|
-
valid = (~key_padding_mask).float()
|
|
189
|
-
valid = valid.view(B, 1, 1, T)
|
|
208
|
+
valid = (~key_padding_mask).float() # [B, T]
|
|
209
|
+
valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
|
|
190
210
|
allowed = allowed * valid
|
|
191
211
|
logits = logits.masked_fill(valid == 0, float("-inf"))
|
|
192
212
|
|
|
@@ -197,7 +217,7 @@ class HSTUPointwiseAttention(nn.Module):
|
|
|
197
217
|
|
|
198
218
|
attn = attn / denom # [B, H, T, T]
|
|
199
219
|
AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
|
|
200
|
-
AV = AV.transpose(1, 2).contiguous().view(B, T, D)
|
|
220
|
+
AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
|
|
201
221
|
U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
|
|
202
222
|
y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
|
|
203
223
|
return y
|
|
@@ -218,10 +238,20 @@ class HSTULayer(nn.Module):
|
|
|
218
238
|
rab_max_distance: int = 128,
|
|
219
239
|
):
|
|
220
240
|
super().__init__()
|
|
221
|
-
self.attn = HSTUPointwiseAttention(
|
|
241
|
+
self.attn = HSTUPointwiseAttention(
|
|
242
|
+
d_model=d_model, num_heads=num_heads, dropout=dropout
|
|
243
|
+
)
|
|
222
244
|
self.dropout = nn.Dropout(dropout)
|
|
223
245
|
self.use_rab_pos = use_rab_pos
|
|
224
|
-
self.rel_pos_bias = (
|
|
246
|
+
self.rel_pos_bias = (
|
|
247
|
+
RelativePositionBias(
|
|
248
|
+
num_heads=num_heads,
|
|
249
|
+
num_buckets=rab_num_buckets,
|
|
250
|
+
max_distance=rab_max_distance,
|
|
251
|
+
)
|
|
252
|
+
if use_rab_pos
|
|
253
|
+
else None
|
|
254
|
+
)
|
|
225
255
|
|
|
226
256
|
def forward(
|
|
227
257
|
self,
|
|
@@ -236,8 +266,10 @@ class HSTULayer(nn.Module):
|
|
|
236
266
|
device = x.device
|
|
237
267
|
rab = None
|
|
238
268
|
if self.use_rab_pos:
|
|
239
|
-
rab = self.rel_pos_bias(seq_len=T, device=device)
|
|
240
|
-
out = self.attn(
|
|
269
|
+
rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
|
|
270
|
+
out = self.attn(
|
|
271
|
+
x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab
|
|
272
|
+
)
|
|
241
273
|
return x + self.dropout(out)
|
|
242
274
|
|
|
243
275
|
|
|
@@ -255,7 +287,7 @@ class HSTU(BaseModel):
|
|
|
255
287
|
return "HSTU"
|
|
256
288
|
|
|
257
289
|
@property
|
|
258
|
-
def
|
|
290
|
+
def default_task(self) -> str:
|
|
259
291
|
return "multiclass"
|
|
260
292
|
|
|
261
293
|
def __init__(
|
|
@@ -272,9 +304,9 @@ class HSTU(BaseModel):
|
|
|
272
304
|
use_rab_pos: bool = True,
|
|
273
305
|
rab_num_buckets: int = 32,
|
|
274
306
|
rab_max_distance: int = 128,
|
|
275
|
-
|
|
276
307
|
tie_embeddings: bool = True,
|
|
277
308
|
target: Optional[list[str] | str] = None,
|
|
309
|
+
task: str | list[str] | None = None,
|
|
278
310
|
optimizer: str = "adam",
|
|
279
311
|
optimizer_params: Optional[dict] = None,
|
|
280
312
|
scheduler: Optional[str] = None,
|
|
@@ -288,17 +320,25 @@ class HSTU(BaseModel):
|
|
|
288
320
|
**kwargs,
|
|
289
321
|
):
|
|
290
322
|
if not sequence_features:
|
|
291
|
-
raise ValueError(
|
|
323
|
+
raise ValueError(
|
|
324
|
+
"[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history)."
|
|
325
|
+
)
|
|
292
326
|
|
|
293
327
|
# demo version: use the first SequenceFeature as the main sequence
|
|
294
328
|
self.history_feature = sequence_features[0]
|
|
295
329
|
|
|
296
|
-
hidden_dim = d_model or max(
|
|
330
|
+
hidden_dim = d_model or max(
|
|
331
|
+
int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32
|
|
332
|
+
)
|
|
297
333
|
# Make hidden_dim divisible by num_heads
|
|
298
334
|
if hidden_dim % num_heads != 0:
|
|
299
335
|
hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
|
|
300
336
|
|
|
301
|
-
self.padding_idx =
|
|
337
|
+
self.padding_idx = (
|
|
338
|
+
self.history_feature.padding_idx
|
|
339
|
+
if self.history_feature.padding_idx is not None
|
|
340
|
+
else 0
|
|
341
|
+
)
|
|
302
342
|
self.vocab_size = self.history_feature.vocab_size
|
|
303
343
|
self.max_seq_len = max_seq_len
|
|
304
344
|
|
|
@@ -307,7 +347,7 @@ class HSTU(BaseModel):
|
|
|
307
347
|
sparse_features=sparse_features,
|
|
308
348
|
sequence_features=sequence_features,
|
|
309
349
|
target=target,
|
|
310
|
-
task=self.
|
|
350
|
+
task=task or self.default_task,
|
|
311
351
|
device=device,
|
|
312
352
|
embedding_l1_reg=embedding_l1_reg,
|
|
313
353
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -326,8 +366,19 @@ class HSTU(BaseModel):
|
|
|
326
366
|
self.input_dropout = nn.Dropout(dropout)
|
|
327
367
|
|
|
328
368
|
# HSTU layers
|
|
329
|
-
self.layers = nn.ModuleList(
|
|
330
|
-
|
|
369
|
+
self.layers = nn.ModuleList(
|
|
370
|
+
[
|
|
371
|
+
HSTULayer(
|
|
372
|
+
d_model=hidden_dim,
|
|
373
|
+
num_heads=num_heads,
|
|
374
|
+
dropout=dropout,
|
|
375
|
+
use_rab_pos=use_rab_pos,
|
|
376
|
+
rab_num_buckets=rab_num_buckets,
|
|
377
|
+
rab_max_distance=rab_max_distance,
|
|
378
|
+
)
|
|
379
|
+
for _ in range(num_layers)
|
|
380
|
+
]
|
|
381
|
+
)
|
|
331
382
|
|
|
332
383
|
self.final_norm = nn.LayerNorm(hidden_dim)
|
|
333
384
|
self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
|
|
@@ -343,8 +394,17 @@ class HSTU(BaseModel):
|
|
|
343
394
|
loss_params = loss_params or {}
|
|
344
395
|
loss_params.setdefault("ignore_index", self.ignore_index)
|
|
345
396
|
|
|
346
|
-
self.compile(
|
|
347
|
-
|
|
397
|
+
self.compile(
|
|
398
|
+
optimizer=optimizer,
|
|
399
|
+
optimizer_params=optimizer_params,
|
|
400
|
+
scheduler=scheduler,
|
|
401
|
+
scheduler_params=scheduler_params,
|
|
402
|
+
loss="crossentropy",
|
|
403
|
+
loss_params=loss_params,
|
|
404
|
+
)
|
|
405
|
+
self.register_regularization_weights(
|
|
406
|
+
embedding_attr="token_embedding", include_modules=["layers", "lm_head"]
|
|
407
|
+
)
|
|
348
408
|
|
|
349
409
|
def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
350
410
|
"""
|
|
@@ -353,7 +413,7 @@ class HSTU(BaseModel):
|
|
|
353
413
|
"""
|
|
354
414
|
if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
|
|
355
415
|
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
|
|
356
|
-
mask = torch.triu(mask, diagonal=1)
|
|
416
|
+
mask = torch.triu(mask, diagonal=1)
|
|
357
417
|
self.causal_mask = mask
|
|
358
418
|
return self.causal_mask[:seq_len, :seq_len]
|
|
359
419
|
|
|
@@ -364,27 +424,31 @@ class HSTU(BaseModel):
|
|
|
364
424
|
|
|
365
425
|
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
366
426
|
seq = x[self.history_feature.name].long() # [B, T_raw]
|
|
367
|
-
seq = self._trim_sequence(seq)
|
|
427
|
+
seq = self._trim_sequence(seq) # [B, T]
|
|
368
428
|
|
|
369
429
|
B, T = seq.shape
|
|
370
430
|
device = seq.device
|
|
371
431
|
# position ids: [B, T]
|
|
372
432
|
pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
|
|
373
|
-
token_emb = self.token_embedding(seq)
|
|
374
|
-
pos_emb = self.position_embedding(pos_ids)
|
|
433
|
+
token_emb = self.token_embedding(seq) # [B, T, D]
|
|
434
|
+
pos_emb = self.position_embedding(pos_ids) # [B, T, D]
|
|
375
435
|
hidden_states = self.input_dropout(token_emb + pos_emb)
|
|
376
436
|
|
|
377
437
|
# padding mask:True = pad
|
|
378
|
-
padding_mask = seq.eq(self.padding_idx)
|
|
438
|
+
padding_mask = seq.eq(self.padding_idx) # [B, T]
|
|
379
439
|
attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
|
|
380
440
|
|
|
381
441
|
for layer in self.layers:
|
|
382
|
-
hidden_states = layer(
|
|
442
|
+
hidden_states = layer(
|
|
443
|
+
x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask
|
|
444
|
+
)
|
|
383
445
|
hidden_states = self.final_norm(hidden_states) # [B, T, D]
|
|
384
446
|
|
|
385
447
|
valid_lengths = (~padding_mask).sum(dim=1) # [B]
|
|
386
448
|
last_index = (valid_lengths - 1).clamp(min=0)
|
|
387
|
-
last_hidden = hidden_states[
|
|
449
|
+
last_hidden = hidden_states[
|
|
450
|
+
torch.arange(B, device=device), last_index
|
|
451
|
+
] # [B, D]
|
|
388
452
|
|
|
389
453
|
logits = self.lm_head(last_hidden) # [B, vocab_size]
|
|
390
454
|
return logits
|
|
@@ -394,6 +458,8 @@ class HSTU(BaseModel):
|
|
|
394
458
|
y_true: [B] or [B, 1], the id of the next item.
|
|
395
459
|
"""
|
|
396
460
|
if y_true is None:
|
|
397
|
-
raise ValueError(
|
|
461
|
+
raise ValueError(
|
|
462
|
+
"[HSTU-compute_loss] Training requires y_true (next item id)."
|
|
463
|
+
)
|
|
398
464
|
labels = y_true.view(-1).long()
|
|
399
465
|
return self.loss_fn[0](y_pred, labels)
|