torch-rechub 0.0.6__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/layers.py +228 -159
- torch_rechub/basic/loss_func.py +62 -47
- torch_rechub/data/dataset.py +18 -31
- torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub/serving/__init__.py +50 -0
- torch_rechub/serving/annoy.py +133 -0
- torch_rechub/serving/base.py +107 -0
- torch_rechub/serving/faiss.py +154 -0
- torch_rechub/serving/milvus.py +215 -0
- torch_rechub/trainers/ctr_trainer.py +12 -2
- torch_rechub/trainers/match_trainer.py +13 -2
- torch_rechub/trainers/mtl_trainer.py +12 -2
- torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub/types.py +5 -0
- torch_rechub/utils/data.py +191 -145
- torch_rechub/utils/hstu_utils.py +87 -76
- torch_rechub/utils/model_utils.py +10 -12
- torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub/utils/quantization.py +128 -0
- torch_rechub/utils/visualization.py +4 -12
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/METADATA +34 -18
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/RECORD +24 -18
- torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/WHEEL +0 -0
- {torch_rechub-0.0.6.dist-info → torch_rechub-0.2.0.dist-info}/licenses/LICENSE +0 -0
torch_rechub/basic/layers.py
CHANGED
|
@@ -9,11 +9,12 @@ from .features import DenseFeature, SequenceFeature, SparseFeature
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class PredictionLayer(nn.Module):
|
|
12
|
-
"""Prediction
|
|
12
|
+
"""Prediction layer.
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
task_type : {'classification', 'regression'}
|
|
17
|
+
Classification applies sigmoid to logits; regression returns logits.
|
|
17
18
|
"""
|
|
18
19
|
|
|
19
20
|
def __init__(self, task_type='classification'):
|
|
@@ -29,24 +30,30 @@ class PredictionLayer(nn.Module):
|
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class EmbeddingLayer(nn.Module):
|
|
32
|
-
"""General
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
33
|
+
"""General embedding layer.
|
|
34
|
+
|
|
35
|
+
Stores per-feature embedding tables in ``embed_dict``.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
features : list
|
|
40
|
+
Feature objects to create embedding tables for.
|
|
41
|
+
|
|
42
|
+
Shape
|
|
43
|
+
-----
|
|
44
|
+
Input
|
|
45
|
+
x : dict
|
|
46
|
+
``{feature_name: feature_value}``; sequence values shape ``(B, L)``,
|
|
47
|
+
sparse/dense values shape ``(B,)``.
|
|
48
|
+
features : list
|
|
49
|
+
Feature list for lookup.
|
|
50
|
+
squeeze_dim : bool, default False
|
|
51
|
+
Whether to flatten embeddings.
|
|
52
|
+
Output
|
|
53
|
+
- Dense only: ``(B, num_dense)``.
|
|
54
|
+
- Sparse: ``(B, num_features, embed_dim)`` or flattened.
|
|
55
|
+
- Sequence: same as sparse or ``(B, num_seq, L, embed_dim)`` when ``pooling="concat"``.
|
|
56
|
+
- Mixed: flattened sparse plus dense when ``squeeze_dim=True``.
|
|
50
57
|
"""
|
|
51
58
|
|
|
52
59
|
def __init__(self, features):
|
|
@@ -119,16 +126,18 @@ class EmbeddingLayer(nn.Module):
|
|
|
119
126
|
|
|
120
127
|
|
|
121
128
|
class InputMask(nn.Module):
|
|
122
|
-
"""Return
|
|
123
|
-
|
|
124
|
-
Shape
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
129
|
+
"""Return input masks from features.
|
|
130
|
+
|
|
131
|
+
Shape
|
|
132
|
+
-----
|
|
133
|
+
Input
|
|
134
|
+
x : dict
|
|
135
|
+
``{feature_name: feature_value}``; sequence ``(B, L)``, sparse/dense ``(B,)``.
|
|
136
|
+
features : list or SparseFeature or SequenceFeature
|
|
137
|
+
All elements must be sparse or sequence features.
|
|
138
|
+
Output
|
|
139
|
+
- Sparse: ``(B, num_features)``
|
|
140
|
+
- Sequence: ``(B, num_seq, seq_length)``
|
|
132
141
|
"""
|
|
133
142
|
|
|
134
143
|
def __init__(self):
|
|
@@ -151,16 +160,19 @@ class InputMask(nn.Module):
|
|
|
151
160
|
|
|
152
161
|
|
|
153
162
|
class LR(nn.Module):
|
|
154
|
-
"""Logistic
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
163
|
+
"""Logistic regression module.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
input_dim : int
|
|
168
|
+
Input dimension.
|
|
169
|
+
sigmoid : bool, default False
|
|
170
|
+
Apply sigmoid to output when True.
|
|
171
|
+
|
|
172
|
+
Shape
|
|
173
|
+
-----
|
|
174
|
+
Input: ``(B, input_dim)``
|
|
175
|
+
Output: ``(B, 1)``
|
|
164
176
|
"""
|
|
165
177
|
|
|
166
178
|
def __init__(self, input_dim, sigmoid=False):
|
|
@@ -176,11 +188,12 @@ class LR(nn.Module):
|
|
|
176
188
|
|
|
177
189
|
|
|
178
190
|
class ConcatPooling(nn.Module):
|
|
179
|
-
"""Keep
|
|
191
|
+
"""Keep original sequence embedding shape.
|
|
180
192
|
|
|
181
|
-
Shape
|
|
182
|
-
|
|
183
|
-
|
|
193
|
+
Shape
|
|
194
|
+
-----
|
|
195
|
+
Input: ``(B, L, D)``
|
|
196
|
+
Output: ``(B, L, D)``
|
|
184
197
|
"""
|
|
185
198
|
|
|
186
199
|
def __init__(self):
|
|
@@ -191,13 +204,15 @@ class ConcatPooling(nn.Module):
|
|
|
191
204
|
|
|
192
205
|
|
|
193
206
|
class AveragePooling(nn.Module):
|
|
194
|
-
"""
|
|
195
|
-
|
|
196
|
-
Shape
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
207
|
+
"""Mean pooling over sequence embeddings.
|
|
208
|
+
|
|
209
|
+
Shape
|
|
210
|
+
-----
|
|
211
|
+
Input
|
|
212
|
+
x : ``(B, L, D)``
|
|
213
|
+
mask : ``(B, 1, L)``
|
|
214
|
+
Output
|
|
215
|
+
``(B, D)``
|
|
201
216
|
"""
|
|
202
217
|
|
|
203
218
|
def __init__(self):
|
|
@@ -213,13 +228,15 @@ class AveragePooling(nn.Module):
|
|
|
213
228
|
|
|
214
229
|
|
|
215
230
|
class SumPooling(nn.Module):
|
|
216
|
-
"""
|
|
217
|
-
|
|
218
|
-
Shape
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
231
|
+
"""Sum pooling over sequence embeddings.
|
|
232
|
+
|
|
233
|
+
Shape
|
|
234
|
+
-----
|
|
235
|
+
Input
|
|
236
|
+
x : ``(B, L, D)``
|
|
237
|
+
mask : ``(B, 1, L)``
|
|
238
|
+
Output
|
|
239
|
+
``(B, D)``
|
|
223
240
|
"""
|
|
224
241
|
|
|
225
242
|
def __init__(self):
|
|
@@ -233,20 +250,25 @@ class SumPooling(nn.Module):
|
|
|
233
250
|
|
|
234
251
|
|
|
235
252
|
class MLP(nn.Module):
|
|
236
|
-
"""Multi
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
253
|
+
"""Multi-layer perceptron with BN/activation/dropout per linear layer.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
input_dim : int
|
|
258
|
+
Input dimension of the first linear layer.
|
|
259
|
+
output_layer : bool, default True
|
|
260
|
+
If True, append a final Linear(*,1).
|
|
261
|
+
dims : list, default []
|
|
262
|
+
Hidden layer sizes.
|
|
263
|
+
dropout : float, default 0
|
|
264
|
+
Dropout probability.
|
|
265
|
+
activation : str, default 'relu'
|
|
266
|
+
Activation function (sigmoid, relu, prelu, dice, softmax).
|
|
267
|
+
|
|
268
|
+
Shape
|
|
269
|
+
-----
|
|
270
|
+
Input: ``(B, input_dim)``
|
|
271
|
+
Output: ``(B, 1)`` or ``(B, dims[-1])``
|
|
250
272
|
"""
|
|
251
273
|
|
|
252
274
|
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
|
|
@@ -269,16 +291,17 @@ class MLP(nn.Module):
|
|
|
269
291
|
|
|
270
292
|
|
|
271
293
|
class FM(nn.Module):
|
|
272
|
-
"""
|
|
273
|
-
<https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
|
|
274
|
-
feature interactions.
|
|
294
|
+
"""Factorization Machine for 2nd-order interactions.
|
|
275
295
|
|
|
276
|
-
|
|
277
|
-
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
reduce_sum : bool, default True
|
|
299
|
+
Sum over embed dim (inner product) when True; otherwise keep dim.
|
|
278
300
|
|
|
279
|
-
Shape
|
|
280
|
-
|
|
281
|
-
|
|
301
|
+
Shape
|
|
302
|
+
-----
|
|
303
|
+
Input: ``(B, num_features, embed_dim)``
|
|
304
|
+
Output: ``(B, 1)`` or ``(B, embed_dim)``
|
|
282
305
|
"""
|
|
283
306
|
|
|
284
307
|
def __init__(self, reduce_sum=True):
|
|
@@ -295,15 +318,21 @@ class FM(nn.Module):
|
|
|
295
318
|
|
|
296
319
|
|
|
297
320
|
class CIN(nn.Module):
|
|
298
|
-
"""Compressed Interaction Network
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
321
|
+
"""Compressed Interaction Network.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
input_dim : int
|
|
326
|
+
Input dimension.
|
|
327
|
+
cin_size : list[int]
|
|
328
|
+
Output channels per Conv1d layer.
|
|
329
|
+
split_half : bool, default True
|
|
330
|
+
Split channels except last layer.
|
|
331
|
+
|
|
332
|
+
Shape
|
|
333
|
+
-----
|
|
334
|
+
Input: ``(B, num_features, embed_dim)``
|
|
335
|
+
Output: ``(B, 1)``
|
|
307
336
|
"""
|
|
308
337
|
|
|
309
338
|
def __init__(self, input_dim, cin_size, split_half=True):
|
|
@@ -338,10 +367,12 @@ class CIN(nn.Module):
|
|
|
338
367
|
|
|
339
368
|
|
|
340
369
|
class CrossLayer(nn.Module):
|
|
341
|
-
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
370
|
+
"""Cross layer.
|
|
371
|
+
|
|
372
|
+
Parameters
|
|
373
|
+
----------
|
|
374
|
+
input_dim : int
|
|
375
|
+
Input dimension.
|
|
345
376
|
"""
|
|
346
377
|
|
|
347
378
|
def __init__(self, input_dim):
|
|
@@ -355,15 +386,19 @@ class CrossLayer(nn.Module):
|
|
|
355
386
|
|
|
356
387
|
|
|
357
388
|
class CrossNetwork(nn.Module):
|
|
358
|
-
"""CrossNetwork
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
389
|
+
"""CrossNetwork from DCN.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
input_dim : int
|
|
394
|
+
Input dimension.
|
|
395
|
+
num_layers : int
|
|
396
|
+
Number of cross layers.
|
|
397
|
+
|
|
398
|
+
Shape
|
|
399
|
+
-----
|
|
400
|
+
Input: ``(B, *)``
|
|
401
|
+
Output: ``(B, *)``
|
|
367
402
|
"""
|
|
368
403
|
|
|
369
404
|
def __init__(self, input_dim, num_layers):
|
|
@@ -384,6 +419,15 @@ class CrossNetwork(nn.Module):
|
|
|
384
419
|
|
|
385
420
|
|
|
386
421
|
class CrossNetV2(nn.Module):
|
|
422
|
+
"""DCNv2-style cross network.
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
input_dim : int
|
|
427
|
+
Input dimension.
|
|
428
|
+
num_layers : int
|
|
429
|
+
Number of cross layers.
|
|
430
|
+
"""
|
|
387
431
|
|
|
388
432
|
def __init__(self, input_dim, num_layers):
|
|
389
433
|
super().__init__()
|
|
@@ -399,10 +443,11 @@ class CrossNetV2(nn.Module):
|
|
|
399
443
|
|
|
400
444
|
|
|
401
445
|
class CrossNetMix(nn.Module):
|
|
402
|
-
"""
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
446
|
+
"""CrossNetMix with MOE and nonlinear low-rank transforms.
|
|
447
|
+
|
|
448
|
+
Notes
|
|
449
|
+
-----
|
|
450
|
+
Input: float tensor ``(B, num_fields, embed_dim)``.
|
|
406
451
|
"""
|
|
407
452
|
|
|
408
453
|
def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
|
|
@@ -460,14 +505,14 @@ class CrossNetMix(nn.Module):
|
|
|
460
505
|
|
|
461
506
|
|
|
462
507
|
class SENETLayer(nn.Module):
|
|
463
|
-
"""
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
508
|
+
"""SENet-style feature gating.
|
|
509
|
+
|
|
510
|
+
Parameters
|
|
511
|
+
----------
|
|
512
|
+
num_fields : int
|
|
513
|
+
Number of feature fields.
|
|
514
|
+
reduction_ratio : int, default=3
|
|
515
|
+
Reduction ratio for the bottleneck MLP.
|
|
471
516
|
"""
|
|
472
517
|
|
|
473
518
|
def __init__(self, num_fields, reduction_ratio=3):
|
|
@@ -483,14 +528,16 @@ class SENETLayer(nn.Module):
|
|
|
483
528
|
|
|
484
529
|
|
|
485
530
|
class BiLinearInteractionLayer(nn.Module):
|
|
486
|
-
"""
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
531
|
+
"""Bilinear feature interaction (FFM-style).
|
|
532
|
+
|
|
533
|
+
Parameters
|
|
534
|
+
----------
|
|
535
|
+
input_dim : int
|
|
536
|
+
Input dimension.
|
|
537
|
+
num_fields : int
|
|
538
|
+
Number of feature fields.
|
|
539
|
+
bilinear_type : {'field_all', 'field_each', 'field_interaction'}, default 'field_interaction'
|
|
540
|
+
Bilinear interaction variant.
|
|
494
541
|
"""
|
|
495
542
|
|
|
496
543
|
def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
|
|
@@ -517,18 +564,24 @@ class BiLinearInteractionLayer(nn.Module):
|
|
|
517
564
|
|
|
518
565
|
|
|
519
566
|
class MultiInterestSA(nn.Module):
|
|
520
|
-
"""
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
567
|
+
"""Self-attention multi-interest module (Comirec).
|
|
568
|
+
|
|
569
|
+
Parameters
|
|
570
|
+
----------
|
|
571
|
+
embedding_dim : int
|
|
572
|
+
Item embedding dimension.
|
|
573
|
+
interest_num : int
|
|
574
|
+
Number of interests.
|
|
575
|
+
hidden_dim : int, optional
|
|
576
|
+
Hidden dimension; defaults to ``4 * embedding_dim`` if None.
|
|
577
|
+
|
|
578
|
+
Shape
|
|
579
|
+
-----
|
|
580
|
+
Input
|
|
581
|
+
seq_emb : ``(B, L, D)``
|
|
582
|
+
mask : ``(B, L, 1)``
|
|
583
|
+
Output
|
|
584
|
+
``(B, interest_num, D)``
|
|
532
585
|
"""
|
|
533
586
|
|
|
534
587
|
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
@@ -555,20 +608,30 @@ class MultiInterestSA(nn.Module):
|
|
|
555
608
|
|
|
556
609
|
|
|
557
610
|
class CapsuleNetwork(nn.Module):
|
|
558
|
-
"""
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
611
|
+
"""Capsule network for multi-interest (MIND/Comirec).
|
|
612
|
+
|
|
613
|
+
Parameters
|
|
614
|
+
----------
|
|
615
|
+
embedding_dim : int
|
|
616
|
+
Item embedding dimension.
|
|
617
|
+
seq_len : int
|
|
618
|
+
Sequence length.
|
|
619
|
+
bilinear_type : {0, 1, 2}, default 2
|
|
620
|
+
0 for MIND, 2 for ComirecDR.
|
|
621
|
+
interest_num : int, default 4
|
|
622
|
+
Number of interests.
|
|
623
|
+
routing_times : int, default 3
|
|
624
|
+
Routing iterations.
|
|
625
|
+
relu_layer : bool, default False
|
|
626
|
+
Whether to apply ReLU after routing.
|
|
627
|
+
|
|
628
|
+
Shape
|
|
629
|
+
-----
|
|
630
|
+
Input
|
|
631
|
+
seq_emb : ``(B, L, D)``
|
|
632
|
+
mask : ``(B, L, 1)``
|
|
633
|
+
Output
|
|
634
|
+
``(B, interest_num, D)``
|
|
572
635
|
"""
|
|
573
636
|
|
|
574
637
|
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
@@ -783,7 +846,7 @@ class HSTULayer(nn.Module):
|
|
|
783
846
|
self.dropout = nn.Dropout(dropout)
|
|
784
847
|
|
|
785
848
|
# Scaling factor for attention scores
|
|
786
|
-
self.scale = 1.0 / (dqk**0.5)
|
|
849
|
+
# self.scale = 1.0 / (dqk**0.5) # Removed in favor of L2 norm + SiLU
|
|
787
850
|
|
|
788
851
|
def forward(self, x, rel_pos_bias=None):
|
|
789
852
|
"""Forward pass of a single HSTU layer.
|
|
@@ -815,6 +878,10 @@ class HSTULayer(nn.Module):
|
|
|
815
878
|
u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
816
879
|
v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
817
880
|
|
|
881
|
+
# Apply L2 normalization to Q and K (HSTU specific)
|
|
882
|
+
q = F.normalize(q, p=2, dim=-1)
|
|
883
|
+
k = F.normalize(k, p=2, dim=-1)
|
|
884
|
+
|
|
818
885
|
# Transpose to (B, H, L, dqk/dv)
|
|
819
886
|
q = q.transpose(1, 2) # (B, H, L, dqk)
|
|
820
887
|
k = k.transpose(1, 2) # (B, H, L, dqk)
|
|
@@ -822,20 +889,22 @@ class HSTULayer(nn.Module):
|
|
|
822
889
|
v = v.transpose(1, 2) # (B, H, L, dv)
|
|
823
890
|
|
|
824
891
|
# Compute attention scores: (B, H, L, L)
|
|
825
|
-
|
|
892
|
+
# Note: No scaling factor here as we use L2 norm + SiLU
|
|
893
|
+
scores = torch.matmul(q, k.transpose(-2, -1))
|
|
894
|
+
|
|
895
|
+
# Add relative position bias if provided (before masking/activation)
|
|
896
|
+
if rel_pos_bias is not None:
|
|
897
|
+
scores = scores + rel_pos_bias
|
|
826
898
|
|
|
827
899
|
# Add causal mask (prevent attending to future positions)
|
|
828
900
|
# For generative models this is required so that position i only attends
|
|
829
901
|
# to positions <= i.
|
|
830
902
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
# Add relative position bias if provided
|
|
834
|
-
if rel_pos_bias is not None:
|
|
835
|
-
scores = scores + rel_pos_bias
|
|
903
|
+
# Use a large negative number for masking compatible with SiLU
|
|
904
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), -1e4)
|
|
836
905
|
|
|
837
|
-
#
|
|
838
|
-
attn_weights = F.
|
|
906
|
+
# SiLU activation over attention scores (HSTU specific)
|
|
907
|
+
attn_weights = F.silu(scores)
|
|
839
908
|
attn_weights = self.dropout(attn_weights)
|
|
840
909
|
|
|
841
910
|
# Attention output: (B, H, L, dv)
|