liger-kernel-nightly 0.6.0.dev20250718080702__py3-none-any.whl → 0.6.0.dev20250719041256__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- import math
2
1
  import operator
3
2
 
4
3
  import torch
@@ -43,30 +42,45 @@ def _layer_norm_forward_kernel(
43
42
  https://arxiv.org/abs/1607.06450
44
43
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
44
  """
46
- row_idx = tl.program_id(0)
45
+ row_idx = tl.program_id(0).to(tl.int64)
47
46
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
47
  mask = col_offsets < n_cols
49
48
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
49
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
50
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
51
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
52
+ W_f32 = W_row.to(tl.float32)
53
+ B_f32 = B_row.to(tl.float32)
54
+
55
+ # Calculate pointers for this row
56
+ row_X_ptr = X_ptr + row_idx * X_row_stride
57
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
58
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
59
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
60
+
61
+ # Load input data and convert to fp32 for numerical stability
62
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
63
+ X_f32 = X_row.to(tl.float32)
64
+
65
+ # Compute statistics in fp32 for numerical stability
66
+ n_cols_f32 = n_cols.to(tl.float32)
67
+ mean = tl.sum(X_f32, axis=0) / n_cols_f32
68
+ X_centered = X_f32 - mean
69
+ # Apply mask to variance calculation to exclude contributions from masked elements
70
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
71
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
62
72
  rstd = rsqrt(var + eps)
63
73
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
74
+ # Store statistics (convert back to original dtype only once)
75
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
76
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
77
 
67
- Y_row = Xmm * rstd * W_row + B_row
78
+ # Fused normalization and affine transformation
79
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
80
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
81
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
82
+ # Store output (single conversion back to original dtype)
83
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
84
 
71
85
 
72
86
  @triton.jit
@@ -81,73 +95,87 @@ def _layer_norm_backward_kernel(
81
95
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
96
  stride_x, # stride of each row in input
83
97
  stride_dx, # stride of each row in input grad
84
- stride_dw, # stride of each row in weights grad
85
- stride_db, # stride of each row in bias grad
86
98
  stride_dy, # stride of each row in output grad
87
- n_rows,
88
99
  n_cols,
89
- rows_per_program: tl.constexpr,
90
100
  BLOCK_SIZE: tl.constexpr,
91
101
  dtype: tl.constexpr,
102
+ atomic_dtype: tl.constexpr,
92
103
  ):
93
104
  """
94
105
  References:
95
106
  https://arxiv.org/abs/1607.06450
96
107
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
108
  """
100
- row_block_id = tl.program_id(0)
101
- row_start = row_block_id * rows_per_program
102
- row_end = min((row_block_id + 1) * rows_per_program, n_rows)
109
+ row_idx = tl.program_id(0).to(tl.int64)
103
110
  cols = tl.arange(0, BLOCK_SIZE)
104
111
  mask = cols < n_cols
105
112
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
- db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
-
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
114
-
115
- for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
124
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
- c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
128
-
129
- dw_row += dy * x_hat
130
- db_row += dy
131
-
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
137
-
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
113
+ # Pre-load weights once (same optimization as forward pass)
114
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
115
+ w_f32 = w.to(tl.float32)
116
+ n_cols_f32 = n_cols.to(tl.float32)
117
+
118
+ # Calculate pointers for this specific row
119
+ row_X_ptr = X_ptr + row_idx * stride_x
120
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
121
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
122
+ row_Mean_ptr = Mean_ptr + row_idx
123
+ row_RSTD_ptr = RSTD_ptr + row_idx
124
+
125
+ # Load data for this row
126
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
127
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
128
+ mean = tl.load(row_Mean_ptr)
129
+ rstd = tl.load(row_RSTD_ptr)
130
+
131
+ # Convert to fp32 for numerical stability
132
+ x_f32 = x.to(tl.float32)
133
+ dy_f32 = dy.to(tl.float32)
134
+ mean_f32 = mean.to(tl.float32)
135
+ rstd_f32 = rstd.to(tl.float32)
136
+
137
+ # Compute backward pass for this row
138
+ x_hat = (x_f32 - mean_f32) * rstd_f32
139
+ wdy = w_f32 * dy_f32
140
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
141
+ c2 = tl.sum(wdy, axis=0) / n_cols_f32
142
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
143
+
144
+ # Store input gradient
145
+ tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
146
+
147
+ # Accumulate weight and bias gradients using atomic operations
148
+ dw = dy_f32 * x_hat
149
+ db = dy_f32
150
+ tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
151
+ tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
140
152
 
141
153
 
142
154
  def layer_norm_forward(X, W, B, eps):
155
+ """
156
+ Args:
157
+ X: Input tensor of shape (..., hidden_size)
158
+ W: Weight tensor of shape (hidden_size,)
159
+ B: Bias tensor of shape (hidden_size,)
160
+ eps: Small constant for numerical stability
161
+
162
+ Returns:
163
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
164
+ """
143
165
  shape = X.shape
144
166
  dim = shape[-1]
145
167
  X = X.view(-1, dim)
146
168
  n_rows, n_cols = X.shape
169
+
170
+ # Calculate optimal block size and warp configuration
147
171
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
172
+
173
+ # Allocate output tensors
148
174
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
175
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
176
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
177
+
178
+ # Validate input dimensions
151
179
  if X.shape[1] != W.shape[0]:
152
180
  raise ValueError(
153
181
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +187,9 @@ def layer_norm_forward(X, W, B, eps):
159
187
  if X.device.type == "xpu":
160
188
  kernel_args["grf_mode"] = "large"
161
189
 
162
- _layer_norm_forward_kernel[(n_rows,)](
190
+ # Launch kernel with one thread block per row for optimal performance
191
+ grid = (n_rows,)
192
+ _layer_norm_forward_kernel[grid](
163
193
  Y,
164
194
  Y.stride(0),
165
195
  X,
@@ -176,35 +206,43 @@ def layer_norm_forward(X, W, B, eps):
176
206
  eps,
177
207
  BLOCK_SIZE=BLOCK_SIZE,
178
208
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
209
+ **kernel_args,
180
210
  )
211
+
181
212
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
213
 
183
214
 
184
215
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
216
+ """
217
+ Args:
218
+ dY: Gradient of output
219
+ X: Input tensor
220
+ W: Weight tensor
221
+ B: Bias tensor
222
+ Mean: Pre-computed mean
223
+ RSTD: Pre-computed reciprocal standard deviation
224
+
225
+ Returns:
226
+ Tuple of (input_grad, weight_grad, bias_grad)
227
+ """
185
228
  shape = dY.shape
186
229
  dim = shape[-1]
187
230
  dY = dY.view(-1, dim)
188
231
  n_rows, n_cols = dY.shape
189
232
 
190
- sm_count = 1
191
- if X.device.type == "cuda":
192
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
193
- elif X.device.type == "xpu":
194
- sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
-
233
+ # Allocate gradient tensors
196
234
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
235
+ # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
236
+ grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
237
+ DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
238
+ DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
199
239
 
240
+ # Calculate optimal block size and warp configuration
200
241
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
242
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
243
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
205
244
 
206
- rows_per_program = math.ceil(n_rows / sm_count)
207
- grid = (sm_count,)
245
+ # Determine dtype for triton operations
208
246
  triton_dtype = (
209
247
  tl.float32
210
248
  if X.dtype == torch.float32
@@ -212,41 +250,41 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
212
250
  if X.dtype == torch.bfloat16
213
251
  else tl.float16
214
252
  if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
253
+ else tl.float32 # fallback
216
254
  )
217
255
 
256
+ # Use float32 for atomic operations if bfloat16 is not supported
257
+ atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
258
+
218
259
  # XPU-specific optimization
219
260
  kernel_args = {}
220
261
  if X.device.type == "xpu":
221
262
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
263
 
264
+ # Launch kernel with one thread block per row for optimal performance
265
+ grid = (n_rows,)
223
266
  _layer_norm_backward_kernel[grid](
224
267
  X,
225
268
  W,
226
269
  Mean,
227
270
  RSTD,
228
271
  DX,
229
- _DW,
230
- _DB,
272
+ DW,
273
+ DB,
231
274
  dY,
232
275
  X.stride(0),
233
276
  DX.stride(0),
234
- _DW.stride(0),
235
- _DB.stride(0),
236
277
  dY.stride(0),
237
- n_rows,
238
278
  n_cols,
239
- rows_per_program,
240
279
  BLOCK_SIZE=BLOCK_SIZE,
241
280
  dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
281
+ atomic_dtype=atomic_dtype,
282
+ num_warps=num_warps,
283
+ **kernel_args,
243
284
  )
244
285
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
286
  DX = DX.view(*shape)
249
- return DX, DW, DB
287
+ return DX, DW.to(W.dtype), DB.to(W.dtype)
250
288
 
251
289
 
252
290
  class LigerLayerNormFunction(torch.autograd.Function):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.0.dev20250718080702
3
+ Version: 0.6.0.dev20250719041256
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -28,7 +28,7 @@ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2wogg
28
28
  liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
29
29
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
30
30
  liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
31
- liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
31
+ liger_kernel/ops/layer_norm.py,sha256=g7TfVMCoUg_zAl4nb4xPr_FTRGjxxDhFXdKEnU9NYhE,9908
32
32
  liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
33
33
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
34
34
  liger_kernel/ops/rms_norm.py,sha256=-rcgHwWCxlA-Syec2XhdW4jfOeCDt2r7qwjslgXFYDU,18865
@@ -92,9 +92,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
92
92
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
93
93
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
94
94
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
95
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
96
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/METADATA,sha256=mNkIMGPTMPdmmjDsW54kDe0WhPi8Ep0Cpt4koWQQuaE,24672
97
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
98
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
99
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
100
- liger_kernel_nightly-0.6.0.dev20250718080702.dist-info/RECORD,,
95
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
96
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/METADATA,sha256=fHiX-8I-gf3rBfmNt7LRh5ZwZcXoddbQ5rJzYxohaTg,24672
97
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
98
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
99
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
100
+ liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/RECORD,,