liger-kernel 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {liger_kernel-0.3.0/src/liger_kernel.egg-info → liger_kernel-0.3.1}/PKG-INFO +15 -8
  2. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/README.md +11 -5
  3. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/pyproject.toml +6 -3
  4. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/fused_linear_cross_entropy.py +1 -1
  5. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/geglu.py +2 -2
  6. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/kl_div.py +43 -32
  7. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/swiglu.py +2 -2
  8. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/auto_model.py +18 -6
  9. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/kl_div.py +3 -2
  10. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/monkey_patch.py +96 -122
  11. {liger_kernel-0.3.0 → liger_kernel-0.3.1/src/liger_kernel.egg-info}/PKG-INFO +15 -8
  12. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel.egg-info/requires.txt +4 -2
  13. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/LICENSE +0 -0
  14. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/NOTICE +0 -0
  15. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/setup.cfg +0 -0
  16. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/env_report.py +0 -0
  17. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/__init__.py +0 -0
  18. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/cross_entropy.py +0 -0
  19. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  20. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/layer_norm.py +0 -0
  21. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/rms_norm.py +0 -0
  22. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/rope.py +0 -0
  23. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/ops/utils.py +0 -0
  24. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/__init__.py +0 -0
  25. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  26. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  27. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/functional.py +0 -0
  28. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  29. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/geglu.py +0 -0
  30. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/layer_norm.py +0 -0
  31. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/__init__.py +0 -0
  32. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/gemma.py +0 -0
  33. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/llama.py +0 -0
  34. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/mistral.py +0 -0
  35. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  36. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/phi3.py +0 -0
  37. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  38. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  39. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/rms_norm.py +0 -0
  40. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/rope.py +0 -0
  41. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/swiglu.py +0 -0
  42. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  43. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/triton/__init__.py +0 -0
  44. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel/triton/monkey_patch.py +0 -0
  45. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel.egg-info/SOURCES.txt +0 -0
  46. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  47. {liger_kernel-0.3.0 → liger_kernel-0.3.1}/src/liger_kernel.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -32,15 +32,16 @@ License-File: LICENSE
32
32
  License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.0
35
- Requires-Dist: transformers>=4.42.0
35
+ Provides-Extra: transformers
36
+ Requires-Dist: transformers~=4.0; extra == "transformers"
36
37
  Provides-Extra: dev
38
+ Requires-Dist: transformers>=4.44.2; extra == "dev"
37
39
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
40
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
41
  Requires-Dist: black>=24.4.2; extra == "dev"
40
42
  Requires-Dist: isort>=5.13.2; extra == "dev"
41
43
  Requires-Dist: pytest>=7.1.2; extra == "dev"
42
44
  Requires-Dist: datasets>=2.19.2; extra == "dev"
43
- Requires-Dist: jupyter==1.0.0; extra == "dev"
44
45
  Requires-Dist: seaborn; extra == "dev"
45
46
 
46
47
  # Liger Kernel: Efficient Triton Kernels for LLM Training
@@ -74,8 +75,8 @@ Requires-Dist: seaborn; extra == "dev"
74
75
  </a>
75
76
  </td>
76
77
  <td style="padding: 10px;">
77
- <a href="https://discord.gg/CX2YmNmn">
78
- <img src="https://dcbadge.vercel.app/api/server/cudamode?style=flat" alt="Join Our Discord">
78
+ <a href="https://discord.gg/gpumode">
79
+ <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
79
80
  </a>
80
81
  </td>
81
82
  </tr>
@@ -151,7 +152,10 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
151
152
 
152
153
  - `torch >= 2.1.2`
153
154
  - `triton >= 2.3.0`
154
- - `transformers >= 4.42.0`
155
+
156
+ ### Optional Dependencies
157
+
158
+ - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
155
159
 
156
160
  > **Note:**
157
161
  > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
@@ -174,7 +178,10 @@ To install from source:
174
178
  git clone https://github.com/linkedin/Liger-Kernel.git
175
179
  cd Liger-Kernel
176
180
  pip install -e .
