liger-kernel-nightly 0.4.0.dev20241112204954__tar.gz → 0.4.1.dev20241112233746__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (56) hide show
  1. {liger_kernel_nightly-0.4.0.dev20241112204954/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.1.dev20241112233746}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/rms_norm.py +27 -6
  4. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/monkey_patch.py +1 -1
  5. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/rms_norm.py +11 -3
  6. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  7. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/LICENSE +0 -0
  8. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/NOTICE +0 -0
  9. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/README.md +0 -0
  10. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/setup.cfg +0 -0
  11. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/env_report.py +0 -0
  12. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/__init__.py +0 -0
  13. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/cross_entropy.py +0 -0
  14. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  15. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  16. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  17. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  18. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/geglu.py +0 -0
  19. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/group_norm.py +0 -0
  20. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/jsd.py +0 -0
  21. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/kl_div.py +0 -0
  22. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/layer_norm.py +0 -0
  23. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/rope.py +0 -0
  24. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/swiglu.py +0 -0
  25. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/ops/utils.py +0 -0
  26. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/__init__.py +0 -0
  27. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/auto_model.py +0 -0
  28. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  29. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  30. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/functional.py +0 -0
  31. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  33. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/geglu.py +0 -0
  34. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/group_norm.py +0 -0
  35. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/jsd.py +0 -0
  36. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/kl_div.py +0 -0
  37. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/layer_norm.py +0 -0
  38. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/__init__.py +0 -0
  39. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/gemma.py +0 -0
  40. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  41. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/llama.py +0 -0
  42. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/mistral.py +0 -0
  43. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  44. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/mllama.py +0 -0
  45. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/phi3.py +0 -0
  46. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  47. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  48. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/rope.py +0 -0
  49. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/swiglu.py +0 -0
  50. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  51. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/triton/__init__.py +0 -0
  52. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel/triton/monkey_patch.py +0 -0
  53. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  54. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  55. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  56. {liger_kernel_nightly-0.4.0.dev20241112204954 → liger_kernel_nightly-0.4.1.dev20241112233746}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241112204954
3
+ Version: 0.4.1.dev20241112233746
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.0.dev20241112204954"
7
+ version = "0.4.1.dev20241112233746"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -116,6 +116,8 @@ def _rms_norm_forward_kernel(
116
116
  def _rms_norm_backward_kernel(
117
117
  dY_ptr,
118
118
  dY_row_stride,
119
+ dX_ptr,
120
+ dX_row_stride,
119
121
  X_ptr,
120
122
  X_row_stride,
121
123
  X_dtype: tl.constexpr,
@@ -146,6 +148,8 @@ def _rms_norm_backward_kernel(
146
148
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
149
 
148
150
  dY_ptr += row_start * dY_row_stride
151
+ dX_ptr += row_start * dX_row_stride
152
+
149
153
  X_ptr += row_start * X_row_stride
150
154
  RSTD_ptr += row_start
151
155
 
@@ -184,9 +188,10 @@ def _rms_norm_backward_kernel(
184
188
  # here X_row is already in fp32 (see previous if block)
185
189
  dW_row += dY_row * (X_row * rstd_row)
186
190
 
187
- tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
191
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
188
192
 
189
193
  dY_ptr += dY_row_stride
194
+ dX_ptr += dX_row_stride
190
195
  X_ptr += X_row_stride
191
196
  RSTD_ptr += RSTD_row_stride
192
197
 
@@ -251,7 +256,9 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
251
256
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
252
257
 
253
258
 
254
- def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
259
+ def rms_norm_backward(
260
+ dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261
+ ):
255
262
  shape = dY.shape
256
263
  dim = shape[-1]
257
264
  dY = dY.view(-1, dim)
@@ -265,10 +272,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
265
272
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
266
273
  rows_per_program = math.ceil(n_rows / sm_count)
267
274
  grid = (sm_count,)
268
- # Here we use dY to store the value of dX to save memory
275
+
276
+ if in_place is True:
277
+ dX = dY
278
+ else:
279
+ dX = torch.zeros_like(dY)
280
+
269
281
  _rms_norm_backward_kernel[grid](
270
282
  dY,
271
283
  dY.stride(0),
284
+ dX,
285
+ dX.stride(0),
272
286
  X,
273
287
  X.stride(0),
274
288
  torch_to_triton_dtype[X.dtype],
@@ -286,8 +300,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
286
300
  BLOCK_SIZE=BLOCK_SIZE,
287
301
  num_warps=num_warps,
288
302
  )
289
- dX = dY.view(*shape)
303
+ dX = dX.view(*shape)
290
304
  dW = _dW.sum(dim=0).to(W.dtype)
305
+
291
306
  return dX, dW
292
307
 
293
308
 
@@ -307,11 +322,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
307
322
  - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
308
323
  - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
309
324
  - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
325
+
326
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
327
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
328
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
310
329
  """
311
330
 
312
331
  @staticmethod
313
332
  @ensure_contiguous
314
- def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
333
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
315
334
  """
316
335
  X: (B, T, H) or (BxT, H)
317
336
  W: (H,)
@@ -321,6 +340,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
321
340
  )
322
341
  ctx.offset = offset
323
342
  ctx.casting_mode = casting_mode
343
+ ctx.in_place = in_place
324
344
  ctx.BLOCK_SIZE = BLOCK_SIZE
325
345
  ctx.num_warps = num_warps
326
346
  ctx.save_for_backward(X, W, RSTD)
@@ -342,5 +362,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
342
362
  ctx.casting_mode,
343
363
  ctx.BLOCK_SIZE,
344
364
  ctx.num_warps,
365
+ ctx.in_place,
345
366
  )
346
- return dX, dW, None, None, None
367
+ return dX, dW, None, None, None, None
@@ -507,7 +507,7 @@ def apply_liger_kernel_to_gemma2(
507
507
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
508
508
 
509
509
  LigerRMSNormForGemma2 = partial(
510
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
510
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
511
511
  )
512
512
  _patch_rms_norm_module_for_gemma2 = partial(
513
513
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
@@ -6,7 +6,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
8
  def __init__(
9
- self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=True,
10
16
  ):
11
17
  super().__init__()
12
18
  assert init_fn in [
@@ -16,10 +22,11 @@ class LigerRMSNorm(nn.Module):
16
22
  self.weight = nn.Parameter(
17
23
  torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
24
  )
19
- self.variance_epsilon, self.offset, self.casting_mode = (
25
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
20
26
  eps,
21
27
  offset,
22
28
  casting_mode,
29
+ in_place,
23
30
  )
24
31
 
25
32
  def forward(self, hidden_states):
@@ -29,7 +36,8 @@ class LigerRMSNorm(nn.Module):
29
36
  self.variance_epsilon,
30
37
  self.offset,
31
38
  self.casting_mode,
39
+ self.in_place,
32
40
  )
33
41
 
34
42
  def extra_repr(self):
35
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
43
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241112204954
3
+ Version: 0.4.1.dev20241112233746
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation