liger-kernel-nightly 0.4.2.dev20241121054604__tar.gz → 0.4.2.dev20241121225747__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241121225747}/PKG-INFO +3 -3
  2. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/README.md +2 -2
  3. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/fused_linear_jsd.py +1 -1
  5. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/jsd.py +19 -10
  6. liger_kernel_nightly-0.4.2.dev20241121225747/src/liger_kernel/transformers/functional.py +173 -0
  7. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/fused_linear_jsd.py +1 -4
  8. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/jsd.py +1 -4
  9. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747/src/liger_kernel_nightly.egg-info}/PKG-INFO +3 -3
  10. liger_kernel_nightly-0.4.2.dev20241121054604/src/liger_kernel/transformers/functional.py +0 -58
  11. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/LICENSE +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/NOTICE +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/setup.cfg +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/functional.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/env_report.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/__init__.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/cross_entropy.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/geglu.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/group_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/kl_div.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/layer_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/rms_norm.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/rope.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/swiglu.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/ops/utils.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/__init__.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/auto_model.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/kl_div.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/layer_norm.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/__init__.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/gemma.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/llama.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/mistral.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/mllama.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/phi3.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/rms_norm.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/rope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/swiglu.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/triton/__init__.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel/triton/monkey_patch.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241121054604 → liger_kernel_nightly-0.4.2.dev20241121225747}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241121054604
3
+ Version: 0.4.2.dev20241121225747
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -303,8 +303,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
303
303
  <!-- TODO: verify vocab sizes are accurate -->
304
304
  - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
305
305
  - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
306
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
307
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
306
+ - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
307
+ - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
308
308
 
309
309
 
310
310
  ### Experimental Kernels
@@ -256,8 +256,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
256
256
  <!-- TODO: verify vocab sizes are accurate -->
257
257
  - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
258
258
  - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
259
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
260
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
259
+ - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
260
+ - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
261
261
 
262
262
 
263
263
  ### Experimental Kernels
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241121054604"
7
+ version = "0.4.2.dev20241121225747"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -202,7 +202,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
202
202
  teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
203
  teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
204
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
- jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
205
+ jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
206
206
  ignore_index (int): the index to ignore. Default: -100
207
207
  temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
208
 
