liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -5,45 +5,48 @@ liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EB
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNicXwZIjmBU,5454
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
- liger_kernel/chunked_loss/functional.py,sha256=THWWpCnRVhTVfnPnyvQjdBvo1JDtxhwLmtZE_yiBBqM,817
9
- liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=FJh7k3sry-fqnBApLSngf7h-lHQEiXtOY_tiRDVanPM,11022
8
+ liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
10
11
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
11
- liger_kernel/chunked_loss/fused_linear_rlhf.py,sha256=wGujqwLz91mOE9MmdenhBIKvbmswhwtINMCpcP7D74c,9050
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
- liger_kernel/chunked_loss/grpo_loss.py,sha256=axED3628yKODu1v7PMAvSd08WZqwNQvJOTUYMgcihdQ,6665
14
- liger_kernel/chunked_loss/jsd_loss.py,sha256=j2_1AYLu0FW2VQJIEr1J1qHsWd5VUo6C3aedglHVH4Y,6771
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=eh6mErFUZsSQrgRRefuXdk-LG0gS7Rg2r-U9CtbH3eU,10834
14
+ liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- liger_kernel/ops/cross_entropy.py,sha256=yKKhN63I7r9NxJye4wTLBvvKAyrXQt6jf4nBo3lJyVg,18860
19
+ liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
20
+ liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
20
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
21
- liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
22
+ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
22
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
23
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
24
- liger_kernel/ops/jsd.py,sha256=0jNeRxpcNI5ckxCdoCNyO5GEedLIuzx3lz6KAiksc4o,6109
25
- liger_kernel/ops/kl_div.py,sha256=MnfuYqqQESON1X2Swy064x1urKtMFdgeSWd60VttBXI,8420
26
- liger_kernel/ops/layer_norm.py,sha256=6roQjioyg-9O2qLPV8nL4U0-5UH80tdzOMTWwjvDnn8,7961
25
+ liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
26
+ liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
27
+ liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
27
28
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
28
- liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
29
+ liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
29
30
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
30
31
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
31
32
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
32
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
33
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
34
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
35
- liger_kernel/transformers/__init__.py,sha256=4bwMPQhGHxmZ-WTFAMD9m-s0PYyfcvIRxhq_h3b0Wz0,2363
36
+ liger_kernel/transformers/__init__.py,sha256=SH30Pt2ZqyQY-mmWQldg_r-5koowuymTIoU4F4e1KHk,6419
36
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
37
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
38
- liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
39
+ liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
+ liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
39
41
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
40
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
41
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
44
+ liger_kernel/transformers/gema3_rms.py,sha256=LTmZOXe6WEnv6ZroW-kU1TE2B36-z5v8OLmKr3XEVFo,353
42
45
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
43
46
  liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
44
47
  liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
45
48
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
46
- liger_kernel/transformers/monkey_patch.py,sha256=9ud9tv1LI9WIa9UDu0abGIiusIIkayO1fjAUMWgwwT0,47096
49
+ liger_kernel/transformers/monkey_patch.py,sha256=QpfNU7MmVDGlBWIZ2RLTSyh0vuZ-si7H37SL-qOliUs,64393
47
50
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
48
51
  liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
