liger-kernel-nightly 0.5.10.dev20250527002824__py3-none-any.whl → 0.5.10.dev20250528223524__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/rms_norm.py +243 -45
- liger_kernel/transformers/monkey_patch.py +3 -4
- liger_kernel/transformers/rms_norm.py +4 -1
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250527002824.dist-info → liger_kernel_nightly-0.5.10.dev20250528223524.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
@@ -193,6 +193,153 @@ def _rms_norm_backward_kernel(
|
|
193
193
|
|
194
194
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
195
195
|
|
196
|
+
@triton.jit
|
197
|
+
def _block_rms_norm_forward_kernel(
|
198
|
+
Y_ptr,
|
199
|
+
Y_row_stride,
|
200
|
+
X_ptr,
|
201
|
+
X_row_stride,
|
202
|
+
W_ptr,
|
203
|
+
W_row_stride,
|
204
|
+
RSTD_ptr,
|
205
|
+
RSTD_row_stride,
|
206
|
+
n_rows,
|
207
|
+
n_cols,
|
208
|
+
eps,
|
209
|
+
offset,
|
210
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
211
|
+
BLOCK_SIZE: tl.constexpr,
|
212
|
+
BLOCK_ROW: tl.constexpr,
|
213
|
+
):
|
214
|
+
"""
|
215
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
216
|
+
|
217
|
+
Reference:
|
218
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
219
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
220
|
+
3. https://arxiv.org/pdf/1910.07467
|
221
|
+
"""
|
222
|
+
|
223
|
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
224
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
225
|
+
row_mask = row_idx < n_rows
|
226
|
+
col_mask = col_offsets < n_cols
|
227
|
+
|
228
|
+
|
229
|
+
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :] , other=0)
|
230
|
+
X_row_dtype = X_row.dtype
|
231
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
232
|
+
|
233
|
+
# On Llama, only rstd is computed on fp32
|
234
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
235
|
+
X_row = X_row.to(tl.float32)
|
236
|
+
|
237
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
238
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
239
|
+
W_row = W_row.to(tl.float32)
|
240
|
+
X_row = X_row.to(tl.float32)
|
241
|
+
|
242
|
+
if casting_mode == _CASTING_MODE_NONE:
|
243
|
+
eps = eps.to(X_row_dtype)
|
244
|
+
offset = offset.to(X_row_dtype)
|
245
|
+
|
246
|
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
247
|
+
rstd = rsqrt(mean_square + eps)
|
248
|
+
|
249
|
+
# We can save time by caching rms with minimal memory overhead
|
250
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
251
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
252
|
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
253
|
+
|
254
|
+
X_row = X_row * rstd[:, None]
|
255
|
+
|
256
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
257
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
258
|
+
X_row = X_row.to(X_row_dtype)
|
259
|
+
|
260
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
261
|
+
|
262
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
263
|
+
Y_row = Y_row.to(X_row_dtype)
|
264
|
+
|
265
|
+
tl.store(Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :], Y_row, mask=row_mask[:, None] & col_mask[None, :])
|
266
|
+
|
267
|
+
@triton.jit
|
268
|
+
def _block_rms_norm_backward_kernel(
|
269
|
+
dY_ptr,
|
270
|
+
dY_row_stride,
|
271
|
+
dX_ptr,
|
272
|
+
dX_row_stride,
|
273
|
+
X_ptr,
|
274
|
+
X_row_stride,
|
275
|
+
X_dtype: tl.constexpr,
|
276
|
+
W_ptr,
|
277
|
+
W_row_stride,
|
278
|
+
RSTD_ptr,
|
279
|
+
RSTD_row_stride,
|
280
|
+
dW_ptr,
|
281
|
+
dW_row_stride,
|
282
|
+
n_rows,
|
283
|
+
n_cols,
|
284
|
+
offset,
|
285
|
+
rows_per_program: tl.constexpr,
|
286
|
+
casting_mode: tl.constexpr,
|
287
|
+
BLOCK_SIZE: tl.constexpr,
|
288
|
+
BLOCK_ROW: tl.constexpr,
|
289
|
+
):
|
290
|
+
"""
|
291
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
292
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
293
|
+
"""
|
294
|
+
|
295
|
+
pid = tl.program_id(0).cast(tl.int64)
|
296
|
+
NUM_SMS = tl.num_programs(0)
|
297
|
+
|
298
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
299
|
+
col_mask = col_offsets < n_cols
|
300
|
+
|
301
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
302
|
+
|
303
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
304
|
+
W_row = W_row + offset
|
305
|
+
|
306
|
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
307
|
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
308
|
+
row_mask = row_idx < n_rows
|
309
|
+
dY_row = tl.load(dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
|
310
|
+
X_row = tl.load(X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :], mask=row_mask[:, None] & col_mask[None, :], other=0.0)
|
311
|
+
|
312
|
+
# Get cached rms
|
313
|
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
314
|
+
|
315
|
+
X_row = X_row.to(tl.float32)
|
316
|
+
|
317
|
+
# Different bacward graphs for different casting modes
|
318
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
319
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
320
|
+
|
321
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
322
|
+
dY_row = dY_row.to(tl.float32)
|
323
|
+
m = dY_row * W_row[None, :]
|
324
|
+
else:
|
325
|
+
m = dY_row * W_row[None, :]
|
326
|
+
|
327
|
+
dX_row = rstd_row[:, None] * m
|
328
|
+
|
329
|
+
dX_row += (rstd_row[:, None]) * (-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row)
|
330
|
+
|
331
|
+
# calculate the gradient of W
|
332
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
333
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
|
334
|
+
else:
|
335
|
+
# here X_row is already in fp32 (see previous if block)
|
336
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
337
|
+
|
338
|
+
tl.store(dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :], dX_row, mask=row_mask[:, None] & col_mask[None, :])
|
339
|
+
|
340
|
+
|
341
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
342
|
+
|
196
343
|
|
197
344
|
_str_to_casting_mode = {
|
198
345
|
"llama": _CASTING_MODE_LLAMA.value,
|
@@ -201,7 +348,7 @@ _str_to_casting_mode = {
|
|
201
348
|
}
|
202
349
|
|
203
350
|
|
204
|
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
351
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
205
352
|
if not isinstance(casting_mode, int):
|
206
353
|
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
207
354
|
casting_mode = _str_to_casting_mode[casting_mode]
|
@@ -227,27 +374,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
227
374
|
kernel_args = {}
|
228
375
|
if X.device.type == "xpu":
|
229
376
|
kernel_args["grf_mode"] = "large"
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
377
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
378
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
379
|
+
Y,
|
380
|
+
Y.stride(0),
|
381
|
+
X,
|
382
|
+
X.stride(0),
|
383
|
+
W,
|
384
|
+
W.stride(0),
|
385
|
+
RSTD,
|
386
|
+
RSTD.stride(0),
|
387
|
+
n_cols,
|
388
|
+
eps,
|
389
|
+
offset,
|
390
|
+
casting_mode,
|
391
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
392
|
+
num_warps=num_warps,
|
393
|
+
**kernel_args, # XPU-specific optimization
|
394
|
+
)
|
395
|
+
else:
|
396
|
+
BLOCK_ROW = 16
|
397
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
398
|
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
399
|
+
Y,
|
400
|
+
Y.stride(0),
|
401
|
+
X,
|
402
|
+
X.stride(0),
|
403
|
+
W,
|
404
|
+
W.stride(0),
|
405
|
+
RSTD,
|
406
|
+
RSTD.stride(0),
|
407
|
+
n_rows,
|
408
|
+
n_cols,
|
409
|
+
eps,
|
410
|
+
offset,
|
411
|
+
casting_mode,
|
412
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
413
|
+
num_warps=num_warps,
|
414
|
+
**kernel_args, # XPU-specific optimization
|
415
|
+
)
|
247
416
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
248
417
|
|
249
418
|
|
250
|
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
419
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
251
420
|
shape = dY.shape
|
252
421
|
dim = shape[-1]
|
253
422
|
dY = dY.view(-1, dim)
|
@@ -277,29 +446,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
277
446
|
if X.device.type == "xpu":
|
278
447
|
kernel_args["grf_mode"] = "large"
|
279
448
|
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
449
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
450
|
+
_rms_norm_backward_kernel[grid](
|
451
|
+
dY,
|
452
|
+
dY.stride(0),
|
453
|
+
dX,
|
454
|
+
dX.stride(0),
|
455
|
+
X,
|
456
|
+
X.stride(0),
|
457
|
+
torch_to_triton_dtype[X.dtype],
|
458
|
+
W,
|
459
|
+
W.stride(0),
|
460
|
+
RSTD,
|
461
|
+
RSTD.stride(0),
|
462
|
+
_dW,
|
463
|
+
_dW.stride(0),
|
464
|
+
n_rows,
|
465
|
+
n_cols,
|
466
|
+
offset,
|
467
|
+
rows_per_program,
|
468
|
+
casting_mode,
|
469
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
470
|
+
num_warps=num_warps,
|
471
|
+
**kernel_args, # XPU-specific optimization
|
472
|
+
)
|
473
|
+
else:
|
474
|
+
BLOCK_ROW = 16
|
475
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
476
|
+
_block_rms_norm_backward_kernel[grid](
|
477
|
+
dY,
|
478
|
+
dY.stride(0),
|
479
|
+
dX,
|
480
|
+
dX.stride(0),
|
481
|
+
X,
|
482
|
+
X.stride(0),
|
483
|
+
torch_to_triton_dtype[X.dtype],
|
484
|
+
W,
|
485
|
+
W.stride(0),
|
486
|
+
RSTD,
|
487
|
+
RSTD.stride(0),
|
488
|
+
_dW,
|
489
|
+
_dW.stride(0),
|
490
|
+
n_rows,
|
491
|
+
n_cols,
|
492
|
+
offset,
|
493
|
+
rows_per_program,
|
494
|
+
casting_mode,
|
495
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
496
|
+
num_warps=num_warps,
|
497
|
+
**kernel_args, # XPU-specific optimization
|
498
|
+
)
|
303
499
|
dX = dX.view(*shape)
|
304
500
|
dW = _dW.sum(dim=0).to(W.dtype)
|
305
501
|
|
@@ -330,15 +526,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
330
526
|
|
331
527
|
@staticmethod
|
332
528
|
@ensure_contiguous
|
333
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
529
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
334
530
|
"""
|
335
531
|
X: (B, T, H) or (BxT, H)
|
336
532
|
W: (H,)
|
337
533
|
"""
|
338
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
534
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
339
535
|
ctx.offset = offset
|
340
536
|
ctx.casting_mode = casting_mode
|
341
537
|
ctx.in_place = in_place
|
538
|
+
ctx.row_mode = row_mode
|
342
539
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
343
540
|
ctx.num_warps = num_warps
|
344
541
|
ctx.save_for_backward(X, W, RSTD)
|
@@ -361,5 +558,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
361
558
|
ctx.BLOCK_SIZE,
|
362
559
|
ctx.num_warps,
|
363
560
|
ctx.in_place,
|
561
|
+
ctx.row_mode
|
364
562
|
)
|
365
|
-
return dX, dW, None, None, None, None
|
563
|
+
return dX, dW, None, None, None, None, None
|
@@ -776,7 +776,7 @@ def apply_liger_kernel_to_gemma3_text(
|
|
776
776
|
|
777
777
|
from transformers.models.gemma3 import modeling_gemma3
|
778
778
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
|
779
|
-
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
|
779
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM, Gemma3TextModel
|
780
780
|
|
781
781
|
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
|
782
782
|
from liger_kernel.transformers.model.gemma3 import causal_forward
|
@@ -807,9 +807,9 @@ def apply_liger_kernel_to_gemma3_text(
|
|
807
807
|
# The model instance already exists, so we need to additionally patch the
|
808
808
|
# instance variables that reference already-instantiated modules
|
809
809
|
|
810
|
-
if isinstance(model, Gemma3ForCausalLM):
|
810
|
+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
|
811
811
|
# get the base model from the model instance
|
812
|
-
base_model = model.model
|
812
|
+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
|
813
813
|
|
814
814
|
if rms_norm:
|
815
815
|
_patch_rms_norm_module_for_gemma3(base_model.norm)
|
@@ -1625,7 +1625,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
|
|
1625
1625
|
return
|
1626
1626
|
|
1627
1627
|
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
1628
|
-
|
1629
1628
|
apply_fn_signature = inspect.signature(apply_fn)
|
1630
1629
|
|
1631
1630
|
# Filter out the keyword arguments that are not supported by the apply function
|
@@ -13,6 +13,7 @@ class LigerRMSNorm(nn.Module):
|
|
13
13
|
casting_mode="llama",
|
14
14
|
init_fn="ones",
|
15
15
|
in_place=True,
|
16
|
+
row_mode=None,
|
16
17
|
):
|
17
18
|
super().__init__()
|
18
19
|
assert init_fn in [
|
@@ -20,11 +21,12 @@ class LigerRMSNorm(nn.Module):
|
|
20
21
|
"zeros",
|
21
22
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
22
23
|
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
23
|
-
self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
|
24
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
24
25
|
eps,
|
25
26
|
offset,
|
26
27
|
casting_mode,
|
27
28
|
in_place,
|
29
|
+
row_mode,
|
28
30
|
)
|
29
31
|
|
30
32
|
def forward(self, hidden_states):
|
@@ -35,6 +37,7 @@ class LigerRMSNorm(nn.Module):
|
|
35
37
|
self.offset,
|
36
38
|
self.casting_mode,
|
37
39
|
self.in_place,
|
40
|
+
self.row_mode
|
38
41
|
)
|
39
42
|
|
40
43
|
def extra_repr(self):
|
@@ -28,7 +28,7 @@ liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,87
|
|
28
28
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
29
29
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
30
30
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
31
|
-
liger_kernel/ops/rms_norm.py,sha256=
|
31
|
+
liger_kernel/ops/rms_norm.py,sha256=IDj_V3hwo6tm3FijVbRh6ebUj2A3591MNkMer_gncdM,18749
|
32
32
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
33
33
|
liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
|
34
34
|
liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
|
@@ -52,10 +52,10 @@ liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-
|
|
52
52
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
53
53
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
54
54
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
55
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
55
|
+
liger_kernel/transformers/monkey_patch.py,sha256=a0CXSC8BwZg3vok-ns0udZLUOBkegGQgPDod3H8ilP4,74610
|
56
56
|
liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
|
57
57
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
58
|
-
liger_kernel/transformers/rms_norm.py,sha256=
|
58
|
+
liger_kernel/transformers/rms_norm.py,sha256=srMS4jdkMCjY4Yqj9jjsy_IkY8KlHdTPLOx4069ZACA,1277
|
59
59
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
60
60
|
liger_kernel/transformers/softmax.py,sha256=u7bFo35-cjaAm9of6-DLzmkaNFELOM-9AgyrcvUPifw,270
|
61
61
|
liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
|
@@ -86,9 +86,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
86
86
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
87
87
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
88
88
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
94
|
-
liger_kernel_nightly-0.5.10.
|
89
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/METADATA,sha256=XqzBAk8PxwjhEYwf_3Xw0sssbGSM3IWW9z3NWlsZ7ZU,24113
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250528223524.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|