liger-kernel 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- import math
2
1
  import operator
3
2
 
4
3
  import torch
@@ -43,30 +42,45 @@ def _layer_norm_forward_kernel(
43
42
  https://arxiv.org/abs/1607.06450
44
43
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
44
  """
46
- row_idx = tl.program_id(0)
45
+ row_idx = tl.program_id(0).to(tl.int64)
47
46
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
47
  mask = col_offsets < n_cols
49
48
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
49
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
50
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
51
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
52
+ W_f32 = W_row.to(tl.float32)
53
+ B_f32 = B_row.to(tl.float32)
54
+
55
+ # Calculate pointers for this row
56
+ row_X_ptr = X_ptr + row_idx * X_row_stride
57
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
58
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
59
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
60
+
61
+ # Load input data and convert to fp32 for numerical stability
62
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
63
+ X_f32 = X_row.to(tl.float32)
64
+
65
+ # Compute statistics in fp32 for numerical stability
66
+ n_cols_f32 = n_cols.to(tl.float32)
67
+ mean = tl.sum(X_f32, axis=0) / n_cols_f32
68
+ X_centered = X_f32 - mean
69
+ # Apply mask to variance calculation to exclude contributions from masked elements
70
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
71
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
62
72
  rstd = rsqrt(var + eps)
63
73
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
74
+ # Store statistics (convert back to original dtype only once)
75
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
76
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
77
 
67
- Y_row = Xmm * rstd * W_row + B_row
78
+ # Fused normalization and affine transformation
79
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
80
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
81
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
82
+ # Store output (single conversion back to original dtype)
83
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
84
 
71
85
 
72
86
  @triton.jit
@@ -81,73 +95,87 @@ def _layer_norm_backward_kernel(
81
95
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
96
  stride_x, # stride of each row in input
83
97
  stride_dx, # stride of each row in input grad
84
- stride_dw, # stride of each row in weights grad
85
- stride_db, # stride of each row in bias grad
86
98
  stride_dy, # stride of each row in output grad
87
- n_rows,
88
99
  n_cols,
89
- rows_per_program: tl.constexpr,
90
100
  BLOCK_SIZE: tl.constexpr,
91
101
  dtype: tl.constexpr,
102
+ atomic_dtype: tl.constexpr,
92
103
  ):
93
104
  """
94
105
  References:
95
106
  https://arxiv.org/abs/1607.06450
96
107
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
108
  """
100
- row_block_id = tl.program_id(0)
101
- row_start = row_block_id * rows_per_program
102
- row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
+ row_idx = tl.program_id(0).to(tl.int64)
103
110
  cols = tl.arange(0, BLOCK_SIZE)
104
111
  mask = cols < n_cols
105
112
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
- db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
-
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
114
-
115
- for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
124
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
- c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
-
129
- dw_row += dy * x_hat
130
- db_row += dy
131
-
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
137
-
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
113
+ # Pre-load weights once (same optimization as forward pass)
114
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
+ w_f32 = w.to(tl.float32)
116
+ n_cols_f32 = n_cols.to(tl.float32)
117
+
118
+ # Calculate pointers for this specific row
119
+ row_X_ptr = X_ptr + row_idx * stride_x
120
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
121
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
122
+ row_Mean_ptr = Mean_ptr + row_idx
123
+ row_RSTD_ptr = RSTD_ptr + row_idx
124
+
125
+ # Load data for this row
126
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
127
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
128
+ mean = tl.load(row_Mean_ptr)
129
+ rstd = tl.load(row_RSTD_ptr)
130
+
131
+ # Convert to fp32 for numerical stability
132
+ x_f32 = x.to(tl.float32)
133
+ dy_f32 = dy.to(tl.float32)
134
+ mean_f32 = mean.to(tl.float32)
135
+ rstd_f32 = rstd.to(tl.float32)
136
+
137
+ # Compute backward pass for this row
138
+ x_hat = (x_f32 - mean_f32) * rstd_f32
139
+ wdy = w_f32 * dy_f32
140
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
+ c2 = tl.sum(wdy, axis=0) / n_cols_f32
142
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
+
144
+ # Store input gradient
145
+ tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
146
+
147
+ # Accumulate weight and bias gradients using atomic operations
148
+ dw = dy_f32 * x_hat
149
+ db = dy_f32
150
+ tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
151
+ tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
140
152
 
141
153
 
142
154
  def layer_norm_forward(X, W, B, eps):
155
+ """
156
+ Args:
157
+ X: Input tensor of shape (..., hidden_size)
158
+ W: Weight tensor of shape (hidden_size,)
159
+ B: Bias tensor of shape (hidden_size,)
160
+ eps: Small constant for numerical stability
161
+
162
+ Returns:
163
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
164
+ """
143
165
  shape = X.shape
144
166
  dim = shape[-1]
145
167
  X = X.view(-1, dim)
146
168
  n_rows, n_cols = X.shape
169
+
170
+ # Calculate optimal block size and warp configuration
147
171
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
172
+
173
+ # Allocate output tensors
148
174
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
175
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
176
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
177
+
178
+ # Validate input dimensions
151
179
  if X.shape[1] != W.shape[0]:
152
180
  raise ValueError(
153
181
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +187,9 @@ def layer_norm_forward(X, W, B, eps):
159
187
  if X.device.type == "xpu":
160
188
  kernel_args["grf_mode"] = "large"
161
189
 
162
- _layer_norm_forward_kernel[(n_rows,)](
190
+ # Launch kernel with one thread block per row for optimal performance
191
+ grid = (n_rows,)
192
+ _layer_norm_forward_kernel[grid](
163
193
  Y,
164
194
  Y.stride(0),
165
195
  X,
@@ -176,35 +206,43 @@ def layer_norm_forward(X, W, B, eps):
176
206
  eps,
177
207
  BLOCK_SIZE=BLOCK_SIZE,
178
208
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
209
+ **kernel_args,
180
210
  )
211
+
181
212
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
213
 
183
214
 
184
215
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
216
+ """
217
+ Args:
218
+ dY: Gradient of output
219
+ X: Input tensor
220
+ W: Weight tensor
221
+ B: Bias tensor
222
+ Mean: Pre-computed mean
223
+ RSTD: Pre-computed reciprocal standard deviation
224
+
225
+ Returns:
226
+ Tuple of (input_grad, weight_grad, bias_grad)
227
+ """
185
228
  shape = dY.shape
186
229
  dim = shape[-1]
187
230
  dY = dY.view(-1, dim)
188
231
  n_rows, n_cols = dY.shape
189
232
 
190
- sm_count = 1
191
- if X.device.type == "cuda":
192
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
- elif X.device.type == "xpu":
194
- sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
-
233
+ # Allocate gradient tensors
196
234
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
235
+ # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
236
+ grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
237
+ DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
238
+ DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
199
239
 
240
+ # Calculate optimal block size and warp configuration
200
241
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
242
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
243
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
205
244
 
206
- rows_per_program = math.ceil(n_rows / sm_count)
207
- grid = (sm_count,)
245
+ # Determine dtype for triton operations
208
246
  triton_dtype = (
209
247
  tl.float32
210
248
  if X.dtype == torch.float32
@@ -212,41 +250,40 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
212
250
  if X.dtype == torch.bfloat16
213
251
  else tl.float16
214
252
  if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
253
+ else tl.float32 # fallback
216
254
  )
217
255
 
256
+ # Use float32 for atomic operations if bfloat16 is not supported
257
+ atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
258
+
259
+ kernel_args = {"num_warps": num_warps}
218
260
  # XPU-specific optimization
219
- kernel_args = {}
220
261
  if X.device.type == "xpu":
221
262
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
263
 
264
+ # Launch kernel with one thread block per row for optimal performance
265
+ grid = (n_rows,)
223
266
  _layer_norm_backward_kernel[grid](
224
267
  X,
225
268
  W,
226
269
  Mean,
227
270
  RSTD,
228
271
  DX,
229
- _DW,
230
- _DB,
272
+ DW,
273
+ DB,
231
274
  dY,
232
275
  X.stride(0),
233
276
  DX.stride(0),
234
- _DW.stride(0),
235
- _DB.stride(0),
236
277
  dY.stride(0),
237
- n_rows,
238
278
  n_cols,
239
- rows_per_program,
240
279
  BLOCK_SIZE=BLOCK_SIZE,
241
280
  dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
281
+ atomic_dtype=atomic_dtype,
282
+ **kernel_args,
243
283
  )
244
284
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
285
  DX = DX.view(*shape)
249
- return DX, DW, DB
286
+ return DX, DW.to(W.dtype), DB.to(W.dtype)
250
287
 
251
288
 
252
289
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None
@@ -63,7 +63,7 @@ def _rms_norm_forward_kernel(
63
63
  3. https://arxiv.org/pdf/1910.07467
64
64
  """
65
65
 
66
- row_idx = tl.program_id(0)
66
+ row_idx = tl.program_id(0).to(tl.int64)
67
67
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
68
  mask = col_offsets < n_cols
69
69
 
@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
137
137
  dw = sum(dy * (x / RMS)). summation over BxT dimension
138
138
  """
139
139
 
140
- row_block_id = tl.program_id(0)
140
+ row_block_id = tl.program_id(0).to(tl.int64)
141
141
  row_start = row_block_id * rows_per_program
142
142
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
143
143
  col_offsets = tl.arange(0, BLOCK_SIZE)
liger_kernel/ops/rope.py CHANGED
@@ -32,7 +32,7 @@ def _triton_rope(
32
32
 
33
33
  # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
34
34
  # stride: (seq_len * head_dim, head_dim, 1)
35
- pid = tl.program_id(0)
35
+ pid = tl.program_id(0).to(tl.int64)
36
36
 
37
37
  # locate start address
38
38
  q_ptr = q_ptr + pid * q_row_stride
@@ -5,13 +5,20 @@ from typing import TYPE_CHECKING
5
5
  # Always-safe imports (independent of 'transformers')
6
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
7
7
  from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
8
9
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
9
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
10
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
11
12
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13
+ from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
12
14
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
+ from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
+ from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
13
18
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
14
19
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20
+ from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
21
+ from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
15
22
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
16
23
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
17
24
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
@@ -28,6 +35,7 @@ if TYPE_CHECKING:
28
35
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
29
36
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
30
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
38
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
31
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
32
40
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
33
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -43,6 +51,7 @@ if TYPE_CHECKING:
43
51
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
44
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
45
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
54
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
46
55
 
47
56
 
48
57
  # Check if 'transformers' is installed
@@ -85,6 +94,7 @@ def __getattr__(name: str):
85
94
  "apply_liger_kernel_to_gemma3",
86
95
  "apply_liger_kernel_to_gemma3_text",
87
96
  "apply_liger_kernel_to_glm4",
97
+ "apply_liger_kernel_to_glm4v",
88
98
  "apply_liger_kernel_to_granite",
89
99
  "apply_liger_kernel_to_llama",
90
100
  "apply_liger_kernel_to_llava",
@@ -100,6 +110,7 @@ def __getattr__(name: str):
100
110
  "apply_liger_kernel_to_qwen2_vl",
101
111
  "apply_liger_kernel_to_qwen3",
102
112
  "apply_liger_kernel_to_qwen3_moe",
113
+ "apply_liger_kernel_to_smollm3",
103
114
  }
104
115
 
105
116
  if name in monkey_patch_symbols:
@@ -119,13 +130,20 @@ __all__ = [
119
130
  "LigerGEGLUMLP",
120
131
  "LigerJSD",
121
132
  "LigerLayerNorm",
133
+ "LigerFusedAddRMSNorm",
122
134
  "LigerRMSNorm",
123
135
  "liger_rotary_pos_emb",
136
+ "liger_llama4_text_rotary_pos_emb",
137
+ "liger_llama4_vision_rotary_pos_emb",
124
138
  "LigerBlockSparseTop2MLP",
125
139
  "LigerPhi3SwiGLUMLP",
126
140
  "LigerQwen3MoeSwiGLUMLP",
127
141
  "LigerSwiGLUMLP",
128
142
  "LigerTVDLoss",
143
+ "LigerKLDIVLoss",
144
+ "LigerMultiTokenAttention",
145
+ "LigerSoftmax",
146
+ "LigerSparsemax",
129
147
  ]
130
148
 
131
149
  # Add transformer-dependent symbols only if available
@@ -140,6 +158,7 @@ if _TRANSFORMERS_AVAILABLE:
140
158
  "apply_liger_kernel_to_gemma3",
141
159
  "apply_liger_kernel_to_gemma3_text",
142
160
  "apply_liger_kernel_to_glm4",
161
+ "apply_liger_kernel_to_glm4v",
143
162
  "apply_liger_kernel_to_granite",
144
163
  "apply_liger_kernel_to_llama",
145
164
  "apply_liger_kernel_to_llava",
@@ -155,5 +174,6 @@ if _TRANSFORMERS_AVAILABLE:
155
174
  "apply_liger_kernel_to_qwen2_vl",
156
175
  "apply_liger_kernel_to_qwen3",
157
176
  "apply_liger_kernel_to_qwen3_moe",
177
+ "apply_liger_kernel_to_smollm3",
158
178
  ]
159
179
  )
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2
+
3
+ __all__ = [
4
+ "LigerEmbedding",
5
+ ]