@@ -18,7 +18,7 @@ def _jsd_kernel(
18
18
  dX_ptr,
19
19
  dX_stride,
20
20
  label_ptr,
21
- beta,
21
+ beta: tl.constexpr,
22
22
  n_non_ignore: int,
23
23
  ignore_index: tl.constexpr,
24
24
  n_cols,
@@ -50,17 +50,26 @@ def _jsd_kernel(
50
50
  X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
51
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
52
 
53
- Q = tl.exp(X)
54
- P = tl.exp(Y)
55
- M = beta * P + (1 - beta) * Q
56
- log_M = tl.log(M)
53
+ if beta == 0.0: # forward KL
54
+ Y_prob = tl.exp(Y)
55
+ loss = Y_prob * (Y - X)
56
+ dX = -Y_prob
57
+ elif beta == 1.0:
58
+ X_prob = tl.exp(X)
59
+ loss = X_prob * (X - Y)
60
+ dX = loss + X_prob
61
+ else:
62
+ Q = tl.exp(X)
63
+ P = tl.exp(Y)
64
+ M = beta * P + (1 - beta) * Q
65
+ log_M = tl.log(M)
66
+
67
+ loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
+ dX = (1 - beta) * Q * (X - log_M)
57
69
 
58
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
- # reduction == "batchmean"
60
70
  loss = loss / n_non_ignore
71
+ dX = dX / n_non_ignore
61
72
  tl.store(loss_ptr + offsets, loss, mask=mask)
62
-
63
- dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
73
  tl.store(dX_ptr + offsets, dX, mask=mask)
65
74
 
66
75
 
@@ -142,7 +151,7 @@ class LigerJSDFunction(torch.autograd.Function):
142
151
  _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
152
  target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
153
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
- beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
154
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
146
155
  ignore_index (int): the index to ignore. Default: -100
147
156
 
148
157
  Returns:
@@ -0,0 +1,173 @@
1
+ from typing import Optional
2
+
3
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.fused_linear_cross_entropy import (
5
+ LigerFusedLinearCrossEntropyFunction,
6
+ )
7
+ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
9
+ from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
+ from liger_kernel.ops.jsd import LigerJSDFunction
11
+ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
+ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14
+ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
+ from liger_kernel.ops.rope import LigerRopeFunction
16
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
+
18
+
19
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
20
+ # `weight` and `size_average` are placeholders and not implemented yet
21
+ def liger_cross_entropy(
22
+ input,
23
+ target,
24
+ weight=None,
25
+ size_average=None,
26
+ ignore_index: int = -100,
27
+ reduce=None,
28
+ reduction: str = "mean",
29
+ label_smoothing: float = 0.0,
30
+ lse_square_scale: float = 0.0,
31
+ softcap: Optional[float] = None,
32
+ return_z_loss: bool = False,
33
+ ):
34
+ loss, z_loss = LigerCrossEntropyFunction.apply(
35
+ input,
36
+ target,
37
+ ignore_index,
38
+ lse_square_scale,
39
+ label_smoothing,
40
+ reduction,
41
+ softcap,
42
+ return_z_loss,
43
+ )
44
+ if not return_z_loss:
45
+ return loss
46
+ return loss, z_loss
47
+
48
+
49
+ def liger_fused_linear_cross_entropy(
50
+ input,
51
+ weight,
52
+ target,
53
+ bias=None,
54
+ ignore_index: int = -100,
55
+ lse_square_scale: float = 0.0,
56
+ label_smoothing: float = 0.0,
57
+ reduction: str = "mean",
58
+ softcap: Optional[float] = None,
59
+ ):
60
+ return LigerFusedLinearCrossEntropyFunction.apply(
61
+ input,
62
+ weight,
63
+ target,
64
+ bias,
65
+ ignore_index,
66
+ lse_square_scale,
67
+ label_smoothing,
68
+ reduction,
69
+ softcap,
70
+ )
71
+
72
+
73
+ def liger_fused_linear_jsd(
74
+ student_input,
75
+ student_weight,
76
+ teacher_input,
77
+ teacher_weight,
78
+ shift_labels=None,
79
+ jsd_beta: float = 0.5,
80
+ ignore_index: int = -100,
81
+ temperature: float = 1.0,
82
+ ):
83
+ return LigerFusedLinearJSDFunction.apply(
84
+ student_input,
85
+ student_weight,
86
+ teacher_input,
87
+ teacher_weight,
88
+ shift_labels,
89
+ jsd_beta,
90
+ ignore_index,
91
+ temperature,
92
+ )
93
+
94
+
95
+ def liger_geglu(a, b):
96
+ return LigerGELUMulFunction.apply(a, b)
97
+
98
+
99
+ def liger_group_norm(
100
+ X,
101
+ affine_scaling_weight,
102
+ affine_shifting_bias,
103
+ num_channels,
104
+ num_groups,
105
+ eps,
106
+ ):
107
+ return LigerGroupNormFunction.apply(
108
+ X,
109
+ affine_scaling_weight,
110
+ affine_shifting_bias,
111
+ num_channels,
112
+ num_groups,
113
+ eps,
114
+ )
115
+
116
+
117
+ def liger_jsd(
118
+ input,
119
+ target,
120
+ shift_labels=None,
121
+ beta: float = 0.5,
122
+ ignore_index: int = -100,
123
+ ):
124
+ return LigerJSDFunction.apply(
125
+ input,
126
+ target,
127
+ shift_labels,
128
+ beta,
129
+ ignore_index,
130
+ )
131
+
132
+
133
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134
+ # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135
+ def liger_kl_div(
136
+ input,
137
+ target,
138
+ size_average: bool = True,
139
+ reduce: bool = True,
140
+ reduction: str = "mean",
141
+ log_target: bool = False,
142
+ eps: float = 1e-10,
143
+ ):
144
+ # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145
+ return LigerKLDivLossFunction.apply(
146
+ input,
147
+ target,
148
+ reduction,
149
+ log_target,
150
+ eps,
151
+ )
152
+
153
+
154
+ def liger_layer_norm(X, W, B, eps):
155
+ return LigerLayerNormFunction.apply(X, W, B, eps)
156
+
157
+
158
+ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
+
161
+
162
+ def liger_rms_norm(
163
+ X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
+ ):
165
+ return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
+
167
+
168
+ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170
+
171
+
172
+ def liger_swiglu(a, b):
173
+ return LigerSiLUMulFunction.apply(a, b)
@@ -12,7 +12,7 @@ class LigerFusedLinearJSD(torch.nn.Module):
12
12
  the materialization of the large logits tensor.
13
13
 
14
14
  Args:
15
- jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
15
+ jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
16
16
  ignore_index (int): The index to ignore in the target. Default: `-100`
17
17
  temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18
18
 
@@ -70,9 +70,6 @@ class LigerFusedLinearJSD(torch.nn.Module):
70
70
 
71
71
  def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72
72
  super().__init__()
73
- assert (
74
- jsd_beta > 0 and jsd_beta < 1
75
- ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
76
73
  assert temperature != 0, "temperature cannot be 0."
77
74
  self.jsd_beta = jsd_beta
78
75
  self.temperature = temperature
@@ -18,7 +18,7 @@ class LigerJSD(torch.nn.Module):
18
18
  :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19
19
 
20
20
  Args:
21
- beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
21
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
22
22
  ignore_index (int): The index to ignore in the target. Default: `-100`
23
23
 
24
24
  Shape:
@@ -58,9 +58,6 @@ class LigerJSD(torch.nn.Module):
58
58
 
59
59
  def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60
60
  super().__init__()
61
- assert (
62
- beta > 0 and beta < 1
63
- ), f"beta must be greater than 0 and less than 1. Got: {beta}"
64
61
  self.beta = beta
65
62
  self.ignore_index = ignore_index
66
63
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241121054604
3
+ Version: 0.4.2.dev20241121225747
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -303,8 +303,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
303
303
  <!-- TODO: verify vocab sizes are accurate -->
304
304
  - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
305
305
  - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
306
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
307
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
306
+ - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
307
+ - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
308
308
 
309
309
 
310
310
  ### Experimental Kernels
@@ -1,58 +0,0 @@
1
- from typing import Optional
2
-
3
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
- from liger_kernel.ops.fused_linear_cross_entropy import (
5
- LigerFusedLinearCrossEntropyFunction,
6
- )
7
- from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
- from liger_kernel.ops.geglu import LigerGELUMulFunction
9
- from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
- from liger_kernel.ops.jsd import LigerJSDFunction
11
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
- from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
14
- from liger_kernel.ops.rms_norm import LigerRMSNormFunction
15
- from liger_kernel.ops.rope import LigerRopeFunction
16
- from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
-
18
- liger_swiglu = LigerSiLUMulFunction.apply
19
- liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
20
- liger_geglu = LigerGELUMulFunction.apply
21
- liger_rms_norm = LigerRMSNormFunction.apply
22
- liger_rope = LigerRopeFunction.apply
23
- liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
24
- liger_layer_norm = LigerLayerNormFunction.apply
25
- liger_kl_div = LigerKLDivLossFunction.apply
26
- liger_jsd = LigerJSDFunction.apply
27
- liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
28
- liger_group_norm = LigerGroupNormFunction.apply
29
-
30
-
31
- # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
32
- # `weight` and `size_average` are placeholders and not implemented yet
33
- def liger_cross_entropy(
34
- input,
35
- target,
36
- weight=None,
37
- size_average=None,
38
- ignore_index: int = -100,
39
- reduce=None,
40
- reduction: str = "mean",
41
- label_smoothing: float = 0.0,
42
- lse_square_scale: float = 0.0,
43
- softcap: Optional[float] = None,
44
- return_z_loss: bool = False,
45
- ):
46
- loss, z_loss = LigerCrossEntropyFunction.apply(
47
- input,
48
- target,
49
- ignore_index,
50
- lse_square_scale,
51
- label_smoothing,
52
- reduction,
53
- softcap,
54
- return_z_loss,
55
- )
56
- if not return_z_loss:
57
- return loss
58
- return loss, z_loss