nextrec 0.4.33__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +10 -18
  3. nextrec/basic/asserts.py +1 -22
  4. nextrec/basic/callback.py +2 -2
  5. nextrec/basic/features.py +6 -37
  6. nextrec/basic/heads.py +13 -1
  7. nextrec/basic/layers.py +33 -123
  8. nextrec/basic/loggers.py +3 -2
  9. nextrec/basic/metrics.py +85 -4
  10. nextrec/basic/model.py +518 -7
  11. nextrec/basic/summary.py +88 -42
  12. nextrec/cli.py +117 -30
  13. nextrec/data/data_processing.py +8 -13
  14. nextrec/data/preprocessor.py +449 -844
  15. nextrec/loss/grad_norm.py +78 -76
  16. nextrec/models/multi_task/ple.py +1 -0
  17. nextrec/models/multi_task/share_bottom.py +1 -0
  18. nextrec/models/ranking/afm.py +4 -9
  19. nextrec/models/ranking/dien.py +7 -8
  20. nextrec/models/ranking/ffm.py +2 -2
  21. nextrec/models/retrieval/sdm.py +1 -2
  22. nextrec/models/sequential/hstu.py +0 -2
  23. nextrec/models/tree_base/base.py +1 -1
  24. nextrec/utils/__init__.py +2 -1
  25. nextrec/utils/config.py +1 -1
  26. nextrec/utils/console.py +1 -1
  27. nextrec/utils/onnx_utils.py +252 -0
  28. nextrec/utils/torch_utils.py +63 -56
  29. nextrec/utils/types.py +43 -0
  30. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/METADATA +10 -4
  31. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/RECORD +34 -42
  32. nextrec/models/multi_task/[pre]star.py +0 -192
  33. nextrec/models/representation/autorec.py +0 -0
  34. nextrec/models/representation/bpr.py +0 -0
  35. nextrec/models/representation/cl4srec.py +0 -0
  36. nextrec/models/representation/lightgcn.py +0 -0
  37. nextrec/models/representation/mf.py +0 -0
  38. nextrec/models/representation/s3rec.py +0 -0
  39. nextrec/models/sequential/sasrec.py +0 -0
  40. nextrec/utils/feature.py +0 -29
  41. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/WHEEL +0 -0
  42. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/entry_points.txt +0 -0
  43. {nextrec-0.4.33.dist-info → nextrec-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,26 +1,26 @@
1
1
  nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
- nextrec/__version__.py,sha256=O_0xE0g6EcJfkv7qWx5tmF2cs2K3UCW8uX8xzUqd7Rs,23
3
- nextrec/cli.py,sha256=k7gOrPfb3zmyUDxZipUNCFn-PaKCwUKbyJHhgpt-lyc,25673
2
+ nextrec/__version__.py,sha256=LBK46heutvn3KmsCrKIYu8RQikbfnjZaj2xFrXaeCzQ,22
3
+ nextrec/cli.py,sha256=ryRwHI62wv-7qQs8JbuQrAl0VHzURkGVPB9IiOGtnck,29120
4
4
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- nextrec/basic/activation.py,sha256=uekcJsOy8SiT0_NaDO2VNSStyYFzVikDFVLDk-VrjwQ,2949
6
- nextrec/basic/asserts.py,sha256=U1EKovV_OT7_Mm99zFvdfF2hccFREp3gdDaeRjfiBwQ,2249
7
- nextrec/basic/callback.py,sha256=7geza5iMMlMojlrIKH5A7nzvCe4IYwgUaMRh_xpblWk,12585
8
- nextrec/basic/features.py,sha256=zLijBNkKwCXv9TKxSWwvmt7aVfWn2D5JvfwukeIRqec,9174
9
- nextrec/basic/heads.py,sha256=BshykLxD41KxKuZaBxf4Fmy1Mc52b3ioJliN1BVaGlk,3374
10
- nextrec/basic/layers.py,sha256=tr8XFOcTvUHEZ6T3zJwmtKMA-u_xfzHloIkItGs821U,40084
11
- nextrec/basic/loggers.py,sha256=LAfnhdSNEzHybrXaKxCWoAML1c2A-FJF6atpfrrm_Kw,13840
12
- nextrec/basic/metrics.py,sha256=CPzENDcpO6QTDZLBtQlfAGKUYYQc0FT-eaMKJ4MURFo,23396
13
- nextrec/basic/model.py,sha256=Psm1lfAScyDmkK-TmA7pjvI_Hg1IkZ02XgnqJVmvwyw,111699
5
+ nextrec/basic/activation.py,sha256=rU-W-DHgiD3AZnMGmD014ChxklfP9BpedDTiwtdgXhA,2762
6
+ nextrec/basic/asserts.py,sha256=eaB4FZJ7Sbh9S8PLJZNiYd7Zi98ca2rDi4S7wSYCaEw,1473
7
+ nextrec/basic/callback.py,sha256=H1C7EdkLTRLtPrKwCk1Gwq41m7I6tmtUWEST5Ih9fHI,12648
8
+ nextrec/basic/features.py,sha256=hVbEEtEYFer5OLkqNEc0N0vK9QkkWE9Il8FOW0hejZQ,8528
9
+ nextrec/basic/heads.py,sha256=WqvavaH6Y8Au8dLaoUfH2AaOOWgYvjZI5US8avkQNsQ,4009
10
+ nextrec/basic/layers.py,sha256=tawggQMMHlYTGpnubxUAvDPDJe_Lpq-HpLLCSjbJV54,37320
11
+ nextrec/basic/loggers.py,sha256=jJUzUt_kMpjpV2Mqr7qBENWA1olEutTI7dFnpmndUUw,13845
12
+ nextrec/basic/metrics.py,sha256=nVz3AkKwsxj_M97CoZWyQj8_Y9ZM_Icvw_QCM6c33Bc,26262
13
+ nextrec/basic/model.py,sha256=x4QZc8lNXHzpOJLeN3qPs9G_kNC1KgKn_2K52v3-vLw,132691
14
14
  nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
15
- nextrec/basic/summary.py,sha256=MkzFwWJH3K76O0Gxqm3rVfbmWHqokvK2OBDe7WFQymo,17780
15
+ nextrec/basic/summary.py,sha256=hCDVB8127GSGtlfFnfEFHWXuvW5qjCSTwowNoA1i1xE,19815
16
16
  nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
17
17
  nextrec/data/batch_utils.py,sha256=TbnXYqYlmK51dJthaL6dO7LTn4wyp8221I-kdgvpvDE,3542
18
- nextrec/data/data_processing.py,sha256=lhuwYxWp4Ts2bbuLGDt2LmuPrOy7pNcKczd2uVcQ4ss,6476
18
+ nextrec/data/data_processing.py,sha256=xD6afp4zc217ddKfDtHtToyDpxMDWvoqD_Vk4pIpvXU,6333
19
19
  nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
20
20
  nextrec/data/dataloader.py,sha256=2sXwoiWxupKE-V1QYeZlXjK1yJyxhDtlOhknAnJF8Wk,19727
21
- nextrec/data/preprocessor.py,sha256=vZR7GnVALHMjQ3d-Bvd0mtkKj0nrkzndMib3vHY570Q,68496
21
+ nextrec/data/preprocessor.py,sha256=kOEfPy0t0M3jBA0kPIlwuSQYsuvn8yUNr_uE_NiulHU,49939
22
22
  nextrec/loss/__init__.py,sha256=rualGsY-IBvmM52q9eOBk0MyKcMkpkazcscOeDXi_SM,774
23
- nextrec/loss/grad_norm.py,sha256=YoE_XSIN1HOUcNq1dpfkIlWtMaB5Pu-SEWDaNgtRw1M,8316
23
+ nextrec/loss/grad_norm.py,sha256=I4jAs0f84I7MWmYZOMC0JRUNvBHZzhgpuox0hOtYWDg,7435
24
24
  nextrec/loss/listwise.py,sha256=mluxXQt9XiuWGvXA1nk4I0miqaKB6_GPVQqxLhAiJKs,5999
25
25
  nextrec/loss/pairwise.py,sha256=9fyH9p2u-N0-jAnNTq3X5Dje0ipj1dob8wp-yQKRra4,3493
26
26
  nextrec/loss/pointwise.py,sha256=09nzI1L5eP9raXnj3Q49bD9Clp_JmsSWUvEj7bkTzSw,7474
@@ -28,7 +28,6 @@ nextrec/models/generative/__init__.py,sha256=0MV3P-_ainPaTxmRBGWKUVCEt14KJvuvEHm
28
28
  nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  nextrec/models/multi_task/[pre]aitm.py,sha256=A2n0T4JEui-uHgbqITU5lpsmtnP14fQXRZM1peTPvhQ,6661
30
30
  nextrec/models/multi_task/[pre]snr_trans.py,sha256=k08tC-TI--a_Tt4_BmX0ZubzntyqwsejutYzbB5F4S4,9077
31
- nextrec/models/multi_task/[pre]star.py,sha256=BczXHPJtK7xyPbLO0fQ-w7qnzaBvLpyhG0CKMBUItCY,7057
32
31
  nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
32
  nextrec/models/multi_task/apg.py,sha256=dNveyottpHTd811_3KVZQQqgffDzmX_rY1EMiv1oKeo,13536
34
33
  nextrec/models/multi_task/cross_stitch.py,sha256=o48ZZFKWXz-w4MWboVdxDxifM0V3_n2Is2nZ2HoJkfw,9441
@@ -37,19 +36,19 @@ nextrec/models/multi_task/esmm.py,sha256=iTRMcjCznyKlwJMGWgqXl9CkQrvGTsfjEjpU_5X
37
36
  nextrec/models/multi_task/hmoe.py,sha256=Ge0AC7oLcEwvt75AwGG5MtClCX6L46H7HRDqD9h-kpQ,7584
38
37
  nextrec/models/multi_task/mmoe.py,sha256=vc19O4N4_64-_oucOGI-P_LaxsQaGvrYGg5Z2sxP49w,7609
39
38
  nextrec/models/multi_task/pepnet.py,sha256=I0MXLXAKBRTN1Vp1QOZD47IrljDDB3UOPHZYUH1cghU,13526
40
- nextrec/models/multi_task/ple.py,sha256=tr8mendvt87u7a4P54lQBYCHtsWDae7VrN9fOqHcXXo,12154
39
+ nextrec/models/multi_task/ple.py,sha256=h8xPqd7BFM76GbQL6RzEbi2EAB8igrkFGS4zqhtWDEc,12155
41
40
  nextrec/models/multi_task/poso.py,sha256=jr-RaLl5UnZc1HcEIK3HrNnc_g17BImyJb43d4rEXpE,18218
42
- nextrec/models/multi_task/share_bottom.py,sha256=NYK2B9k7BaXxzX4VN3dE0_quwsOQGM7JElm2tBdb9MY,5237
41
+ nextrec/models/multi_task/share_bottom.py,sha256=yM5--iqwEFQwpeQy_SmY8Vdo8a1Exi0-LNKSYeJz3hc,5238
43
42
  nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- nextrec/models/ranking/afm.py,sha256=2IylgCi7DBZzF7bdWFWvsEBlnRmyradxdyJiz5P4Nwg,9397
43
+ nextrec/models/ranking/afm.py,sha256=QPGBNlR9qjS5V5F-A8BnRBPcwwU6qcXQa96cNk9xlsY,9187
45
44
  nextrec/models/ranking/autoint.py,sha256=5g0tMtvkt3FiCjqmU76K7w7z-h3nMJbP0eVt5Ph4bbU,7259
46
45
  nextrec/models/ranking/dcn.py,sha256=qG3O2oL9ZLu9NBOJST06JeEh0k6UTXs-X1mQe_H4QCE,6679
47
46
  nextrec/models/ranking/dcn_v2.py,sha256=807wASeG56WYD7mEDEKaERw3r-Jpas6WhDzO0HGEk9I,10382
48
47
  nextrec/models/ranking/deepfm.py,sha256=1fuc-9f8DKDUH7EAY-XX5lITJ7Qw7ge6PTWFfQ7wld8,4404
49
- nextrec/models/ranking/dien.py,sha256=cluW44zOBHadf3oup6h7MpTexDXjH_VPkYcqsNBT80Y,18361
48
+ nextrec/models/ranking/dien.py,sha256=9OM2vz1Umqlvny6UUySTEz5g_jf8I4creox9J7oV82A,18320
50
49
  nextrec/models/ranking/din.py,sha256=J6-S72_KJYLrzUmdrh6aAx-Qc1C08ZY7lY_KFtAkJz0,8855
51
50
  nextrec/models/ranking/eulernet.py,sha256=0nOBfccfvukSZLUNOCcB_azCh52DGJq-s9doyEGMN8E,11484
52
- nextrec/models/ranking/ffm.py,sha256=v15x2-rExcrEYdcPf2IxEgx-ImDSevhkhi4Oe4GbloY,10512
51
+ nextrec/models/ranking/ffm.py,sha256=v8_whymddfY7u0F9rD4VTUQJXnyYYJYvbgwiR7DKuII,10432
53
52
  nextrec/models/ranking/fibinet.py,sha256=ejR1vNh5XM23SD7mfT686kuv3cmf5gKfkj0z_iMQqNA,7283
54
53
  nextrec/models/ranking/fm.py,sha256=SlFtbtnrZbeRnCHf-kUAMaeLV_wDgLiaBwPGeAO_ycM,3795
55
54
  nextrec/models/ranking/lr.py,sha256=0gmqPED-z7k4WVRy11WSLhYfS4bJQPRQTzQbe-rUITg,3227
@@ -58,38 +57,31 @@ nextrec/models/ranking/pnn.py,sha256=NU537ySMMXtndtPuiwCCQhZhBpGk6X6PRDa55GFdqt0
58
57
  nextrec/models/ranking/widedeep.py,sha256=i37lqEUZUtxNXSkkhQjBF9QP6yreGdV_jcphofHOzW4,4267
59
58
  nextrec/models/ranking/xdeepfm.py,sha256=vqDZRlrY9tgoqvnTXWe8xUbFx_28VztgVVc6ukzevRc,7507
60
59
  nextrec/models/representation/__init__.py,sha256=O3QHMMXBszwM-mTl7bA3wawNZvDGet-QIv6Ys5GHGJ8,190
61
- nextrec/models/representation/autorec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
- nextrec/models/representation/bpr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
- nextrec/models/representation/cl4srec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
- nextrec/models/representation/lightgcn.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
65
- nextrec/models/representation/mf.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
60
  nextrec/models/representation/rqvae.py,sha256=ytSXblWj3iYo76y_8mATm5w6C_YSAh2tq4MUFG-ngBc,29296
67
- nextrec/models/representation/s3rec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
68
61
  nextrec/models/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
62
  nextrec/models/retrieval/dssm.py,sha256=AmXDr62M6tivBg1P4MQ8f4cZnl1TGHxRBvZj05zWw64,6887
70
63
  nextrec/models/retrieval/dssm_v2.py,sha256=5ZH3dfNfRCDE69k8KG8BZJixaGOSVvQHB9uIDPMLPk4,5953
71
64
  nextrec/models/retrieval/mind.py,sha256=I0qVj39ApweRGW3qDNLca5vsNtJwRe7gBLh1pedsexY,14061
72
- nextrec/models/retrieval/sdm.py,sha256=1Y2gidG7WKuuGFaaQ8BcBGhQYoyyLPyhpRTo_xE1pmc,9987
65
+ nextrec/models/retrieval/sdm.py,sha256=h9TqVmSJ8YF7hgPci784nAlBg1LazB641c4iEeuiLDg,9956
73
66
  nextrec/models/retrieval/youtube_dnn.py,sha256=hLyR4liuusJIjRg4vuaSoSEecYgDICipXnNFiA3o3oY,6351
74
- nextrec/models/sequential/hstu.py,sha256=iZcYLp44r23nHYNhGwD25JfH85DBrFwHOTg1WpHvLe8,18983
75
- nextrec/models/sequential/sasrec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
+ nextrec/models/sequential/hstu.py,sha256=XFq-IERFg2ohqg03HkP6YinQaZUXljtYayUmvU-N_IY,18916
76
68
  nextrec/models/tree_base/__init__.py,sha256=ssGpU1viVidr1tgKCyvmkUXe32bUmwb5qPEkTh7ce70,342
77
- nextrec/models/tree_base/base.py,sha256=6aXYM_nA1XkKfo3KfWSYBoO6IqPIafLrdzJ334ZBaZY,26558
69
+ nextrec/models/tree_base/base.py,sha256=HcxUkctNgizY9LOBh2qXY6gUiYoy2JXA7j11NwUfWT4,26562
78
70
  nextrec/models/tree_base/catboost.py,sha256=hXINyx7iianwDxOZx3SLm0i-YP1jiC3HcAeqP9A2i4A,3434
79
71
  nextrec/models/tree_base/lightgbm.py,sha256=VilMU7SgfHR5LAaaoQo-tY1vkzpSvWovIrgaSeuJ1-A,2263
80
72
  nextrec/models/tree_base/xgboost.py,sha256=thOmDIC_nitno_k2mcH2cj2VcS07f9veTG01FMOO-28,1957
81
- nextrec/utils/__init__.py,sha256=Td-TC1IoTeb0KV-EgPy_vTHmmxmE6tO9q7Gmgsk1p-A,2672
82
- nextrec/utils/config.py,sha256=SUKVgWrsCkJvKLBwcHQHls859jhdPzXk7_3DYoyIXzE,20481
83
- nextrec/utils/console.py,sha256=RA3ZTjtUQXvueouSmXJNLkRjeUGQZesphwWjFMTbV4I,13577
73
+ nextrec/utils/__init__.py,sha256=a29_8gjGzTj1A8r4oOi6LUyFhC97v4ePLOPKbNtCJ6M,2702
74
+ nextrec/utils/config.py,sha256=Ngd_u8ZS5br4lIqrBJ_ecLquMF4KJi6TPAGqLZg8H4s,20485
75
+ nextrec/utils/console.py,sha256=RnSUplJnyanSQ6TyMQkP7S1j2rGMver1DbFVqNH6_1k,13581
84
76
  nextrec/utils/data.py,sha256=pSL96mWjWfW_RKE-qlUSs9vfiYnFZAaRirzA6r7DB6s,24994
85
77
  nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
86
- nextrec/utils/feature.py,sha256=E3NOFIW8gAoRXVrDhCSonzg8k7nMUZyZzMfCq9k73_A,623
87
78
  nextrec/utils/loss.py,sha256=GBWQGpDaYkMJySpdG078XbeUNXUC34PVqFy0AqNS9N0,4578
88
79
  nextrec/utils/model.py,sha256=PI9y8oWz1lhktgapZsiXb8rTr2NrFFlc80tr4yOFHik,5334
89
- nextrec/utils/torch_utils.py,sha256=UQpWS7F3nITYqvx2KRBaQJc9oTowRkIvowhuQLt6NFM,11953
90
- nextrec/utils/types.py,sha256=G88DHXFv-mbg-XY-7Xwwh1qvh6WM9jpAsBjw5VuBcio,1559
91
- nextrec-0.4.33.dist-info/METADATA,sha256=f9PQhSjuU2I32jNDBnVA5YA7K0yiTgnrV0S3QVPSHQU,23188
92
- nextrec-0.4.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
93
- nextrec-0.4.33.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
94
- nextrec-0.4.33.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
95
- nextrec-0.4.33.dist-info/RECORD,,
80
+ nextrec/utils/onnx_utils.py,sha256=KIVV_ELYzj3kCswfsSBZ1F2OnSwRJnXj7sxDBwBoBaA,8668
81
+ nextrec/utils/torch_utils.py,sha256=fxViD6Pah0qnXtpvem6ncuLV7y58Q_gyktfvkZQo_JI,12207
82
+ nextrec/utils/types.py,sha256=LFwYCBRo5WeYUh5LSCuyP1Lg9ez0Ih00Es3fUttGAFw,2273
83
+ nextrec-0.5.0.dist-info/METADATA,sha256=wE43qgqOUL8C9FFdfp3E6UfqMP5gjo24aGaG6wCYsdM,23532
84
+ nextrec-0.5.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
85
+ nextrec-0.5.0.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
86
+ nextrec-0.5.0.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
87
+ nextrec-0.5.0.dist-info/RECORD,,
@@ -1,192 +0,0 @@
1
- """
2
- Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
3
- Checkpoint: edit on 01/14/2026
4
- Author: Yang Zhou, zyaztec@gmail.com
5
- Reference:
6
- - [1] Sheng XR, Zhao L, Zhou G, Ding X, Dai B, Luo Q, Yang S, Lv J, Zhang C, Deng H, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427, 2021.
7
- URL: https://arxiv.org/abs/2101.11427
8
- - [2] MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation: https://github.com/alipay/MMLRec-A-Unified-Multi-Task-and-Multi-Scenario-Learning-Benchmark-for-Recommendation/
9
-
10
- STAR uses shared-specific linear layers to adapt representations per task while
11
- optionally reusing shared parameters. It can also apply domain-specific batch
12
- normalization on the first hidden layer when a domain mask is provided.
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from nextrec.basic.activation import activation_layer
21
- from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
- from nextrec.basic.heads import TaskHead
23
- from nextrec.basic.layers import DomainBatchNorm, EmbeddingLayer
24
- from nextrec.basic.model import BaseModel
25
- from nextrec.utils.types import TaskTypeInput, TaskTypeName
26
-
27
-
28
- class SharedSpecificLinear(nn.Module):
29
- """
30
- Shared-specific linear layer: task-specific projection plus optional shared one.
31
- """
32
-
33
- def __init__(
34
- self,
35
- input_dim: int,
36
- output_dim: int,
37
- nums_task: int,
38
- use_shared: bool = True,
39
- ) -> None:
40
- super().__init__()
41
- self.use_shared = use_shared
42
- self.shared = nn.Linear(input_dim, output_dim) if use_shared else None
43
- self.specific = nn.ModuleList(
44
- [nn.Linear(input_dim, output_dim) for _ in range(nums_task)]
45
- )
46
-
47
- def forward(self, x: torch.Tensor, task_idx: int) -> torch.Tensor:
48
- output = self.specific[task_idx](x)
49
- if self.use_shared and self.shared is not None:
50
- output = output + self.shared(x)
51
- return output
52
-
53
-
54
- class STAR(BaseModel):
55
- """
56
- STAR: shared-specific multi-task tower with optional domain-specific batch norm.
57
- """
58
-
59
- @property
60
- def model_name(self) -> str:
61
- return "STAR"
62
-
63
- @property
64
- def default_task(self) -> TaskTypeName | list[TaskTypeName]:
65
- nums_task = self.nums_task if hasattr(self, "nums_task") else None
66
- if nums_task is not None and nums_task > 0:
67
- return ["binary"] * nums_task
68
- return ["binary"]
69
-
70
- def __init__(
71
- self,
72
- dense_features: list[DenseFeature] | None = None,
73
- sparse_features: list[SparseFeature] | None = None,
74
- sequence_features: list[SequenceFeature] | None = None,
75
- target: list[str] | str | None = None,
76
- task: TaskTypeInput | list[TaskTypeInput] | None = None,
77
- mlp_params: dict | None = None,
78
- use_shared: bool = True,
79
- **kwargs,
80
- ) -> None:
81
- dense_features = dense_features or []
82
- sparse_features = sparse_features or []
83
- sequence_features = sequence_features or []
84
- mlp_params = mlp_params or {}
85
- mlp_params.setdefault("hidden_dims", [256, 128])
86
- mlp_params.setdefault("activation", "relu")
87
- mlp_params.setdefault("dropout", 0.0)
88
- mlp_params.setdefault("norm_type", "none")
89
-
90
- if target is None:
91
- target = []
92
- elif isinstance(target, str):
93
- target = [target]
94
-
95
- self.nums_task = len(target) if target else 1
96
-
97
- super().__init__(
98
- dense_features=dense_features,
99
- sparse_features=sparse_features,
100
- sequence_features=sequence_features,
101
- target=target,
102
- task=task,
103
- **kwargs,
104
- )
105
-
106
- if not mlp_params["hidden_dims"]:
107
- raise ValueError("mlp_params['hidden_dims'] must not be empty.")
108
-
109
- norm_type = mlp_params["norm_type"]
110
- self.dnn_use_bn = norm_type == "batch_norm"
111
- self.dnn_dropout = mlp_params["dropout"]
112
-
113
- self.embedding = EmbeddingLayer(features=self.all_features)
114
- input_dim = self.embedding.input_dim
115
-
116
- layer_units = [input_dim] + list(mlp_params["hidden_dims"])
117
- self.star_layers = nn.ModuleList(
118
- [
119
- SharedSpecificLinear(
120
- input_dim=layer_units[idx],
121
- output_dim=layer_units[idx + 1],
122
- nums_task=self.nums_task,
123
- use_shared=use_shared,
124
- )
125
- for idx in range(len(mlp_params["hidden_dims"]))
126
- ]
127
- )
128
- self.activation_layers = nn.ModuleList(
129
- [
130
- activation_layer(mlp_params["activation"])
131
- for _ in range(len(mlp_params["hidden_dims"]))
132
- ]
133
- )
134
- if mlp_params["dropout"] > 0:
135
- self.dropout_layers = nn.ModuleList(
136
- [
137
- nn.Dropout(mlp_params["dropout"])
138
- for _ in range(len(mlp_params["hidden_dims"]))
139
- ]
140
- )
141
- else:
142
- self.dropout_layers = nn.ModuleList(
143
- [nn.Identity() for _ in range(len(mlp_params["hidden_dims"]))]
144
- )
145
-
146
- self.domain_bn = (
147
- DomainBatchNorm(
148
- num_features=mlp_params["hidden_dims"][0], num_domains=self.nums_task
149
- )
150
- if self.dnn_use_bn
151
- else None
152
- )
153
-
154
- self.final_layer = SharedSpecificLinear(
155
- input_dim=mlp_params["hidden_dims"][-1],
156
- output_dim=1,
157
- nums_task=self.nums_task,
158
- use_shared=use_shared,
159
- )
160
- self.prediction_layer = TaskHead(
161
- task_type=self.task, task_dims=[1] * self.nums_task
162
- )
163
-
164
- self.grad_norm_shared_modules = ["embedding", "star_layers", "final_layer"]
165
- self.register_regularization_weights(
166
- embedding_attr="embedding",
167
- include_modules=["star_layers", "final_layer"],
168
- )
169
-
170
- def forward(
171
- self, x: dict[str, torch.Tensor], domain_mask: torch.Tensor | None = None
172
- ) -> torch.Tensor:
173
- input_flat = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
174
-
175
- task_outputs = []
176
- for task_idx in range(self.nums_task):
177
- output = input_flat
178
- for layer_idx, layer in enumerate(self.star_layers):
179
- output = layer(output, task_idx)
180
- output = self.activation_layers[layer_idx](output)
181
- output = self.dropout_layers[layer_idx](output)
182
- if (
183
- layer_idx == 0
184
- and self.dnn_use_bn
185
- and self.domain_bn is not None
186
- and domain_mask is not None
187
- ):
188
- output = self.domain_bn(output, domain_mask)
189
- task_outputs.append(self.final_layer(output, task_idx))
190
-
191
- logits = torch.cat(task_outputs, dim=1)
192
- return self.prediction_layer(logits)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
nextrec/utils/feature.py DELETED
@@ -1,29 +0,0 @@
1
- """
2
- Feature processing utilities for NextRec
3
-
4
- Date: create on 03/12/2025
5
- Checkpoint: edit on 27/12/2025
6
- Author: Yang Zhou, zyaztec@gmail.com
7
- """
8
-
9
- import numbers
10
- from typing import Any
11
-
12
-
13
- def to_list(value: str | list[str] | None) -> list[str]:
14
- if value is None:
15
- return []
16
- if isinstance(value, str):
17
- return [value]
18
- return list(value)
19
-
20
-
21
- def as_float(value: Any) -> float | None:
22
- if isinstance(value, numbers.Number):
23
- return float(value)
24
- if hasattr(value, "item"):
25
- try:
26
- return float(value.item())
27
- except Exception:
28
- return None
29
- return None