181
+ # or if using transformers
182
+ pip install -e .[transformers]
177
183
  ```
184
+
178
185
  ## Getting Started
179
186
 
180
187
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -271,9 +278,9 @@ loss.backward()
271
278
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
279
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
280
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
274
- | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
281
+ | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
275
282
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
276
- | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
283
+ | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
277
284
 
278
285
 
279
286
 
@@ -29,8 +29,8 @@
29
29
  </a>
30
30
  </td>
31
31
  <td style="padding: 10px;">
32
- <a href="https://discord.gg/CX2YmNmn">
33
- <img src="https://dcbadge.vercel.app/api/server/cudamode?style=flat" alt="Join Our Discord">
32
+ <a href="https://discord.gg/gpumode">
33
+ <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
34
34
  </a>
35
35
  </td>
36
36
  </tr>
@@ -106,7 +106,10 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
106
106
 
107
107
  - `torch >= 2.1.2`
108
108
  - `triton >= 2.3.0`
109
- - `transformers >= 4.42.0`
109
+
110
+ ### Optional Dependencies
111
+
112
+ - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
110
113
 
111
114
  > **Note:**
112
115
  > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
@@ -129,7 +132,10 @@ To install from source:
129
132
  git clone https://github.com/linkedin/Liger-Kernel.git
130
133
  cd Liger-Kernel
131
134
  pip install -e .
135
+ # or if using transformers
136
+ pip install -e .[transformers]
132
137
  ```
138
+
133
139
  ## Getting Started
134
140
 
135
141
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -226,9 +232,9 @@ loss.backward()
226
232
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
227
233
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
228
234
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
229
- | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
235
+ | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
230
236
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
231
- | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
237
+ | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
232
238
 
233
239
 
234
240
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.3.0"
7
+ version = "0.3.1"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -12,18 +12,21 @@ license = { file = "LICENSE" }
12
12
  dependencies = [
13
13
  "torch>=2.1.2",
14
14
  "triton>=2.3.0",
15
- "transformers>=4.42.0"
16
15
  ]
17
16
 
18
17
  [project.optional-dependencies]
18
+ transformers = [
19
+ "transformers~=4.0"
20
+ ]
21
+
19
22
  dev = [
23
+ "transformers>=4.44.2",
20
24
  "matplotlib>=3.7.2",
21
25
  "flake8>=4.0.1.1",
22
26
  "black>=24.4.2",
23
27
  "isort>=5.13.2",
24
28
  "pytest>=7.1.2",
25
29
  "datasets>=2.19.2",
26
- "jupyter==1.0.0",
27
30
  "seaborn",
28
31
  ]
29
32
 