49
52
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -52,24 +55,28 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
52
55
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
53
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
54
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
56
- liger_kernel/transformers/model/gemma2.py,sha256=27NcoZjEqP9Lqb4Wf0EKqTbr2HTGiHPhrVyPCRGPz6s,10767
57
- liger_kernel/transformers/model/llama.py,sha256=3LJFXKFDKvEakaWPc_NicSFst4Y_hdSMrdl1UDK1EcA,10330
58
- liger_kernel/transformers/model/mistral.py,sha256=MVRksI5_j_8WJu8znOHKCdSI5jSu-S7cdFYzt9m_vIQ,5180
59
- liger_kernel/transformers/model/mixtral.py,sha256=jpZJkpl625Q-JHWarj2MqT5mRaSsiCtg0c9vVyvOdCY,11430
60
- liger_kernel/transformers/model/mllama.py,sha256=qWexBdskuN3gPJvPUwt4J0nU675tGD6W7wxgRZ9Bifg,11145
61
- liger_kernel/transformers/model/olmo2.py,sha256=yyksS6E4fuWd8asEW8rEDBKqZpFmP4ITCM_bjIDZaoY,5124
62
- liger_kernel/transformers/model/phi3.py,sha256=biRa8fph9qdnQmkD9I21t5XIjpIt1i6UKU4uk8Up8pU,10292
63
- liger_kernel/transformers/model/qwen2.py,sha256=14UuPjxB-tjqWn85Tn4fqBFvVhVsth5iPEt8kJSMiew,9581
64
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=l71WBfX0ptrisoURIRwXJH7MQ2vGKOvcRYMNsrydwlQ,9455
65
- liger_kernel/transformers/model/qwen2_vl.py,sha256=yMLqsfSYcvhClUpTUjGoADiOxfLB2B8240VdrPP0c8s,9851
58
+ liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
+ liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
60
+ liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
61
+ liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
62
+ liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
+ liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
64
+ liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
+ liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
+ liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
67
+ liger_kernel/transformers/model/olmo2.py,sha256=rSzSALikEGkk0w3PLNQPrqg-ioN8TpWCXkAlg3LtCdI,5189
68
+ liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
69
+ liger_kernel/transformers/model/phi3.py,sha256=ebITCrmwmb4z66CbSrZl1kD6BsP52IcSAR8uwUTp9nc,10455
70
+ liger_kernel/transformers/model/qwen2.py,sha256=QaoTDrJv2wIuAM8QMoeWVvgNl0N5gHzIrew9QGG7kXc,9744
71
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
72
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=zo4O9fShNHYqSLrzLGqQYWSMtJI6UHaSY7zvMCYWyD8,9685
66
73
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
67
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
68
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
69
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
70
- liger_kernel-0.5.5.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
71
- liger_kernel-0.5.5.dist-info/METADATA,sha256=PRpIrVa7cvCW-D7zMA6qpsQ1iJogiK6POWpYUbYHYr4,22411
72
- liger_kernel-0.5.5.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
73
- liger_kernel-0.5.5.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
74
- liger_kernel-0.5.5.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
75
- liger_kernel-0.5.5.dist-info/RECORD,,
77
+ liger_kernel-0.5.7.dist-info/licenses/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel-0.5.7.dist-info/licenses/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
79
+ liger_kernel-0.5.7.dist-info/METADATA,sha256=DJsNdftnn0V8KyBJyssCU0NnPtOvl3J7ldrKWs_aPRs,23340
80
+ liger_kernel-0.5.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
81
+ liger_kernel-0.5.7.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel-0.5.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,240 +0,0 @@
1
- from abc import abstractmethod
2
- from functools import partial
3
-
4
- import torch
5
- import torch.nn.functional as F
6
-
7
-
8
- class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
- @abstractmethod
10
- def rlhf_loss_fn(*args, **kwargs):
11
- """
12
- To be extended by subclasses.
13
- """
14
- raise NotImplementedError("RLHF loss function must be implemented.")
15
-
16
- @staticmethod
17
- def forward(
18
- cls,
19
- ctx,
20
- _input,
21
- weight,
22
- attention_mask,
23
- rewards,
24
- bias=None,
25
- num_generations=4,
26
- beta=0.1,
27
- compiled=True,
28
- use_ref_model=False,
29
- ref_input=None,
30
- ref_weight=None,
31
- ref_bias=None,
32
- chunk_size=1,
33
- ):
34
- """Chunked forward pass for RLHF loss computation.
35
-
36
- Args:
37
- cls: The class
38
- ctx: Context for backward
39
- _input: Input tensor
40
- weight: Weight tensor
41
- attention_mask: Attention mask tensor
42
- rewards: Rewards tensor
43
- bias: Bias tensor
44
- num_generations: Number of generations per prompt
45
- beta: Weight for the KL penalty
46
- compiled: Whether to use torch compile
47
- use_ref_model: Whether to use a reference model
48
- ref_input: Reference model input tensor
49
- ref_weight: Reference model weight tensor
50
- ref_bias: Reference model bias tensor
51
- chunk_size: Size of chunks for processing in other loss modules
52
- """
53
- # Save for backward
54
- ctx.beta = beta
55
- ctx.rewards = rewards
56
-
57
- # Initialize accumulators
58
- loss_acc = torch.zeros((), device=_input.device)
59
- grad_weight = torch.zeros_like(weight) # [V, H]
60
- grad_inputs = []
61
- grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
62
- aggregated_metrics = []
63
-
64
- # Create a partial function with fixed arguments
65
- compute_loss = partial(
66
- LigerFusedLinearRLHFBase._compute_chunk_loss,
67
- beta=beta,
68
- use_ref_model=use_ref_model,
69
- ref_weight=ref_weight,
70
- ref_bias=ref_bias,
71
- rlhf_loss_fn=cls.rlhf_loss_fn,
72
- )
73
-
74
- def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
75
- """Fused forward and backward for a chunk."""
76
- if bias is not None:
77
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
78
- input_chunk, # arg 0
79
- weight, # arg 1
80
- attention_mask_chunk, # arg 2
81
- rewards_chunk, # arg 3
82
- ref_input_chunk, # arg 4
83
- bias, # arg 5
84
- )
85
- else:
86
- return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
87
- input_chunk, # arg 0
88
- weight, # arg 1
89
- attention_mask_chunk, # arg 2
90
- rewards_chunk, # arg 3
91
- ref_input_chunk, # arg 4
92
- )
93
-
94
- def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
95
- if bias is not None:
96
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
97
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
98
- )
99
- grad_bias.add_(chunk_grad_bias)
100
- else:
101
- (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
102
- input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
103
- )
104
-
105
- # Accumulate gradients and loss
106
- grad_weight.add_(chunk_grad_weight)
107
- grad_inputs.append(chunk_grad_input)
108
- loss_acc.add_(chunk_loss)
109
-
110
- # Initialize storage for metrics on first chunk
111
- if len(aggregated_metrics) == 0:
112
- for metric in chunk_metrics:
113
- if metric.ndim == 0:
114
- aggregated_metrics.append(torch.zeros((), device=metric.device))
115
- else:
116
- aggregated_metrics.append([])
117
-
118
- # Accumulate metrics
119
- for i, metric in enumerate(chunk_metrics):
120
- if metric.ndim == 0:
121
- aggregated_metrics[i].add_(metric)
122
- else:
123
- aggregated_metrics[i].append(metric)
124
-
125
- if compiled:
126
- accumulate_chunk = torch.compile(accumulate_chunk)
127
-
128
- # Process input in chunks based on num_generations
129
- chunks = max(1, _input.shape[0] // num_generations)
130
- _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
131
- _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
132
- _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
133
- _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks
134
-
135
- for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
136
- _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
137
- ):
138
- # Mark dynamic dimensions
139
- torch._dynamo.mark_dynamic(input_chunk, 1)
140
- torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
141
- if ref_input_chunk is not None:
142
- torch._dynamo.mark_dynamic(ref_input_chunk, 1)
143
-
144
- accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)
145
-
146
- # Scale accumulated loss by number of chunks since we're averaging
147
- loss_acc = loss_acc / chunks
148
-
149
- # Combine gradients
150
- grad_input = torch.cat(grad_inputs, dim=0)
151
-
152
- # Save for backward
153
- ctx.save_for_backward(grad_input, grad_weight, grad_bias)
154
-
155
- # Finalize metrics
156
- final_metrics = []
157
- for metric in aggregated_metrics:
158
- if isinstance(metric, list):
159
- final_metrics.append(torch.cat(metric, dim=0))
160
- else:
161
- final_metrics.append(metric / chunks)
162
-
163
- return loss_acc, tuple(final_metrics)
164
-
165
- @staticmethod
166
- def _compute_chunk_loss(
167
- input_chunk,
168
- weight,
169
- attention_mask_chunk,
170
- rewards_chunk,
171
- ref_input_chunk=None,
172
- bias=None,
173
- beta=0.1,
174
- use_ref_model=False,
175
- ref_weight=None,
176
- ref_bias=None,
177
- rlhf_loss_fn=None,
178
- ):
179
- """Compute loss for a single chunk."""
180
- # Get policy log probabilities using chunk_forward
181
- log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)
182
-
183
- # Get reference log probabilities if needed
184
- ref_log_probs = None
185
- if use_ref_model and ref_input_chunk is not None:
186
- with torch.no_grad():
187
- ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)
188
-
189
- # Compute chunk loss and metrics using the provided loss function
190
- chunk_loss, chunk_metrics = rlhf_loss_fn(
191
- log_probs=log_probs,
192
- attention_mask=attention_mask_chunk,
193
- rewards=rewards_chunk,
194
- ref_log_probs=ref_log_probs,
195
- beta=beta,
196
- )
197
-
198
- return chunk_loss, (logits_mean, *chunk_metrics)
199
-
200
- @staticmethod
201
- def chunk_forward(input_chunk, weight, bias=None):
202
- """Forward pass computation for a single chunk without explicit reshaping."""
203
- # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
204
- logits = torch.matmul(input_chunk, weight.t())
205
- if bias is not None:
206
- logits = logits + bias # Broadcasts bias to [B, T, V]
207
-
208
- # Compute log probabilities using softmax over the last dimension
209
- log_probs = F.log_softmax(logits.float(), dim=-1)
210
-
211
- # Monitoring: compute mean of logits
212
- batch_size, seq_len, _ = input_chunk.shape
213
- logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])
214
- return log_probs, logits, logits_mean
215
-
216
- @staticmethod
217
- def backward(ctx, grad_output, *grad_metrics):
218
- """Backward pass for RLHF loss."""
219
- grad_input, grad_weight, grad_bias = ctx.saved_tensors
220
- if grad_output != 1.0:
221
- grad_input = grad_input * grad_output
222
- grad_weight = grad_weight * grad_output
223
- if grad_bias is not None:
224
- grad_bias = grad_bias * grad_output
225
-
226
- return (
227
- grad_input,
228
- grad_weight,
229
- None, # grad_attention_mask
230
- None, # grad_rewards
231
- grad_bias,
232
- None, # grad_num_generations
233
- None, # grad_beta
234
- None, # grad_compiled
235
- None, # grad_use_ref_model
236
- None, # grad_ref_input
237
- None, # grad_ref_weight
238
- None, # grad_ref_bias
239
- None, # grad_chunk_size
240
- )