liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ from triton.language.extra.libdevice import rsqrt
16
+ except ModuleNotFoundError:
17
+ from triton.language.extra.cuda.libdevice import rsqrt
18
+ else:
19
+ from triton.language.math import rsqrt
20
+
21
+
22
+ @triton.jit
23
+ def _poly_norm_forward_kernel(
24
+ Y_ptr,
25
+ Y_row_stride,
26
+ X_ptr,
27
+ X_row_stride,
28
+ W_ptr, # weight: [3] for [w0, w1, w2]
29
+ B_ptr, # bias: scalar
30
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
31
+ RSTD_row_stride,
32
+ n_cols,
33
+ eps,
34
+ BLOCK_SIZE: tl.constexpr,
35
+ ):
36
+ """
37
+ PolyNorm formula:
38
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
39
+ where norm(u) = u / sqrt(mean(u²) + ε)
40
+
41
+ Reference:
42
+ 1. https://github.com/BryceZhuo/PolyCom/
43
+ 2. https://arxiv.org/pdf/2411.03884
44
+
45
+ Cache rstd values for backward pass
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Load pointers
52
+ Y_ptr += row_idx * Y_row_stride
53
+ X_ptr += row_idx * X_row_stride
54
+ RSTD_ptr += row_idx * RSTD_row_stride
55
+
56
+ # Load input row
57
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
58
+
59
+ # Load weights and bias
60
+ w0 = tl.load(W_ptr + 0)
61
+ w1 = tl.load(W_ptr + 1)
62
+ w2 = tl.load(W_ptr + 2)
63
+ b = tl.load(B_ptr)
64
+
65
+ # Compute x³, x², x
66
+ X_pow3 = X_row * X_row * X_row
67
+ X_pow2 = X_row * X_row
68
+ X_pow1 = X_row
69
+
70
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
71
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
72
+ rstd_3 = rsqrt(mean_square_3 + eps)
73
+ norm_x3 = X_pow3 * rstd_3
74
+
75
+ # Compute norm(x²)
76
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
77
+ rstd_2 = rsqrt(mean_square_2 + eps)
78
+ norm_x2 = X_pow2 * rstd_2
79
+
80
+ # Compute norm(x)
81
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
82
+ rstd_1 = rsqrt(mean_square_1 + eps)
83
+ norm_x1 = X_pow1 * rstd_1
84
+
85
+ # Cache rstd values for backward
86
+ tl.store(RSTD_ptr + 0, rstd_3)
87
+ tl.store(RSTD_ptr + 1, rstd_2)
88
+ tl.store(RSTD_ptr + 2, rstd_1)
89
+
90
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
91
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
92
+
93
+ # Store output
94
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
95
+
96
+
97
+ @triton.jit
98
+ def _poly_norm_backward_kernel(
99
+ dY_ptr,
100
+ dY_row_stride,
101
+ dX_ptr,
102
+ dX_row_stride,
103
+ X_ptr,
104
+ X_row_stride,
105
+ W_ptr,
106
+ RSTD_ptr,
107
+ RSTD_row_stride,
108
+ dW_ptr, # shape: (n_programs, 3)
109
+ dW_row_stride,
110
+ dB_ptr, # shape: (n_programs,)
111
+ n_rows,
112
+ n_cols,
113
+ rows_per_program: tl.constexpr,
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ """
117
+ PolyNorm Backward Kernel Gradient:
118
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
119
+
120
+ where:
121
+ - D_p = RMS(x^p) = 1/rstd_p
122
+ - S_p = sum(grad * x^p) over the row
123
+ - d = n_cols
124
+ - p ∈ {3, 2, 1}
125
+ """
126
+ row_block_id = tl.program_id(0).to(tl.int64)
127
+ row_start = row_block_id * rows_per_program
128
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
129
+ col_offsets = tl.arange(0, BLOCK_SIZE)
130
+ mask = col_offsets < n_cols
131
+
132
+ # Initialize accumulators for weight and bias gradients (scalars)
133
+ dW0_acc = 0.0
134
+ dW1_acc = 0.0
135
+ dW2_acc = 0.0
136
+ dB_acc = 0.0
137
+
138
+ # Load weights
139
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
140
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
141
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
142
+
143
+ dY_ptr += row_start * dY_row_stride
144
+ dX_ptr += row_start * dX_row_stride
145
+ X_ptr += row_start * X_row_stride
146
+ RSTD_ptr += row_start * RSTD_row_stride
147
+
148
+ for _ in range(row_start, row_end):
149
+ # Load input and gradient
150
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
+
153
+ # Load cached rstd values
154
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
157
+
158
+ # Compute powers
159
+ X_pow3 = X_row * X_row * X_row
160
+ X_pow2 = X_row * X_row
161
+ X_pow1 = X_row
162
+
163
+ # Accumulate bias gradient: dB = sum(dY)
164
+ dB_acc += tl.sum(dY_row, axis=0)
165
+
166
+ # Compute gradient w.r.t. input using closed-form formula
167
+ # For p=3: ∂L/∂x from w0 * norm(x³)
168
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
169
+ grad_x_3 = w0 * (
170
+ 3.0 * X_pow2 * rstd_3 * dY_row
171
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
172
+ )
173
+
174
+ # For p=2: ∂L/∂x from w1 * norm(x²)
175
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
176
+ grad_x_2 = w1 * (
177
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
178
+ )
179
+
180
+ # For p=1: ∂L/∂x from w2 * norm(x)
181
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
182
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
183
+
184
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
185
+ dW0_acc += rstd_3 * S_3
186
+ dW1_acc += rstd_2 * S_2
187
+ dW2_acc += rstd_1 * S_1
188
+
189
+ # Total gradient
190
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
+
192
+ # Store gradient
193
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
+
195
+ # Update pointers
196
+ dY_ptr += dY_row_stride
197
+ dX_ptr += dX_row_stride
198
+ X_ptr += X_row_stride
199
+ RSTD_ptr += RSTD_row_stride
200
+
201
+ # Store accumulated gradients (scalars)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
203
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
204
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
205
+ tl.store(dB_ptr + row_block_id, dB_acc)
206
+
207
+
208
+ def poly_norm_forward(X, W, B, eps=1e-6):
209
+ """
210
+ PolyNorm Forward Pass
211
+
212
+ Args:
213
+ X: input tensor of shape (*, H) where H is hidden dimension
214
+ W: weight tensor of shape (3,) for [w0, w1, w2]
215
+ B: bias scalar tensor
216
+ eps: epsilon for numerical stability
217
+
218
+ Returns:
219
+ Y: output tensor of same shape as X
220
+ X: reshaped input (for backward)
221
+ RSTD: cached rstd values (for backward)
222
+ BLOCK_SIZE: block size used
223
+ num_warps: number of warps used
224
+ """
225
+ shape = X.shape
226
+ dim = shape[-1]
227
+ X = X.view(-1, dim)
228
+ n_rows, n_cols = X.shape
229
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
230
+
231
+ # RSTD is to cache rstd for each row
232
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
233
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
234
+
235
+ # Check constraints
236
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
237
+ assert B.numel() == 1, "Bias must be a scalar"
238
+
239
+ # XPU-specific optimization
240
+ kernel_args = {}
241
+ if X.device.type == "xpu":
242
+ kernel_args["grf_mode"] = "large"
243
+
244
+ # Launch kernel
245
+ _poly_norm_forward_kernel[(n_rows,)](
246
+ Y,
247
+ Y.stride(0),
248
+ X,
249
+ X.stride(0),
250
+ W,
251
+ B,
252
+ RSTD,
253
+ RSTD.stride(0),
254
+ n_cols,
255
+ eps,
256
+ BLOCK_SIZE=BLOCK_SIZE,
257
+ num_warps=num_warps,
258
+ **kernel_args,
259
+ )
260
+
261
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
262
+
263
+
264
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
265
+ """
266
+ PolyNorm Backward Pass
267
+
268
+ Args:
269
+ dY: gradient of output
270
+ X: input tensor (already reshaped to 2D)
271
+ W: weight tensor
272
+ RSTD: cached rstd values from forward
273
+ BLOCK_SIZE: block size from forward
274
+ num_warps: number of warps from forward
275
+ in_place: whether to in-place modify dY to store dX (saves memory)
276
+
277
+ Returns:
278
+ dX: gradient w.r.t. input
279
+ dW: gradient w.r.t. weight
280
+ dB: gradient w.r.t. bias
281
+ """
282
+ shape = dY.shape
283
+ dim = shape[-1]
284
+ dY = dY.view(-1, dim)
285
+ n_rows, n_cols = dY.shape
286
+
287
+ # Get number of SMs for parallelization
288
+ import math
289
+
290
+ sm_count = 1
291
+ if X.device.type == "cuda":
292
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
293
+ elif X.device.type == "xpu":
294
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295
+ elif X.device.type == "npu":
296
+ sm_count = get_npu_multi_processor_count()
297
+
298
+ # Allocate or reuse gradients
299
+ if in_place is True:
300
+ dX = dY
301
+ else:
302
+ dX = torch.zeros_like(dY)
303
+
304
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
305
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
306
+
307
+ rows_per_program = math.ceil(n_rows / sm_count)
308
+ grid = (sm_count,)
309
+
310
+ # XPU-specific optimization
311
+ kernel_args = {}
312
+ if X.device.type == "xpu":
313
+ kernel_args["grf_mode"] = "large"
314
+
315
+ # Launch backward kernel
316
+ _poly_norm_backward_kernel[grid](
317
+ dY,
318
+ dY.stride(0),
319
+ dX,
320
+ dX.stride(0),
321
+ X,
322
+ X.stride(0),
323
+ W,
324
+ RSTD,
325
+ RSTD.stride(0),
326
+ _dW,
327
+ _dW.stride(0),
328
+ _dB,
329
+ n_rows,
330
+ n_cols,
331
+ rows_per_program,
332
+ BLOCK_SIZE=BLOCK_SIZE,
333
+ num_warps=num_warps,
334
+ **kernel_args,
335
+ )
336
+
337
+ # Reduce gradients across SMs
338
+ dX = dX.view(*shape)
339
+ dW = _dW.sum(dim=0).to(W.dtype)
340
+ dB = _dB.sum().to(W.dtype)
341
+
342
+ return dX, dW, dB
343
+
344
+
345
+ class LigerPolyNormFunction(torch.autograd.Function):
346
+ """
347
+ PolyNorm Function with forward and backward pass
348
+
349
+ PolyNorm formula:
350
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
351
+ where norm(u) = u / sqrt(mean(u²) + ε)
352
+
353
+ Backward uses closed-form gradient:
354
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
355
+ """
356
+
357
+ @staticmethod
358
+ @ensure_contiguous
359
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
360
+ """
361
+ Args:
362
+ X: input tensor of shape (B, T, H) or (BxT, H)
363
+ W: weight tensor of shape (3,) for [w0, w1, w2]
364
+ B: bias scalar
365
+ eps: epsilon for numerical stability
366
+ in_place: whether to in-place modify grad_output in backward (saves memory)
367
+
368
+ Returns:
369
+ Y: output tensor of same shape as X
370
+ """
371
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
372
+ ctx.BLOCK_SIZE = BLOCK_SIZE
373
+ ctx.num_warps = num_warps
374
+ ctx.in_place = in_place
375
+ ctx.save_for_backward(X, W, RSTD)
376
+ return Y
377
+
378
+ @staticmethod
379
+ @ensure_contiguous
380
+ def backward(ctx, grad_output):
381
+ """
382
+ Args:
383
+ grad_output: gradient of output
384
+
385
+ Returns:
386
+ dX, dW, dB: gradients w.r.t. X, W, B
387
+ """
388
+ X, W, RSTD = ctx.saved_tensors
389
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
390
+ return dX, dW, dB, None, None
@@ -0,0 +1,222 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_qwen2vl_mrope(
8
+ q_ptr,
9
+ k_ptr,
10
+ cos,
11
+ sin,
12
+ sl,
13
+ bs: tl.constexpr,
14
+ n_qh: tl.constexpr,
15
+ n_kh: tl.constexpr,
16
+ hd: tl.constexpr,
17
+ pad_n_qh: tl.constexpr,
18
+ pad_n_kh: tl.constexpr,
19
+ pad_hd: tl.constexpr,
20
+ mrope_section_t: tl.constexpr,
21
+ mrope_section_h: tl.constexpr,
22
+ BLOCK_SIZE: tl.constexpr,
23
+ BACKWARD_PASS: tl.constexpr = False,
24
+ ):
25
+ pid = tl.program_id(0)
26
+
27
+ # locate start address
28
+ q_ptr = q_ptr + pid * (n_qh * hd)
29
+ k_ptr = k_ptr + pid * (n_kh * hd)
30
+
31
+ # ####################################################################
32
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
33
+ # m of this program instance
34
+ # ####################################################################
35
+
36
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
37
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
38
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
39
+ # and pid % sl to get the sequence index.
40
+ # 2. We only need the left half of cos and sin matrix because the right half is just
41
+ # a clone of the left half.
42
+ t_end = mrope_section_t
43
+ h_end = t_end + mrope_section_h
44
+
45
+ t_cos = cos + pid * hd
46
+ h_cos = t_cos + bs * sl * hd
47
+ w_cos = h_cos + bs * sl * hd
48
+ t_sin = sin + pid * hd
49
+ h_sin = t_sin + bs * sl * hd
50
+ w_sin = h_sin + bs * sl * hd
51
+
52
+ cos_offsets = tl.arange(0, pad_hd // 2)
53
+ t_mask = cos_offsets < t_end
54
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62
+ cos_row = t_cos_row + h_cos_row + w_cos_row
63
+ sin_row = t_sin_row + h_sin_row + w_sin_row
64
+
65
+ # ####################################################################
66
+ # Load the left and right half of q and k for the current
67
+ # program instance (i.e. for the current token) separately
68
+ # ####################################################################
69
+ # left half of the head
70
+ first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
71
+ first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
73
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
74
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
75
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
76
+
77
+ # right half of the head
78
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
79
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
80
+ second_q_mask = first_q_mask
81
+ second_k_mask = first_k_mask
82
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
83
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
84
+
85
+ if not BACKWARD_PASS:
86
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
87
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
88
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
89
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
90
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
91
+
92
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
93
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
94
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
95
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
96
+ else:
97
+ # with some math, we can get:
98
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
99
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
100
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
101
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
102
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
103
+
104
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
105
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
106
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
107
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
108
+
109
+
110
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
111
+ # transpose it back to the physical shape because Triton looks at the physical storage
112
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
113
+ q = q.transpose(1, 2)
114
+ k = k.transpose(1, 2)
115
+
116
+ batch_size, seq_len, n_q_head, head_dim = q.shape
117
+ n_kv_head = k.shape[2]
118
+ pad_hd = triton.next_power_of_2(head_dim)
119
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
120
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
121
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
122
+
123
+ n_row = batch_size * seq_len
124
+
125
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
126
+ q = q.contiguous()
127
+ k = k.contiguous()
128
+ cos = cos.contiguous()
129
+ sin = sin.contiguous()
130
+
131
+ _triton_qwen2vl_mrope[(n_row,)](
132
+ q,
133
+ k,
134
+ cos,
135
+ sin,
136
+ seq_len,
137
+ batch_size,
138
+ n_q_head,
139
+ n_kv_head,
140
+ head_dim,
141
+ pad_n_q_head,
142
+ pad_n_kv_head,
143
+ pad_hd,
144
+ mrope_section[0],
145
+ mrope_section[1],
146
+ BLOCK_SIZE=BLOCK_SIZE,
147
+ BACKWARD_PASS=False,
148
+ )
149
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
150
+
151
+
152
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
153
+ dq = dq.transpose(1, 2)
154
+ dk = dk.transpose(1, 2)
155
+
156
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
157
+ n_kv_head = dk.shape[2]
158
+ pad_hd = triton.next_power_of_2(head_dim)
159
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
160
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
161
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
162
+
163
+ n_row = batch_size * seq_len
164
+
165
+ # ensure dq and dk are contiguous
166
+ dq = dq.contiguous()
167
+ dk = dk.contiguous()
168
+
169
+ # backward is similar to forward except swapping few ops
170
+ _triton_qwen2vl_mrope[(n_row,)](
171
+ dq,
172
+ dk,
173
+ cos,
174
+ sin,
175
+ seq_len,
176
+ batch_size,
177
+ n_q_head,
178
+ n_kv_head,
179
+ head_dim,
180
+ pad_n_q_head,
181
+ pad_n_kv_head,
182
+ pad_hd,
183
+ mrope_section[0],
184
+ mrope_section[1],
185
+ BLOCK_SIZE=BLOCK_SIZE,
186
+ BACKWARD_PASS=True,
187
+ )
188
+ return dq.transpose(1, 2), dk.transpose(1, 2)
189
+
190
+
191
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
192
+ """
193
+ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
194
+
195
+ Please find the corresponding HuggingFace implementation here:
196
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
197
+ """
198
+
199
+ @staticmethod
200
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
201
+ """
202
+ q size: (bsz, n_q_head, seq_len, head_dim)
203
+ k size: (bsz, n_kv_head, seq_len, head_dim)
204
+ cos size: (3, bsz, seq_len, head_dim)
205
+ sin size: (3, bsz, seq_len, head_dim)
206
+ """
207
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
208
+ ctx.save_for_backward(cos, sin)
209
+ ctx.mrope_section = mrope_section
210
+ return q, k
211
+
212
+ def backward(ctx, dq, dk):
213
+ """
214
+ dq size: (bsz, n_q_head, seq_len, head_dim)
215
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
216
+ cos size: (3, bsz, seq_len, head_dim)
217
+ sin size: (3, bsz, seq_len, head_dim)
218
+ """
219
+ cos, sin = ctx.saved_tensors
220
+ mrope_section = ctx.mrope_section
221
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
222
+ return dq, dk, None, None, None, None