@@ -97,7 +97,7 @@ def fused_linear_cross_entropy_forward(
97
97
 
98
98
  # gradient of logits_chunk is computed in-place by the above triton kernel.
99
99
  # Following HuggingFace model source code, we do the forward and backward
100
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
100
+ # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
101
101
  # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
102
  # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
103
  logits_chunk = logits_chunk.to(dtype)
@@ -25,7 +25,7 @@ else:
25
25
  def _geglu_tanh_forward_kernel(
26
26
  a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
27
27
  ):
28
- program_id = tl.program_id(0)
28
+ program_id = tl.program_id(0).cast(tl.int64)
29
29
 
30
30
  # locate start index
31
31
  a += program_id * stride
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
52
52
  def _geglu_tanh_backward_kernel(
53
53
  dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
54
54
  ):
55
- program_id = tl.program_id(0)
55
+ program_id = tl.program_id(0).cast(tl.int64)
56
56
 
57
57
  # locate start index
58
58
  dc += program_id * stride
@@ -45,6 +45,7 @@ def _kldiv_kernel_forward(
45
45
  loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
46
46
  loss_stride, # int, output stride
47
47
  n_cols, # int, number of columns in the input tensor
48
+ eps,
48
49
  BLOCK_SIZE: tl.constexpr,
49
50
  log_target: tl.constexpr = False,
50
51
  reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
@@ -56,6 +57,7 @@ def _kldiv_kernel_forward(
56
57
 
57
58
  base_offsets = tl.arange(0, BLOCK_SIZE)
58
59
 
60
+ loss_sum = 0.0
59
61
  for i in range(0, n_cols, BLOCK_SIZE):
60
62
  offsets = i + base_offsets
61
63
  mask = offsets < n_cols
@@ -65,32 +67,33 @@ def _kldiv_kernel_forward(
65
67
  # KL(y_true || y) = y_true * (log(y_true) - log(y))
66
68
  # We compute KL(y_true || y) with y in the log-space
67
69
  if not log_target:
68
- loss = y_true * (tl.log(y_true) - y)
70
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
69
71
  else:
70
72
  loss = tl.exp(y_true) * (y_true - y)
71
73
 
72
74
  if reduction == _REDUCTION_MODE_NONE:
73
75
  tl.store(loss_ptr + offsets, loss, mask=mask)
74
76
  else:
75
- loss = tl.sum(loss, axis=0)
76
- tl.store(loss_ptr, loss)
77
- loss_ptr += 1 # in case of reduction, the output tensor has dimensions [B,], therefore stride is always 1
77
+ loss_sum += tl.sum(loss, axis=0)
78
+
79
+ if reduction != _REDUCTION_MODE_NONE:
80
+ tl.store(loss_ptr, loss_sum)
78
81
 
79
82
 
80
83
  @triton.jit
81
84
  def _kldiv_kernel_backward(
82
- input_ptr,
83
- input_stride,
84
85
  target_ptr,
85
86
  target_stride,
87
+ new_grads_ptr,
88
+ new_grads_stride,
86
89
  n_cols,
87
90
  BLOCK_SIZE: tl.constexpr,
88
91
  log_target: tl.constexpr = False,
89
92
  ):
90
93
  pid = tl.program_id(0).to(tl.int64)
91
94
 
92
- input_ptr += pid * input_stride
93
95
  target_ptr += pid * target_stride
96
+ new_grads_ptr += pid * new_grads_stride
94
97
 
95
98
  offsets = tl.arange(0, BLOCK_SIZE)
96
99
  mask = offsets < n_cols
@@ -106,19 +109,19 @@ def _kldiv_kernel_backward(
106
109
  else:
107
110
  res = -tl.exp(target)
108
111
 
109
- tl.store(input_ptr + offsets, res, mask=mask)
112
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
110
113
 
111
114
 
112
- def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B, S]
113
- B, S = y_pred.shape
115
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
116
+ BT, V = y_pred.shape
114
117
 
115
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
118
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
116
119
  num_warps = get_num_warps(BLOCK_SIZE)
117
120
 
118
- grid = (B,)
121
+ grid = (BT,)
119
122
  reduction = _str_to_reduction_mode[reduction]
120
123
 
121
- out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,)
124
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
122
125
  output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
123
126
 
124
127
  _kldiv_kernel_forward[grid](
@@ -128,7 +131,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
128
131
  y_true.stride(0),
129
132
  output_tensor,
130
133
  output_tensor.stride(0),
131
- S,
134
+ V,
135
+ eps=eps,
132
136
  BLOCK_SIZE=BLOCK_SIZE,
133
137
  num_warps=num_warps,
134
138
  log_target=log_target,
@@ -139,30 +143,30 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
139
143
  # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
140
144
  # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
141
145
  if reduction == _REDUCTION_MODE_BATCHMEAN.value:
142
- return output_tensor.sum() / B
146
+ return output_tensor.sum() / BT
143
147
  elif reduction == _REDUCTION_MODE_SUM.value:
144
148
  return output_tensor.sum(dim=0)
145
149
  elif reduction == _REDUCTION_MODE_MEAN.value:
146
- return output_tensor.mean(dim=0)
150
+ return output_tensor.sum() / (BT * V)
147
151
  else:
148
152
  return output_tensor
149
153
 
150
154
 
151
- def kldiv_backward_triton(input, target, grad_output, log_target):
152
- B, S = input.shape
155
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
156
+ BT, V = target.shape
153
157
 
154
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
158
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
155
159
  num_warps = get_num_warps(BLOCK_SIZE)
156
160
 
157
- grid = (B,)
161
+ grid = (BT,)
158
162
 
159
163
  # We store the gradients in-place in the input tensor
160
164
  _kldiv_kernel_backward[grid](
161
- input,
162
- input.stride(0),
163
165
  target,
164
166
  target.stride(0),
165
- S,
167
+ new_grads,
168
+ new_grads.stride(0),
169
+ V,
166
170
  BLOCK_SIZE=BLOCK_SIZE,
167
171
  num_warps=num_warps,
168
172
  log_target=log_target,
@@ -170,9 +174,9 @@ def kldiv_backward_triton(input, target, grad_output, log_target):
170
174
 
171
175
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
172
176
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
173
- return input
177
+ return new_grads
174
178
 
175
- return input * grad_output
179
+ return new_grads * grad_output
176
180
 
177
181
 
178
182
  class LigerKLDivLossFunction(torch.autograd.Function):
@@ -196,6 +200,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
196
200
  y_true: torch.Tensor,
197
201
  reduction: REDUCTION_LITERAL = "batchmean",
198
202
  log_target: bool = False,
203
+ eps: float = 1e-10,
199
204
  ) -> torch.Tensor:
200
205
  """A forward pass for the KL Divergence Loss.
201
206
 
@@ -205,15 +210,16 @@ class LigerKLDivLossFunction(torch.autograd.Function):
205
210
  y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
206
211
  reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
207
212
  log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
213
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
208
214
 
209
215
  Returns:
210
216
  torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
211
217
  """
212
- ctx.save_for_backward(y_pred, y_true)
218
+ ctx.save_for_backward(y_true)
213
219
  ctx.reduction = reduction
214
220
  ctx.log_target = log_target
215
221
  return kldiv_forward_triton(
216
- y_pred, y_true, log_target=log_target, reduction=reduction
222
+ y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
217
223
  )
218
224
 
219
225
  @staticmethod
@@ -226,22 +232,27 @@ class LigerKLDivLossFunction(torch.autograd.Function):
226
232
  grad_output (torch.Tensor): The gradient of the loss with respect to the output.
227
233
 
228
234
  Returns:
229
- tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
235
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
230
236
  """
231
- y_pred, y_true = ctx.saved_tensors
237
+ (y_true,) = ctx.saved_tensors
238
+
239
+ new_grads = torch.empty_like(y_true)
232
240
 
233
- derivative = kldiv_backward_triton(y_pred, y_true, grad_output, ctx.log_target)
241
+ derivative = kldiv_backward_triton(
242
+ y_true, grad_output, new_grads, ctx.log_target
243
+ )
234
244
 
235
245
  if ctx.reduction == "batchmean":
236
- derivative = derivative / y_pred.shape[0]
246
+ derivative = derivative / y_true.shape[0]
237
247
  elif ctx.reduction == "sum" or ctx.reduction == "none":
238
248
  pass
239
249
  elif ctx.reduction == "mean":
240
- derivative = derivative / (y_pred.shape[0] * y_pred.shape[1])
250
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
241
251
 
242
252
  return (
243
253
  derivative,
244
254
  None,
245
255
  None,
246
256
  None,
257
+ None,
247
258
  )
@@ -14,7 +14,7 @@ def silu(x):
14
14
  def _swiglu_forward_kernel(
15
15
  a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
16
16
  ):
17
- program_id = tl.program_id(0)
17
+ program_id = tl.program_id(0).cast(tl.int64)
18
18
 
19
19
  # locate start index
20
20
  a_ptr += program_id * stride
@@ -35,7 +35,7 @@ def _swiglu_forward_kernel(
35
35
  def _swiglu_backward_kernel(
36
36
  dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
37
37
  ):
38
- program_id = tl.program_id(0)
38
+ program_id = tl.program_id(0).cast(tl.int64)
39
39
 
40
40
  # locate start index
41
41
  dc_ptr += program_id * stride
@@ -1,6 +1,11 @@
1
+ import inspect
2
+
1
3
  from transformers import AutoConfig, AutoModelForCausalLM
2
4
 
3
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
5
+ from liger_kernel.transformers.monkey_patch import (
6
+ MODEL_TYPE_TO_APPLY_LIGER_FN,
7
+ _apply_liger_kernel,
8
+ )
4
9
 
5
10
 
6
11
  def _get_model_config(model_dir, **model_init_kwargs):
@@ -21,13 +26,20 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
21
26
  # Determine the model type and apply the Liger Kernel if applicable
22
27
  # Note: _apply_liger_kernel will only pass relevant kwargs to the apply_liger_kernel_to_* function
23
28
  model_type = model_config.model_type
29
+
24
30
  _apply_liger_kernel(model_type, **kwargs)
25
31
 
26
- # Retain only the keyword args present in the model configuration
27
- for k in list(kwargs.keys()):
28
- if k not in model_config.__dict__:
29
- del kwargs[k]
32
+ # Filter out kwargs that were passed to the apply_liger_* function, which will cause
33
+ # model initialization errors otherwise
34
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
35
+ apply_fn_signature = inspect.signature(apply_fn)
36
+
37
+ applicable_kwargs = {
38
+ key: value
39
+ for key, value in kwargs.items()
40
+ if key not in apply_fn_signature.parameters
41
+ }
30
42
 
31
43
  return super().from_pretrained(
32
- pretrained_model_name_or_path, *model_args, **kwargs
44
+ pretrained_model_name_or_path, *model_args, **applicable_kwargs
33
45
  )
@@ -4,10 +4,11 @@ from liger_kernel.ops.kl_div import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
7
- def __init__(self, *args, **kwargs):
7
+ def __init__(self, eps: float = 1e-10, *args, **kwargs):
8
8
  super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
9
+ self.eps = eps
9
10
 
10
11
  def forward(self, y_pred, y_true):
11
12
  return LigerKLDivLossFunction.apply(
12
- y_pred, y_true, self.reduction, self.log_target
13
+ y_pred, y_true, self.reduction, self.log_target, self.eps
13
14
  )
@@ -1,9 +1,9 @@
1
1
  import inspect
2
2
  import logging
3
3
  from functools import partial
4
+ from typing import Callable
4
5
 
5
- from torch import nn
6
- from transformers import PretrainedConfig, PreTrainedModel
6
+ from transformers import PreTrainedModel
7
7
 
8
8
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
9
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
@@ -25,6 +25,30 @@ from liger_kernel.transformers.swiglu import (
25
25
  logger = logging.getLogger(__name__)
26
26
 
27
27
 
28
+ def _bind_method_to_module(module, method_name: str, new_method: Callable):
29
+ # Binds a new method to a module instance so that self is passed as the first argument
30
+ module.__dict__[method_name] = new_method.__get__(module, module.__class__)
31
+
32
+
33
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
34
+ module.offset = offset
35
+ module.casting_mode = casting_mode
36
+ module.variance_epsilon = (
37
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
38
+ )
39
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
40
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
41
+
42
+
43
+ def _patch_layer_norm_module(module, eps=1e-6):
44
+ module.variance_epsilon = (
45
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
46
+ )
47
+ module.hidden_size = module.normalized_shape
48
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
49
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
50
+
51
+
28
52
  def apply_liger_kernel_to_llama(
29
53
  rope: bool = True,
30
54
  cross_entropy: bool = False,
@@ -69,7 +93,6 @@ def apply_liger_kernel_to_llama(
69
93
  if model is not None:
70
94
  # The model instance already exists, so we need to additionally patch the
71
95
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
72
- config: PretrainedConfig = model.config
73
96
 
74
97
  if hasattr(model, "model"):
75
98
  # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
@@ -81,22 +104,17 @@ def apply_liger_kernel_to_llama(
81
104
  # Direct LlamaModel
82
105
  base_model = model
83
106
 
84
- torch_dtype = config.torch_dtype
85
107
  if rms_norm:
86
- base_model.norm = LigerRMSNorm(
87
- config.hidden_size, eps=config.rms_norm_eps
88
- ).to(torch_dtype)
108
+ _patch_rms_norm_module(base_model.norm)
89
109
 
90
110
  for decoder_layer in base_model.layers:
91
111
  if swiglu:
92
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
112
+ _bind_method_to_module(
113
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
114
+ )
93
115
  if rms_norm:
94
- decoder_layer.input_layernorm = LigerRMSNorm(
95
- config.hidden_size, eps=config.rms_norm_eps
96
- ).to(torch_dtype)
97
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
98
- config.hidden_size, eps=config.rms_norm_eps
99
- ).to(torch_dtype)
116
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
117
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
100
118
 
101
119
 
102
120
  def apply_liger_kernel_to_mistral(
@@ -143,7 +161,6 @@ def apply_liger_kernel_to_mistral(
143
161
  if model is not None:
144
162
  # The model instance already exists, so we need to additionally patch the
145
163
  # instance variables that reference already-instantiated modules
146
- config: PretrainedConfig = model.config
147
164
 
148
165
  if hasattr(model, "model"):
149
166
  # The case for MistralForCausalLM, MistralForTokenClassification for example
@@ -152,22 +169,17 @@ def apply_liger_kernel_to_mistral(
152
169
  # Direct MistralModel
153
170
  base_model = model
154
171
 
155
- torch_dtype = config.torch_dtype
156
172
  if rms_norm:
157
- base_model.norm = LigerRMSNorm(
158
- config.hidden_size, eps=config.rms_norm_eps
159
- ).to(torch_dtype)
173
+ _patch_rms_norm_module(base_model.norm)
160
174
 
161
175
  for decoder_layer in base_model.layers:
162
176
  if swiglu:
163
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
177
+ _bind_method_to_module(
178
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
179
+ )
164
180
  if rms_norm:
165
- decoder_layer.input_layernorm = LigerRMSNorm(
166
- config.hidden_size, eps=config.rms_norm_eps
167
- ).to(torch_dtype)
168
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
169
- config.hidden_size, eps=config.rms_norm_eps
170
- ).to(torch_dtype)
181
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
182
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
171
183
 
172
184
 
173
185
  def apply_liger_kernel_to_mixtral(
@@ -214,7 +226,6 @@ def apply_liger_kernel_to_mixtral(
214
226
  if model is not None:
215
227
  # The model instance already exists, so we need to additionally patch the
216
228
  # instance variables that reference already-instantiated modules
217
- config: PretrainedConfig = model.config
218
229
 
219
230
  if hasattr(model, "model"):
220
231
  # The case for MixtralForCausalLM, MixtralForTokenClassification for example
@@ -223,29 +234,18 @@ def apply_liger_kernel_to_mixtral(
223
234
  # Direct MixtralModel
224
235
  base_model = model
225
236
 
226
- torch_dtype = config.torch_dtype
227
237
  if rms_norm:
228
- base_model.norm = LigerRMSNorm(
229
- config.hidden_size, eps=config.rms_norm_eps
230
- ).to(torch_dtype)
238
+ _patch_rms_norm_module(base_model.norm)
231
239
 
232
240
  for decoder_layer in base_model.layers:
233
241
  if swiglu:
234
- block_sparse_moe = decoder_layer.block_sparse_moe
235
- patched_experts = nn.ModuleList(
236
- [
237
- LigerBlockSparseTop2MLP(config)
238
- for _ in range(block_sparse_moe.num_experts)
239
- ]
240
- )
241
- decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype)
242
+ for expert in decoder_layer.block_sparse_moe.experts:
243
+ _bind_method_to_module(
244
+ expert, "forward", LigerBlockSparseTop2MLP.forward
245
+ )
242
246
  if rms_norm:
243
- decoder_layer.input_layernorm = LigerRMSNorm(
244
- config.hidden_size, eps=config.rms_norm_eps
245
- ).to(torch_dtype)
246
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
247
- config.hidden_size, eps=config.rms_norm_eps
248
- ).to(torch_dtype)
247
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
248
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
249
249
 
250
250
 
251
251
  def apply_liger_kernel_to_gemma(
@@ -282,6 +282,9 @@ def apply_liger_kernel_to_gemma(
282
282
  LigerRMSNormForGemma = partial(
283
283
  LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
284
284
  )
285
+ _patch_rms_norm_module_for_gemma = partial(
286
+ _patch_rms_norm_module, casting_mode="gemma", offset=1.0
287
+ )
285
288
 
286
289
  if rope:
287
290
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -297,7 +300,6 @@ def apply_liger_kernel_to_gemma(
297
300
  if model is not None:
298
301
  # The model instance already exists, so we need to additionally patch the
299
302
  # instance variables that reference already-instantiated modules
300
- config: PretrainedConfig = model.config
301
303
 
302
304
  if hasattr(model, "model"):
303
305
  # The case for GemmaForCausalLM, GemmaForTokenClassification for example
@@ -306,22 +308,17 @@ def apply_liger_kernel_to_gemma(
306
308
  # Direct GemmaModel
307
309
  base_model = model
308
310
 
309
- torch_dtype = config.torch_dtype
310
311
  if rms_norm:
311
- base_model.norm = LigerRMSNormForGemma(
312
- config.hidden_size, eps=config.rms_norm_eps
313
- ).to(torch_dtype)
312
+ _patch_rms_norm_module_for_gemma(base_model.norm)
314
313
 
315
314
  for decoder_layer in base_model.layers:
316
315
  if geglu:
317
- decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
316
+ _bind_method_to_module(
317
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
318
+ )
318
319
  if rms_norm:
319
- decoder_layer.input_layernorm = LigerRMSNormForGemma(
320
- config.hidden_size, eps=config.rms_norm_eps
321
- ).to(torch_dtype)
322
- decoder_layer.post_attention_layernorm = LigerRMSNormForGemma(
323
- config.hidden_size, eps=config.rms_norm_eps
324
- ).to(torch_dtype)
320
+ _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
321
+ _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
325
322
 
326
323
 
327
324
  def apply_liger_kernel_to_gemma2(
@@ -343,10 +340,15 @@ def apply_liger_kernel_to_gemma2(
343
340
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
344
341
  loaded. Default is None.
345
342
  """
346
- print("Got here!")
347
343
  from transformers.models.gemma2 import modeling_gemma2
348
344
 
349
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, init_fn="zeros")
345
+ LigerRMSNormForGemma2 = partial(
346
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
347
+ )
348
+ _patch_rms_norm_module_for_gemma2 = partial(
349
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
350
+ )
351
+
350
352
  if rope:
351
353
  modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
352
354
  if rms_norm:
@@ -360,7 +362,6 @@ def apply_liger_kernel_to_gemma2(
360
362
  if model is not None:
361
363
  # The model instance already exists, so we need to additionally patch the
362
364
  # instance variables that reference already-instantiated modules
363
- config: PretrainedConfig = model.config
364
365
 
365
366
  if hasattr(model, "model"):
366
367
  # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
@@ -369,28 +370,25 @@ def apply_liger_kernel_to_gemma2(
369
370
  # Direct Gemma2Model
370
371
  base_model = model
371
372
 
372
- torch_dtype = config.torch_dtype
373
373
  if rms_norm:
374
- base_model.norm = LigerRMSNormForGemma2(
375
- config.hidden_size, eps=config.rms_norm_eps
376
- ).to(torch_dtype)
374
+ _patch_rms_norm_module_for_gemma2(base_model.norm)
377
375
 
378
376
  for decoder_layer in base_model.layers:
379
377
  if geglu:
380
- decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
378
+ _bind_method_to_module(
379
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
380
+ )
381
381
  if rms_norm:
382
- decoder_layer.input_layernorm = LigerRMSNormForGemma2(
383
- config.hidden_size, eps=config.rms_norm_eps
384
- ).to(torch_dtype)
385
- decoder_layer.post_attention_layernorm = LigerRMSNormForGemma2(
386
- config.hidden_size, eps=config.rms_norm_eps
387
- ).to(torch_dtype)
388
- decoder_layer.pre_feedforward_layernorm = LigerRMSNormForGemma2(
389
- config.hidden_size, eps=config.rms_norm_eps
390
- ).to(torch_dtype)
391
- decoder_layer.post_feedforward_layernorm = LigerRMSNormForGemma2(
392
- config.hidden_size, eps=config.rms_norm_eps
393
- ).to(torch_dtype)
382
+ _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
383
+ _patch_rms_norm_module_for_gemma2(
384
+ decoder_layer.post_attention_layernorm
385
+ )
386
+ _patch_rms_norm_module_for_gemma2(
387
+ decoder_layer.pre_feedforward_layernorm
388
+ )
389
+ _patch_rms_norm_module_for_gemma2(
390
+ decoder_layer.post_feedforward_layernorm
391
+ )
394
392
 
395
393
 
396
394
  def apply_liger_kernel_to_qwen2(
@@ -436,7 +434,6 @@ def apply_liger_kernel_to_qwen2(
436
434
  if model is not None:
437
435
  # The model instance already exists, so we need to additionally patch the
438
436
  # instance variables that reference already-instantiated modules
439
- config: PretrainedConfig = model.config
440
437
 
441
438
  if hasattr(model, "model"):
442
439
  # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
@@ -445,22 +442,17 @@ def apply_liger_kernel_to_qwen2(
445
442
  # Direct Qwen2Model
446
443
  base_model = model
447
444
 
448
- torch_dtype = config.torch_dtype
449
445
  if rms_norm:
450
- base_model.norm = LigerRMSNorm(
451
- config.hidden_size, eps=config.rms_norm_eps
452
- ).to(torch_dtype)
446
+ _patch_rms_norm_module(base_model.norm)
453
447
 
454
448
  for decoder_layer in base_model.layers:
455
449
  if swiglu:
456
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
450
+ _bind_method_to_module(
451
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
452
+ )
457
453
  if rms_norm:
458
- decoder_layer.input_layernorm = LigerRMSNorm(
459
- config.hidden_size, eps=config.rms_norm_eps
460
- ).to(torch_dtype)
461
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
462
- config.hidden_size, eps=config.rms_norm_eps
463
- ).to(torch_dtype)
454
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
455
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
464
456
 
465
457
 
466
458
  def apply_liger_kernel_to_qwen2_vl(
@@ -499,10 +491,9 @@ def apply_liger_kernel_to_qwen2_vl(
499
491
 
500
492
  # TODO: Support Qwen2-VL's multimodal RoPE implementation
501
493
 
502
- LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma")
503
494
  if rms_norm:
504
495
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
505
- modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNormForQwen2VL
496
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
506
497
  if layer_norm:
507
498
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
508
499
  if cross_entropy:
@@ -515,9 +506,6 @@ def apply_liger_kernel_to_qwen2_vl(
515
506
  if model is not None:
516
507
  # The model instance already exists, so we need to additionally patch the
517
508
  # instance variables that reference already-instantiated modules
518
- config: PretrainedConfig = model.config
519
-
520
- torch_dtype = config.torch_dtype
521
509
 
522
510
  if hasattr(model, "model"):
523
511
  # The case for Qwen2VLForConditionalGeneration.
@@ -530,27 +518,19 @@ def apply_liger_kernel_to_qwen2_vl(
530
518
  # Patch Qwen2VisionTransformerPretrainedModel
531
519
  for vision_block in model.visual.blocks:
532
520
  if layer_norm:
533
- vision_block.norm1 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
534
- torch_dtype
535
- )
536
- vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
537
- torch_dtype
538
- )
521
+ _patch_layer_norm_module(vision_block.norm1)
522
+ _patch_layer_norm_module(vision_block.norm2)
539
523
 
540
524
  if rms_norm:
541
- base_model.norm = LigerRMSNormForQwen2VL(
542
- config.hidden_size, eps=config.rms_norm_eps
543
- ).to(torch_dtype)
525
+ _patch_rms_norm_module(base_model.norm)
544
526
  for decoder_layer in base_model.layers:
545
527
  if swiglu:
546
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
528
+ _bind_method_to_module(
529
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
530
+ )
547
531
  if rms_norm:
548
- decoder_layer.input_layernorm = LigerRMSNormForQwen2VL(
549
- config.hidden_size, eps=config.rms_norm_eps
550
- ).to(torch_dtype)
551
- decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL(
552
- config.hidden_size, eps=config.rms_norm_eps
553
- ).to(torch_dtype)
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
554
534
 
555
535
 
556
536
  def apply_liger_kernel_to_phi3(
@@ -596,7 +576,6 @@ def apply_liger_kernel_to_phi3(
596
576
  if model is not None:
597
577
  # The model instance already exists, so we need to additionally patch the
598
578
  # instance variables that reference already-instantiated modules
599
- config: PretrainedConfig = model.config
600
579
 
601
580
  if hasattr(model, "model"):
602
581
  # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
@@ -605,22 +584,17 @@ def apply_liger_kernel_to_phi3(
605
584
  # Direct Phi3Model
606
585
  base_model = model
607
586
 
608
- torch_dtype = config.torch_dtype
609
587
  if rms_norm:
610
- base_model.norm = LigerRMSNorm(
611
- config.hidden_size, eps=config.rms_norm_eps
612
- ).to(torch_dtype)
588
+ _patch_rms_norm_module(base_model.norm)
613
589
 
614
590
  for decoder_layer in base_model.layers:
615
591
  if swiglu:
616
- decoder_layer.mlp = LigerPhi3SwiGLUMLP(config).to(torch_dtype)
592
+ _bind_method_to_module(
593
+ decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
594
+ )
617
595
  if rms_norm:
618
- decoder_layer.input_layernorm = LigerRMSNorm(
619
- config.hidden_size, eps=config.rms_norm_eps
620
- ).to(torch_dtype)
621
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
622
- config.hidden_size, eps=config.rms_norm_eps
623
- ).to(torch_dtype)
596
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
597
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
624
598
 
625
599
 
626
600
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -32,15 +32,16 @@ License-File: LICENSE
32
32
  License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.0
35
- Requires-Dist: transformers>=4.42.0
35
+ Provides-Extra: transformers
36
+ Requires-Dist: transformers~=4.0; extra == "transformers"
36
37
  Provides-Extra: dev
38
+ Requires-Dist: transformers>=4.44.2; extra == "dev"
37
39
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
40
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
41
  Requires-Dist: black>=24.4.2; extra == "dev"
40
42
  Requires-Dist: isort>=5.13.2; extra == "dev"
41
43
  Requires-Dist: pytest>=7.1.2; extra == "dev"
42
44
  Requires-Dist: datasets>=2.19.2; extra == "dev"
43
- Requires-Dist: jupyter==1.0.0; extra == "dev"
44
45
  Requires-Dist: seaborn; extra == "dev"
45
46
 
46
47
  # Liger Kernel: Efficient Triton Kernels for LLM Training
@@ -74,8 +75,8 @@ Requires-Dist: seaborn; extra == "dev"
74
75
  </a>
75
76
  </td>
76
77
  <td style="padding: 10px;">
77
- <a href="https://discord.gg/CX2YmNmn">
78
- <img src="https://dcbadge.vercel.app/api/server/cudamode?style=flat" alt="Join Our Discord">
78
+ <a href="https://discord.gg/gpumode">
79
+ <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
79
80
  </a>
80
81
  </td>
81
82
  </tr>
@@ -151,7 +152,10 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
151
152
 
152
153
  - `torch >= 2.1.2`
153
154
  - `triton >= 2.3.0`
154
- - `transformers >= 4.42.0`
155
+
156
+ ### Optional Dependencies
157
+
158
+ - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
155
159
 
156
160
  > **Note:**
157
161
  > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
@@ -174,7 +178,10 @@ To install from source:
174
178
  git clone https://github.com/linkedin/Liger-Kernel.git
175
179
  cd Liger-Kernel
176
180
  pip install -e .
181
+ # or if using transformers
182
+ pip install -e .[transformers]
177
183
  ```
184
+
178
185
  ## Getting Started
179
186
 
180
187
  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
@@ -271,9 +278,9 @@ loss.backward()
271
278
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
279
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
280
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
274
- | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
281
+ | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
275
282
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
276
- | Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
283
+ | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
277
284
 
278
285
 
279
286
 
@@ -1,13 +1,15 @@
1
1
  torch>=2.1.2
2
2
  triton>=2.3.0
3
- transformers>=4.42.0
4
3
 
5
4
  [dev]
5
+ transformers>=4.44.2
6
6
  matplotlib>=3.7.2
7
7
  flake8>=4.0.1.1
8
8
  black>=24.4.2
9
9
  isort>=5.13.2
10
10
  pytest>=7.1.2
11
11
  datasets>=2.19.2
12
- jupyter==1.0.0
13
12
  seaborn
13
+
14
+ [transformers]
15
+ transformers~=4.0
File without changes
File without changes
File without changes