torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +288 -128
- torch_rechub/trainers/match_trainer.py +336 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +356 -207
- torch_rechub/trainers/seq_trainer.py +427 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/model_utils.py +233 -0
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +220 -0
- torch_rechub/utils/visualization.py +271 -0
- torch_rechub-0.0.5.dist-info/METADATA +402 -0
- torch_rechub-0.0.5.dist-info/RECORD +64 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
torch_rechub/basic/layers.py
CHANGED
|
@@ -1,720 +1,994 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
- if input
|
|
47
|
-
- if input
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
self.
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
self.embed_dict[fea.name] = fea.get_embedding_layer()
|
|
63
|
-
elif isinstance(fea,
|
|
64
|
-
self.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
sparse_emb.append(self.embed_dict[fea.
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
pooling_layer =
|
|
80
|
-
elif fea.pooling == "
|
|
81
|
-
pooling_layer =
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
return
|
|
106
|
-
elif dense_exists and sparse_exists:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
- if input
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
self.
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
`
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
layers.append(nn.
|
|
259
|
-
layers.append(
|
|
260
|
-
layers.append(
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
self.
|
|
312
|
-
self.
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
x =
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
for
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
self.
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
self.
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
#
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
#
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
v_x = torch.
|
|
440
|
-
v_x = torch.
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
dot_ =
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
self.bilinear_layer = nn.
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
bilinear_list = [self.bilinear_layer
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
self.
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
self.
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
self.
|
|
576
|
-
self.
|
|
577
|
-
self.
|
|
578
|
-
|
|
579
|
-
self.
|
|
580
|
-
self.
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
item_eb_hat = self.linear(item_eb)
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
item_eb_hat =
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
interest_capsule =
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
#
|
|
716
|
-
s = self.mlp_att(d)
|
|
717
|
-
|
|
718
|
-
# rescale original embedding with field attention (Eq.10), output shape
|
|
719
|
-
|
|
720
|
-
|
|
1
|
+
from itertools import combinations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from .activation import activation_layer
|
|
8
|
+
from .features import DenseFeature, SequenceFeature, SparseFeature
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PredictionLayer(nn.Module):
|
|
12
|
+
"""Prediction Layer.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
task_type (str): if `task_type='classification'`, then return sigmoid(x),
|
|
16
|
+
change the input logits to probability. if`task_type='regression'`, then return x.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, task_type='classification'):
|
|
20
|
+
super(PredictionLayer, self).__init__()
|
|
21
|
+
if task_type not in ["classification", "regression"]:
|
|
22
|
+
raise ValueError("task_type must be classification or regression")
|
|
23
|
+
self.task_type = task_type
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
if self.task_type == "classification":
|
|
27
|
+
x = torch.sigmoid(x)
|
|
28
|
+
return x
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EmbeddingLayer(nn.Module):
|
|
32
|
+
"""General Embedding Layer.
|
|
33
|
+
We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
|
|
38
|
+
|
|
39
|
+
Shape:
|
|
40
|
+
- Input:
|
|
41
|
+
x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
|
|
42
|
+
sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
|
|
43
|
+
features (list): the list of `Feature Class`. It is means the current features which we want to do embedding lookup.
|
|
44
|
+
squeeze_dim (bool): whether to squeeze dim of output (default = `False`).
|
|
45
|
+
- Output:
|
|
46
|
+
- if input Dense: `(batch_size, num_features_dense)`.
|
|
47
|
+
- if input Sparse: `(batch_size, num_features, embed_dim)` or `(batch_size, num_features * embed_dim)`.
|
|
48
|
+
- if input Sequence: same with input sparse or `(batch_size, num_features_seq, seq_length, embed_dim)` when `pooling=="concat"`.
|
|
49
|
+
- if input Dense and Sparse/Sequence: `(batch_size, num_features_sparse * embed_dim)`. Note we must squeeze_dim for concat dense value with sparse embedding.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, features):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.features = features
|
|
55
|
+
self.embed_dict = nn.ModuleDict()
|
|
56
|
+
self.n_dense = 0
|
|
57
|
+
|
|
58
|
+
for fea in features:
|
|
59
|
+
if fea.name in self.embed_dict: # exist
|
|
60
|
+
continue
|
|
61
|
+
if isinstance(fea, SparseFeature) and fea.shared_with is None:
|
|
62
|
+
self.embed_dict[fea.name] = fea.get_embedding_layer()
|
|
63
|
+
elif isinstance(fea, SequenceFeature) and fea.shared_with is None:
|
|
64
|
+
self.embed_dict[fea.name] = fea.get_embedding_layer()
|
|
65
|
+
elif isinstance(fea, DenseFeature):
|
|
66
|
+
self.n_dense += 1
|
|
67
|
+
|
|
68
|
+
def forward(self, x, features, squeeze_dim=False):
|
|
69
|
+
sparse_emb, dense_values = [], []
|
|
70
|
+
sparse_exists, dense_exists = False, False
|
|
71
|
+
for fea in features:
|
|
72
|
+
if isinstance(fea, SparseFeature):
|
|
73
|
+
if fea.shared_with is None:
|
|
74
|
+
sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
|
|
75
|
+
else:
|
|
76
|
+
sparse_emb.append(self.embed_dict[fea.shared_with](x[fea.name].long()).unsqueeze(1))
|
|
77
|
+
elif isinstance(fea, SequenceFeature):
|
|
78
|
+
if fea.pooling == "sum":
|
|
79
|
+
pooling_layer = SumPooling()
|
|
80
|
+
elif fea.pooling == "mean":
|
|
81
|
+
pooling_layer = AveragePooling()
|
|
82
|
+
elif fea.pooling == "concat":
|
|
83
|
+
pooling_layer = ConcatPooling()
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError("Sequence pooling method supports only pooling in %s, got %s." % (["sum", "mean"], fea.pooling))
|
|
86
|
+
fea_mask = InputMask()(x, fea)
|
|
87
|
+
if fea.shared_with is None:
|
|
88
|
+
sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long()), fea_mask).unsqueeze(1))
|
|
89
|
+
else:
|
|
90
|
+
sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long()), fea_mask).unsqueeze(1)) # shared specific sparse feature embedding
|
|
91
|
+
else:
|
|
92
|
+
dense_values.append(x[fea.name].float() if x[fea.name].float().dim() > 1 else x[fea.name].float().unsqueeze(1)) # .unsqueeze(1).unsqueeze(1)
|
|
93
|
+
|
|
94
|
+
if len(dense_values) > 0:
|
|
95
|
+
dense_exists = True
|
|
96
|
+
dense_values = torch.cat(dense_values, dim=1)
|
|
97
|
+
if len(sparse_emb) > 0:
|
|
98
|
+
sparse_exists = True
|
|
99
|
+
# TODO: support concat dynamic embed_dim in dim 2
|
|
100
|
+
# [batch_size, num_features, embed_dim]
|
|
101
|
+
sparse_emb = torch.cat(sparse_emb, dim=1)
|
|
102
|
+
|
|
103
|
+
if squeeze_dim: # Note: if the emb_dim of sparse features is different, we must squeeze_dim
|
|
104
|
+
if dense_exists and not sparse_exists: # only input dense features
|
|
105
|
+
return dense_values
|
|
106
|
+
elif not dense_exists and sparse_exists:
|
|
107
|
+
# squeeze dim to : [batch_size, num_features*embed_dim]
|
|
108
|
+
return sparse_emb.flatten(start_dim=1)
|
|
109
|
+
elif dense_exists and sparse_exists:
|
|
110
|
+
# concat dense value with sparse embedding
|
|
111
|
+
return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1)
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError("The input features can note be empty")
|
|
114
|
+
else:
|
|
115
|
+
if sparse_exists:
|
|
116
|
+
return sparse_emb # [batch_size, num_features, embed_dim]
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError("If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" % ("SparseFeatures", features))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class InputMask(nn.Module):
|
|
122
|
+
"""Return inputs mask from given features
|
|
123
|
+
|
|
124
|
+
Shape:
|
|
125
|
+
- Input:
|
|
126
|
+
x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
|
|
127
|
+
sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
|
|
128
|
+
features (list or SparseFeature or SequenceFeature): Note that the elements in features are either all instances of SparseFeature or all instances of SequenceFeature.
|
|
129
|
+
- Output:
|
|
130
|
+
- if input Sparse: `(batch_size, num_features)`
|
|
131
|
+
- if input Sequence: `(batch_size, num_features_seq, seq_length)`
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self):
|
|
135
|
+
super().__init__()
|
|
136
|
+
|
|
137
|
+
def forward(self, x, features):
|
|
138
|
+
mask = []
|
|
139
|
+
if not isinstance(features, list):
|
|
140
|
+
features = [features]
|
|
141
|
+
for fea in features:
|
|
142
|
+
if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature):
|
|
143
|
+
if fea.padding_idx is not None:
|
|
144
|
+
fea_mask = x[fea.name].long() != fea.padding_idx
|
|
145
|
+
else:
|
|
146
|
+
fea_mask = x[fea.name].long() != -1
|
|
147
|
+
mask.append(fea_mask.unsqueeze(1).float())
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError("Only SparseFeature or SequenceFeature support to get mask.")
|
|
150
|
+
return torch.cat(mask, dim=1)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class LR(nn.Module):
|
|
154
|
+
"""Logistic Regression Module. It is the one Non-linear
|
|
155
|
+
transformation for input feature.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
input_dim (int): input size of Linear module.
|
|
159
|
+
sigmoid (bool): whether to add sigmoid function before output.
|
|
160
|
+
|
|
161
|
+
Shape:
|
|
162
|
+
- Input: `(batch_size, input_dim)`
|
|
163
|
+
- Output: `(batch_size, 1)`
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(self, input_dim, sigmoid=False):
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.sigmoid = sigmoid
|
|
169
|
+
self.fc = nn.Linear(input_dim, 1, bias=True)
|
|
170
|
+
|
|
171
|
+
def forward(self, x):
|
|
172
|
+
if self.sigmoid:
|
|
173
|
+
return torch.sigmoid(self.fc(x))
|
|
174
|
+
else:
|
|
175
|
+
return self.fc(x)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class ConcatPooling(nn.Module):
|
|
179
|
+
"""Keep the origin sequence embedding shape
|
|
180
|
+
|
|
181
|
+
Shape:
|
|
182
|
+
- Input: `(batch_size, seq_length, embed_dim)`
|
|
183
|
+
- Output: `(batch_size, seq_length, embed_dim)`
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(self):
|
|
187
|
+
super().__init__()
|
|
188
|
+
|
|
189
|
+
def forward(self, x, mask=None):
|
|
190
|
+
return x
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class AveragePooling(nn.Module):
|
|
194
|
+
"""Pooling the sequence embedding matrix by `mean`.
|
|
195
|
+
|
|
196
|
+
Shape:
|
|
197
|
+
- Input
|
|
198
|
+
x: `(batch_size, seq_length, embed_dim)`
|
|
199
|
+
mask: `(batch_size, 1, seq_length)`
|
|
200
|
+
- Output: `(batch_size, embed_dim)`
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(self):
|
|
204
|
+
super().__init__()
|
|
205
|
+
|
|
206
|
+
def forward(self, x, mask=None):
|
|
207
|
+
if mask is None:
|
|
208
|
+
return torch.mean(x, dim=1)
|
|
209
|
+
else:
|
|
210
|
+
sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
|
|
211
|
+
non_padding_length = mask.sum(dim=-1)
|
|
212
|
+
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class SumPooling(nn.Module):
|
|
216
|
+
"""Pooling the sequence embedding matrix by `sum`.
|
|
217
|
+
|
|
218
|
+
Shape:
|
|
219
|
+
- Input
|
|
220
|
+
x: `(batch_size, seq_length, embed_dim)`
|
|
221
|
+
mask: `(batch_size, 1, seq_length)`
|
|
222
|
+
- Output: `(batch_size, embed_dim)`
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def __init__(self):
|
|
226
|
+
super().__init__()
|
|
227
|
+
|
|
228
|
+
def forward(self, x, mask=None):
|
|
229
|
+
if mask is None:
|
|
230
|
+
return torch.sum(x, dim=1)
|
|
231
|
+
else:
|
|
232
|
+
return torch.bmm(mask, x).squeeze(1)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class MLP(nn.Module):
|
|
236
|
+
"""Multi Layer Perceptron Module, it is the most widely used module for
|
|
237
|
+
learning feature. Note we default add `BatchNorm1d` and `Activation`
|
|
238
|
+
`Dropout` for each `Linear` Module.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
input dim (int): input size of the first Linear Layer.
|
|
242
|
+
output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
|
|
243
|
+
dims (list): output size of Linear Layer (default=[]).
|
|
244
|
+
dropout (float): probability of an element to be zeroed (default = 0.5).
|
|
245
|
+
activation (str): the activation function, support `[sigmoid, relu, prelu, dice, softmax]` (default='relu').
|
|
246
|
+
|
|
247
|
+
Shape:
|
|
248
|
+
- Input: `(batch_size, input_dim)`
|
|
249
|
+
- Output: `(batch_size, 1)` or `(batch_size, dims[-1])`
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
|
|
253
|
+
super().__init__()
|
|
254
|
+
if dims is None:
|
|
255
|
+
dims = []
|
|
256
|
+
layers = list()
|
|
257
|
+
for i_dim in dims:
|
|
258
|
+
layers.append(nn.Linear(input_dim, i_dim))
|
|
259
|
+
layers.append(nn.BatchNorm1d(i_dim))
|
|
260
|
+
layers.append(activation_layer(activation))
|
|
261
|
+
layers.append(nn.Dropout(p=dropout))
|
|
262
|
+
input_dim = i_dim
|
|
263
|
+
if output_layer:
|
|
264
|
+
layers.append(nn.Linear(input_dim, 1))
|
|
265
|
+
self.mlp = nn.Sequential(*layers)
|
|
266
|
+
|
|
267
|
+
def forward(self, x):
|
|
268
|
+
return self.mlp(x)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class FM(nn.Module):
|
|
272
|
+
"""The Factorization Machine module, mentioned in the `DeepFM paper
|
|
273
|
+
<https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
|
|
274
|
+
feature interactions.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
reduce_sum (bool): whether to sum in embed_dim (default = `True`).
|
|
278
|
+
|
|
279
|
+
Shape:
|
|
280
|
+
- Input: `(batch_size, num_features, embed_dim)`
|
|
281
|
+
- Output: `(batch_size, 1)`` or ``(batch_size, embed_dim)`
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(self, reduce_sum=True):
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.reduce_sum = reduce_sum
|
|
287
|
+
|
|
288
|
+
def forward(self, x):
|
|
289
|
+
square_of_sum = torch.sum(x, dim=1)**2
|
|
290
|
+
sum_of_square = torch.sum(x**2, dim=1)
|
|
291
|
+
ix = square_of_sum - sum_of_square
|
|
292
|
+
if self.reduce_sum:
|
|
293
|
+
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
294
|
+
return 0.5 * ix
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class CIN(nn.Module):
|
|
298
|
+
"""Compressed Interaction Network
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
input_dim (int): input dim of input tensor.
|
|
302
|
+
cin_size (list[int]): out channels of Conv1d.
|
|
303
|
+
|
|
304
|
+
Shape:
|
|
305
|
+
- Input: `(batch_size, num_features, embed_dim)`
|
|
306
|
+
- Output: `(batch_size, 1)`
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, input_dim, cin_size, split_half=True):
|
|
310
|
+
super().__init__()
|
|
311
|
+
self.num_layers = len(cin_size)
|
|
312
|
+
self.split_half = split_half
|
|
313
|
+
self.conv_layers = torch.nn.ModuleList()
|
|
314
|
+
prev_dim, fc_input_dim = input_dim, 0
|
|
315
|
+
for i in range(self.num_layers):
|
|
316
|
+
cross_layer_size = cin_size[i]
|
|
317
|
+
self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
|
|
318
|
+
if self.split_half and i != self.num_layers - 1:
|
|
319
|
+
cross_layer_size //= 2
|
|
320
|
+
prev_dim = cross_layer_size
|
|
321
|
+
fc_input_dim += prev_dim
|
|
322
|
+
self.fc = torch.nn.Linear(fc_input_dim, 1)
|
|
323
|
+
|
|
324
|
+
def forward(self, x):
|
|
325
|
+
xs = list()
|
|
326
|
+
x0, h = x.unsqueeze(2), x
|
|
327
|
+
for i in range(self.num_layers):
|
|
328
|
+
x = x0 * h.unsqueeze(1)
|
|
329
|
+
batch_size, f0_dim, fin_dim, embed_dim = x.shape
|
|
330
|
+
x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
|
|
331
|
+
x = F.relu(self.conv_layers[i](x))
|
|
332
|
+
if self.split_half and i != self.num_layers - 1:
|
|
333
|
+
x, h = torch.split(x, x.shape[1] // 2, dim=1)
|
|
334
|
+
else:
|
|
335
|
+
h = x
|
|
336
|
+
xs.append(x)
|
|
337
|
+
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class CrossLayer(nn.Module):
|
|
341
|
+
"""
|
|
342
|
+
Cross layer.
|
|
343
|
+
Args:
|
|
344
|
+
input_dim (int): input dim of input tensor
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, input_dim):
|
|
348
|
+
super(CrossLayer, self).__init__()
|
|
349
|
+
self.w = torch.nn.Linear(input_dim, 1, bias=False)
|
|
350
|
+
self.b = torch.nn.Parameter(torch.zeros(input_dim))
|
|
351
|
+
|
|
352
|
+
def forward(self, x_0, x_i):
|
|
353
|
+
x = self.w(x_i) * x_0 + self.b
|
|
354
|
+
return x
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class CrossNetwork(nn.Module):
|
|
358
|
+
"""CrossNetwork mentioned in the DCN paper.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
input_dim (int): input dim of input tensor
|
|
362
|
+
|
|
363
|
+
Shape:
|
|
364
|
+
- Input: `(batch_size, *)`
|
|
365
|
+
- Output: `(batch_size, *)`
|
|
366
|
+
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(self, input_dim, num_layers):
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.num_layers = num_layers
|
|
372
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
|
|
373
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
374
|
+
|
|
375
|
+
def forward(self, x):
|
|
376
|
+
"""
|
|
377
|
+
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
378
|
+
"""
|
|
379
|
+
x0 = x
|
|
380
|
+
for i in range(self.num_layers):
|
|
381
|
+
xw = self.w[i](x)
|
|
382
|
+
x = x0 * xw + self.b[i] + x
|
|
383
|
+
return x
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class CrossNetV2(nn.Module):
|
|
387
|
+
|
|
388
|
+
def __init__(self, input_dim, num_layers):
|
|
389
|
+
super().__init__()
|
|
390
|
+
self.num_layers = num_layers
|
|
391
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
|
|
392
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
393
|
+
|
|
394
|
+
def forward(self, x):
|
|
395
|
+
x0 = x
|
|
396
|
+
for i in range(self.num_layers):
|
|
397
|
+
x = x0 * self.w[i](x) + self.b[i] + x
|
|
398
|
+
return x
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class CrossNetMix(nn.Module):
|
|
402
|
+
""" CrossNetMix improves CrossNetwork by:
|
|
403
|
+
1. add MOE to learn feature interactions in different subspaces
|
|
404
|
+
2. add nonlinear transformations in low-dimensional space
|
|
405
|
+
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
|
|
409
|
+
super(CrossNetMix, self).__init__()
|
|
410
|
+
self.num_layers = num_layers
|
|
411
|
+
self.num_experts = num_experts
|
|
412
|
+
|
|
413
|
+
# U: (input_dim, low_rank)
|
|
414
|
+
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
415
|
+
# V: (input_dim, low_rank)
|
|
416
|
+
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
417
|
+
# C: (low_rank, low_rank)
|
|
418
|
+
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
|
|
419
|
+
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
|
|
420
|
+
|
|
421
|
+
self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1))) for i in range(self.num_layers)])
|
|
422
|
+
|
|
423
|
+
def forward(self, x):
|
|
424
|
+
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
425
|
+
x_l = x_0
|
|
426
|
+
for i in range(self.num_layers):
|
|
427
|
+
output_of_experts = []
|
|
428
|
+
gating_score_experts = []
|
|
429
|
+
for expert_id in range(self.num_experts):
|
|
430
|
+
# (1) G(x_l)
|
|
431
|
+
# compute the gating score by x_l
|
|
432
|
+
gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
|
|
433
|
+
|
|
434
|
+
# (2) E(x_l)
|
|
435
|
+
# project the input x_l to $\mathbb{R}^{r}$
|
|
436
|
+
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
|
|
437
|
+
|
|
438
|
+
# nonlinear activation in low rank space
|
|
439
|
+
v_x = torch.tanh(v_x)
|
|
440
|
+
v_x = torch.matmul(self.c_list[i][expert_id], v_x)
|
|
441
|
+
v_x = torch.tanh(v_x)
|
|
442
|
+
|
|
443
|
+
# project back to $\mathbb{R}^{d}$
|
|
444
|
+
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
|
|
445
|
+
|
|
446
|
+
dot_ = uv_x + self.bias[i]
|
|
447
|
+
dot_ = x_0 * dot_ # Hadamard-product
|
|
448
|
+
|
|
449
|
+
output_of_experts.append(dot_.squeeze(2))
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# (3) mixture of low-rank experts
|
|
453
|
+
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
|
|
454
|
+
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
|
|
455
|
+
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
456
|
+
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
457
|
+
|
|
458
|
+
x_l = x_l.squeeze() # (bs, in_features)
|
|
459
|
+
return x_l
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class SENETLayer(nn.Module):
|
|
463
|
+
"""
|
|
464
|
+
A weighted feature gating system in the SENet paper
|
|
465
|
+
Args:
|
|
466
|
+
num_fields (int): number of feature fields
|
|
467
|
+
|
|
468
|
+
Shape:
|
|
469
|
+
- num_fields: `(batch_size, *)`
|
|
470
|
+
- Output: `(batch_size, *)`
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, num_fields, reduction_ratio=3):
|
|
474
|
+
super(SENETLayer, self).__init__()
|
|
475
|
+
reduced_size = max(1, int(num_fields / reduction_ratio))
|
|
476
|
+
self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False), nn.ReLU(), nn.Linear(reduced_size, num_fields, bias=False), nn.ReLU())
|
|
477
|
+
|
|
478
|
+
def forward(self, x):
|
|
479
|
+
z = torch.mean(x, dim=-1, out=None)
|
|
480
|
+
a = self.mlp(z)
|
|
481
|
+
v = x * a.unsqueeze(-1)
|
|
482
|
+
return v
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
class BiLinearInteractionLayer(nn.Module):
|
|
486
|
+
"""
|
|
487
|
+
Bilinear feature interaction module, which is an improved model of the FFM model
|
|
488
|
+
Args:
|
|
489
|
+
num_fields (int): number of feature fields
|
|
490
|
+
bilinear_type(str): the type bilinear interaction function
|
|
491
|
+
Shape:
|
|
492
|
+
- num_fields: `(batch_size, *)`
|
|
493
|
+
- Output: `(batch_size, *)`
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
|
|
497
|
+
super(BiLinearInteractionLayer, self).__init__()
|
|
498
|
+
self.bilinear_type = bilinear_type
|
|
499
|
+
if self.bilinear_type == "field_all":
|
|
500
|
+
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
501
|
+
elif self.bilinear_type == "field_each":
|
|
502
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
|
|
503
|
+
elif self.bilinear_type == "field_interaction":
|
|
504
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i, j in combinations(range(num_fields), 2)])
|
|
505
|
+
else:
|
|
506
|
+
raise NotImplementedError()
|
|
507
|
+
|
|
508
|
+
def forward(self, x):
|
|
509
|
+
feature_emb = torch.split(x, 1, dim=1)
|
|
510
|
+
if self.bilinear_type == "field_all":
|
|
511
|
+
bilinear_list = [self.bilinear_layer(v_i) * v_j for v_i, v_j in combinations(feature_emb, 2)]
|
|
512
|
+
elif self.bilinear_type == "field_each":
|
|
513
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)]
|
|
514
|
+
elif self.bilinear_type == "field_interaction":
|
|
515
|
+
bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))]
|
|
516
|
+
return torch.cat(bilinear_list, dim=1)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class MultiInterestSA(nn.Module):
|
|
520
|
+
"""MultiInterest Attention mentioned in the Comirec paper.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
embedding_dim (int): embedding dim of item embedding
|
|
524
|
+
interest_num (int): num of interest
|
|
525
|
+
hidden_dim (int): hidden dim
|
|
526
|
+
|
|
527
|
+
Shape:
|
|
528
|
+
- Input: seq_emb : (batch,seq,emb)
|
|
529
|
+
mask : (batch,seq,1)
|
|
530
|
+
- Output: `(batch_size, interest_num, embedding_dim)`
|
|
531
|
+
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
535
|
+
super(MultiInterestSA, self).__init__()
|
|
536
|
+
self.embedding_dim = embedding_dim
|
|
537
|
+
self.interest_num = interest_num
|
|
538
|
+
if hidden_dim is None:
|
|
539
|
+
self.hidden_dim = self.embedding_dim * 4
|
|
540
|
+
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
541
|
+
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
542
|
+
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
543
|
+
|
|
544
|
+
def forward(self, seq_emb, mask=None):
|
|
545
|
+
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
546
|
+
if mask is not None:
|
|
547
|
+
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + - \
|
|
548
|
+
1.e9 * (1 - mask.float())
|
|
549
|
+
A = F.softmax(A, dim=1)
|
|
550
|
+
else:
|
|
551
|
+
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
552
|
+
A = A.permute(0, 2, 1)
|
|
553
|
+
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
554
|
+
return multi_interest_emb
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class CapsuleNetwork(nn.Module):
|
|
558
|
+
"""CapsuleNetwork mentioned in the Comirec and MIND paper.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
hidden_size (int): embedding dim of item embedding
|
|
562
|
+
seq_len (int): length of the item sequence
|
|
563
|
+
bilinear_type (int): 0 for MIND, 2 for ComirecDR
|
|
564
|
+
interest_num (int): num of interest
|
|
565
|
+
routing_times (int): routing times
|
|
566
|
+
|
|
567
|
+
Shape:
|
|
568
|
+
- Input: seq_emb : (batch,seq,emb)
|
|
569
|
+
mask : (batch,seq,1)
|
|
570
|
+
- Output: `(batch_size, interest_num, embedding_dim)`
|
|
571
|
+
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
575
|
+
super(CapsuleNetwork, self).__init__()
|
|
576
|
+
self.embedding_dim = embedding_dim # h
|
|
577
|
+
self.seq_len = seq_len # s
|
|
578
|
+
self.bilinear_type = bilinear_type
|
|
579
|
+
self.interest_num = interest_num
|
|
580
|
+
self.routing_times = routing_times
|
|
581
|
+
|
|
582
|
+
self.relu_layer = relu_layer
|
|
583
|
+
self.stop_grad = True
|
|
584
|
+
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
585
|
+
if self.bilinear_type == 0: # MIND
|
|
586
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
587
|
+
elif self.bilinear_type == 1:
|
|
588
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
589
|
+
else:
|
|
590
|
+
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
591
|
+
|
|
592
|
+
def forward(self, item_eb, mask):
|
|
593
|
+
if self.bilinear_type == 0:
|
|
594
|
+
item_eb_hat = self.linear(item_eb)
|
|
595
|
+
item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
|
|
596
|
+
elif self.bilinear_type == 1:
|
|
597
|
+
item_eb_hat = self.linear(item_eb)
|
|
598
|
+
else:
|
|
599
|
+
u = torch.unsqueeze(item_eb, dim=2)
|
|
600
|
+
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
601
|
+
|
|
602
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
603
|
+
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
604
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
605
|
+
|
|
606
|
+
if self.stop_grad:
|
|
607
|
+
item_eb_hat_iter = item_eb_hat.detach()
|
|
608
|
+
else:
|
|
609
|
+
item_eb_hat_iter = item_eb_hat
|
|
610
|
+
|
|
611
|
+
if self.bilinear_type > 0:
|
|
612
|
+
capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
|
|
613
|
+
else:
|
|
614
|
+
capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
|
|
615
|
+
|
|
616
|
+
for i in range(self.routing_times): # 动态路由传播3次
|
|
617
|
+
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
618
|
+
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
619
|
+
|
|
620
|
+
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
621
|
+
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
622
|
+
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
623
|
+
|
|
624
|
+
if i < 2:
|
|
625
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
626
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
627
|
+
scalar_factor = cap_norm / \
|
|
628
|
+
(1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
629
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
630
|
+
|
|
631
|
+
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
632
|
+
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
633
|
+
capsule_weight = capsule_weight + delta_weight
|
|
634
|
+
else:
|
|
635
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
636
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
637
|
+
scalar_factor = cap_norm / \
|
|
638
|
+
(1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
639
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
640
|
+
|
|
641
|
+
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
642
|
+
|
|
643
|
+
if self.relu_layer:
|
|
644
|
+
interest_capsule = self.relu(interest_capsule)
|
|
645
|
+
|
|
646
|
+
return interest_capsule
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
class FFM(nn.Module):
|
|
650
|
+
"""The Field-aware Factorization Machine module, mentioned in the `FFM paper
|
|
651
|
+
<https://dl.acm.org/doi/abs/10.1145/2959100.2959134>`. It explicitly models
|
|
652
|
+
multi-channel second-order feature interactions, with each feature filed
|
|
653
|
+
corresponding to one channel.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
num_fields (int): number of feature fields.
|
|
657
|
+
reduce_sum (bool): whether to sum in embed_dim (default = `True`).
|
|
658
|
+
|
|
659
|
+
Shape:
|
|
660
|
+
- Input: `(batch_size, num_fields, num_fields, embed_dim)`
|
|
661
|
+
- Output: `(batch_size, num_fields*(num_fields-1)/2, 1)` or `(batch_size, num_fields*(num_fields-1)/2, embed_dim)`
|
|
662
|
+
"""
|
|
663
|
+
|
|
664
|
+
def __init__(self, num_fields, reduce_sum=True):
|
|
665
|
+
super().__init__()
|
|
666
|
+
self.num_fields = num_fields
|
|
667
|
+
self.reduce_sum = reduce_sum
|
|
668
|
+
|
|
669
|
+
def forward(self, x):
|
|
670
|
+
# compute (non-redundant) second order field-aware feature crossings
|
|
671
|
+
crossed_embeddings = []
|
|
672
|
+
for i in range(self.num_fields - 1):
|
|
673
|
+
for j in range(i + 1, self.num_fields):
|
|
674
|
+
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
675
|
+
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
676
|
+
|
|
677
|
+
# if reduce_sum is true, the crossing operation is effectively inner
|
|
678
|
+
# product, other wise Hadamard-product
|
|
679
|
+
if self.reduce_sum:
|
|
680
|
+
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
681
|
+
return crossed_embeddings
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
class CEN(nn.Module):
|
|
685
|
+
"""The Compose-Excitation Network module, mentioned in the `FAT-DeepFFM paper
|
|
686
|
+
<https://arxiv.org/abs/1905.06336>`, a modified version of
|
|
687
|
+
`Squeeze-and-Excitation Network” (SENet) (Hu et al., 2017)`. It is used to
|
|
688
|
+
highlight the importance of second-order feature crosses.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
embed_dim (int): the dimensionality of categorical value embedding.
|
|
692
|
+
num_field_crosses (int): the number of second order crosses between feature fields.
|
|
693
|
+
reduction_ratio (int): the between the dimensions of input layer and hidden layer of the MLP module.
|
|
694
|
+
|
|
695
|
+
Shape:
|
|
696
|
+
- Input: `(batch_size, num_fields, num_fields, embed_dim)`
|
|
697
|
+
- Output: `(batch_size, num_fields*(num_fields-1)/2 * embed_dim)`
|
|
698
|
+
"""
|
|
699
|
+
|
|
700
|
+
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
701
|
+
super().__init__()
|
|
702
|
+
|
|
703
|
+
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
704
|
+
self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
|
|
705
|
+
|
|
706
|
+
# two FC layers that computes the field attention
|
|
707
|
+
self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses // reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
|
|
708
|
+
|
|
709
|
+
def forward(self, em):
|
|
710
|
+
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape
|
|
711
|
+
# [batch_size, num_field_crosses]
|
|
712
|
+
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
713
|
+
|
|
714
|
+
# compute field attention (Eq.9), output shape [batch_size,
|
|
715
|
+
# num_field_crosses]
|
|
716
|
+
s = self.mlp_att(d)
|
|
717
|
+
|
|
718
|
+
# rescale original embedding with field attention (Eq.10), output shape
|
|
719
|
+
# [batch_size, num_field_crosses, embed_dim]
|
|
720
|
+
aem = s.unsqueeze(-1) * em
|
|
721
|
+
return aem.flatten(start_dim=1)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
# ============ HSTU Layers (新增) ============
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class HSTULayer(nn.Module):
|
|
728
|
+
"""Single HSTU layer.
|
|
729
|
+
|
|
730
|
+
This layer implements the core HSTU "sequential transduction unit": a
|
|
731
|
+
multi-head self-attention block with gating and a position-wise FFN, plus
|
|
732
|
+
residual connections and LayerNorm.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
d_model (int): Hidden dimension of the model. Default: 512.
|
|
736
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
737
|
+
dqk (int): Dimension of query/key per head. Default: 64.
|
|
738
|
+
dv (int): Dimension of value per head. Default: 64.
|
|
739
|
+
dropout (float): Dropout rate applied in the layer. Default: 0.1.
|
|
740
|
+
use_rel_pos_bias (bool): Whether to use relative position bias.
|
|
741
|
+
|
|
742
|
+
Shape:
|
|
743
|
+
- Input: ``(batch_size, seq_len, d_model)``
|
|
744
|
+
- Output: ``(batch_size, seq_len, d_model)``
|
|
745
|
+
|
|
746
|
+
Example:
|
|
747
|
+
>>> layer = HSTULayer(d_model=512, n_heads=8)
|
|
748
|
+
>>> x = torch.randn(32, 256, 512)
|
|
749
|
+
>>> output = layer(x)
|
|
750
|
+
>>> output.shape
|
|
751
|
+
torch.Size([32, 256, 512])
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
def __init__(self, d_model=512, n_heads=8, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
|
|
755
|
+
super().__init__()
|
|
756
|
+
self.d_model = d_model
|
|
757
|
+
self.n_heads = n_heads
|
|
758
|
+
self.dqk = dqk
|
|
759
|
+
self.dv = dv
|
|
760
|
+
self.dropout_rate = dropout
|
|
761
|
+
self.use_rel_pos_bias = use_rel_pos_bias
|
|
762
|
+
|
|
763
|
+
# Validate dimensions
|
|
764
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
765
|
+
|
|
766
|
+
# Projection 1: d_model -> 2*n_heads*dqk + 2*n_heads*dv
|
|
767
|
+
proj1_out_dim = 2 * n_heads * dqk + 2 * n_heads * dv
|
|
768
|
+
self.proj1 = nn.Linear(d_model, proj1_out_dim)
|
|
769
|
+
|
|
770
|
+
# Projection 2: n_heads*dv -> d_model
|
|
771
|
+
self.proj2 = nn.Linear(n_heads * dv, d_model)
|
|
772
|
+
|
|
773
|
+
# Feed-forward network (FFN)
|
|
774
|
+
# Standard Transformer uses 4*d_model as the hidden dimension of FFN
|
|
775
|
+
ffn_hidden_dim = 4 * d_model
|
|
776
|
+
self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden_dim, d_model), nn.Dropout(dropout))
|
|
777
|
+
|
|
778
|
+
# Layer normalization
|
|
779
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
780
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
781
|
+
|
|
782
|
+
# Dropout
|
|
783
|
+
self.dropout = nn.Dropout(dropout)
|
|
784
|
+
|
|
785
|
+
# Scaling factor for attention scores
|
|
786
|
+
self.scale = 1.0 / (dqk**0.5)
|
|
787
|
+
|
|
788
|
+
def forward(self, x, rel_pos_bias=None):
|
|
789
|
+
"""Forward pass of a single HSTU layer.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
793
|
+
rel_pos_bias (Tensor, optional): Relative position bias of shape
|
|
794
|
+
``(1, n_heads, seq_len, seq_len)``.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
798
|
+
"""
|
|
799
|
+
batch_size, seq_len, _ = x.shape
|
|
800
|
+
|
|
801
|
+
# Residual connection
|
|
802
|
+
residual = x
|
|
803
|
+
|
|
804
|
+
# Layer normalization
|
|
805
|
+
x = self.norm1(x)
|
|
806
|
+
|
|
807
|
+
# Projection 1: (B, L, D) -> (B, L, 2*H*dqk + 2*H*dv)
|
|
808
|
+
proj_out = self.proj1(x)
|
|
809
|
+
|
|
810
|
+
# Split into Q, K, U, V
|
|
811
|
+
# Q, K: (B, L, H, dqk)
|
|
812
|
+
# U, V: (B, L, H, dv)
|
|
813
|
+
q = proj_out[..., :self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
|
|
814
|
+
k = proj_out[..., self.n_heads * self.dqk:2 * self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
|
|
815
|
+
u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
816
|
+
v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
817
|
+
|
|
818
|
+
# Transpose to (B, H, L, dqk/dv)
|
|
819
|
+
q = q.transpose(1, 2) # (B, H, L, dqk)
|
|
820
|
+
k = k.transpose(1, 2) # (B, H, L, dqk)
|
|
821
|
+
u = u.transpose(1, 2) # (B, H, L, dv)
|
|
822
|
+
v = v.transpose(1, 2) # (B, H, L, dv)
|
|
823
|
+
|
|
824
|
+
# Compute attention scores: (B, H, L, L)
|
|
825
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
826
|
+
|
|
827
|
+
# Add causal mask (prevent attending to future positions)
|
|
828
|
+
# For generative models this is required so that position i only attends
|
|
829
|
+
# to positions <= i.
|
|
830
|
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
831
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
832
|
+
|
|
833
|
+
# Add relative position bias if provided
|
|
834
|
+
if rel_pos_bias is not None:
|
|
835
|
+
scores = scores + rel_pos_bias
|
|
836
|
+
|
|
837
|
+
# Softmax over attention scores
|
|
838
|
+
attn_weights = F.softmax(scores, dim=-1)
|
|
839
|
+
attn_weights = self.dropout(attn_weights)
|
|
840
|
+
|
|
841
|
+
# Attention output: (B, H, L, dv)
|
|
842
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
843
|
+
|
|
844
|
+
# Gating mechanism: apply a learned gate on top of attention output
|
|
845
|
+
# First transpose back to (B, L, H, dv)
|
|
846
|
+
attn_output = attn_output.transpose(1, 2) # (B, L, H, dv)
|
|
847
|
+
u = u.transpose(1, 2) # (B, L, H, dv)
|
|
848
|
+
|
|
849
|
+
# Apply element-wise gate: (B, L, H, dv)
|
|
850
|
+
gated_output = attn_output * torch.sigmoid(u)
|
|
851
|
+
|
|
852
|
+
# Merge heads: (B, L, H*dv)
|
|
853
|
+
gated_output = gated_output.reshape(batch_size, seq_len, self.n_heads * self.dv)
|
|
854
|
+
|
|
855
|
+
# Projection 2: (B, L, H*dv) -> (B, L, D)
|
|
856
|
+
output = self.proj2(gated_output)
|
|
857
|
+
output = self.dropout(output)
|
|
858
|
+
|
|
859
|
+
# Residual connection
|
|
860
|
+
output = output + residual
|
|
861
|
+
|
|
862
|
+
# Second residual block: LayerNorm + FFN + residual connection
|
|
863
|
+
residual = output
|
|
864
|
+
output = self.norm2(output)
|
|
865
|
+
output = self.ffn(output)
|
|
866
|
+
output = output + residual
|
|
867
|
+
|
|
868
|
+
return output
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
class HSTUBlock(nn.Module):
|
|
872
|
+
"""Stacked HSTU block.
|
|
873
|
+
|
|
874
|
+
This block stacks multiple :class:`HSTULayer` layers to form a deep HSTU
|
|
875
|
+
encoder for sequential recommendation.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
d_model (int): Hidden dimension of the model. Default: 512.
|
|
879
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
880
|
+
n_layers (int): Number of stacked HSTU layers. Default: 4.
|
|
881
|
+
dqk (int): Dimension of query/key per head. Default: 64.
|
|
882
|
+
dv (int): Dimension of value per head. Default: 64.
|
|
883
|
+
dropout (float): Dropout rate applied in each layer. Default: 0.1.
|
|
884
|
+
use_rel_pos_bias (bool): Whether to use relative position bias.
|
|
885
|
+
|
|
886
|
+
Shape:
|
|
887
|
+
- Input: ``(batch_size, seq_len, d_model)``
|
|
888
|
+
- Output: ``(batch_size, seq_len, d_model)``
|
|
889
|
+
|
|
890
|
+
Example:
|
|
891
|
+
>>> block = HSTUBlock(d_model=512, n_heads=8, n_layers=4)
|
|
892
|
+
>>> x = torch.randn(32, 256, 512)
|
|
893
|
+
>>> output = block(x)
|
|
894
|
+
>>> output.shape
|
|
895
|
+
torch.Size([32, 256, 512])
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def __init__(self, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
|
|
899
|
+
super().__init__()
|
|
900
|
+
self.d_model = d_model
|
|
901
|
+
self.n_heads = n_heads
|
|
902
|
+
self.n_layers = n_layers
|
|
903
|
+
|
|
904
|
+
# Create a stack of HSTULayer modules
|
|
905
|
+
self.layers = nn.ModuleList([HSTULayer(d_model=d_model, n_heads=n_heads, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias) for _ in range(n_layers)])
|
|
906
|
+
|
|
907
|
+
def forward(self, x, rel_pos_bias=None):
|
|
908
|
+
"""Forward pass through all stacked HSTULayer modules.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
912
|
+
rel_pos_bias (Tensor, optional): Relative position bias shared across
|
|
913
|
+
all layers.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
917
|
+
"""
|
|
918
|
+
for layer in self.layers:
|
|
919
|
+
x = layer(x, rel_pos_bias=rel_pos_bias)
|
|
920
|
+
return x
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
class InteractingLayer(nn.Module):
|
|
924
|
+
"""Multi-head Self-Attention based Interacting Layer, used in AutoInt model.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
embed_dim (int): the embedding dimension.
|
|
928
|
+
num_heads (int): the number of attention heads (default=2).
|
|
929
|
+
dropout (float): the dropout rate (default=0.0).
|
|
930
|
+
residual (bool): whether to use residual connection (default=True).
|
|
931
|
+
|
|
932
|
+
Shape:
|
|
933
|
+
- Input: `(batch_size, num_fields, embed_dim)`
|
|
934
|
+
- Output: `(batch_size, num_fields, embed_dim)`
|
|
935
|
+
"""
|
|
936
|
+
|
|
937
|
+
def __init__(self, embed_dim, num_heads=2, dropout=0.0, residual=True):
|
|
938
|
+
super().__init__()
|
|
939
|
+
if embed_dim % num_heads != 0:
|
|
940
|
+
raise ValueError("embed_dim must be divisible by num_heads")
|
|
941
|
+
|
|
942
|
+
self.embed_dim = embed_dim
|
|
943
|
+
self.num_heads = num_heads
|
|
944
|
+
self.head_dim = embed_dim // num_heads
|
|
945
|
+
self.scale = self.head_dim**-0.5
|
|
946
|
+
self.residual = residual
|
|
947
|
+
|
|
948
|
+
self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
949
|
+
self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
950
|
+
self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
951
|
+
|
|
952
|
+
# Residual connection
|
|
953
|
+
self.W_Res = nn.Linear(embed_dim, embed_dim, bias=False) if residual else None
|
|
954
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
|
955
|
+
|
|
956
|
+
def forward(self, x):
|
|
957
|
+
"""
|
|
958
|
+
Args:
|
|
959
|
+
x: input tensor with shape (batch_size, num_fields, embed_dim)
|
|
960
|
+
"""
|
|
961
|
+
batch_size, num_fields, embed_dim = x.shape
|
|
962
|
+
|
|
963
|
+
# Linear projections
|
|
964
|
+
Q = self.W_Q(x) # (batch_size, num_fields, embed_dim)
|
|
965
|
+
K = self.W_K(x) # (batch_size, num_fields, embed_dim)
|
|
966
|
+
V = self.W_V(x) # (batch_size, num_fields, embed_dim)
|
|
967
|
+
|
|
968
|
+
# Reshape for multi-head attention
|
|
969
|
+
# (batch_size, num_heads, num_fields, head_dim)
|
|
970
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
971
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
972
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
973
|
+
|
|
974
|
+
# Scaled dot-product attention
|
|
975
|
+
# (batch_size, num_heads, num_fields, num_fields)
|
|
976
|
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
|
|
977
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
978
|
+
|
|
979
|
+
if self.dropout is not None:
|
|
980
|
+
attn_weights = self.dropout(attn_weights)
|
|
981
|
+
|
|
982
|
+
# Apply attention to values
|
|
983
|
+
# (batch_size, num_heads, num_fields, head_dim)
|
|
984
|
+
attn_output = torch.matmul(attn_weights, V)
|
|
985
|
+
|
|
986
|
+
# Concatenate heads
|
|
987
|
+
# (batch_size, num_fields, embed_dim)
|
|
988
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, num_fields, embed_dim)
|
|
989
|
+
|
|
990
|
+
# Residual connection
|
|
991
|
+
if self.residual and self.W_Res is not None:
|
|
992
|
+
attn_output = attn_output + self.W_Res(x)
|
|
993
|
+
|
|
994
|
+
return F.relu(attn_output)
|