liger-kernel 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,23 @@
1
+ import inspect
2
+ import logging
3
+ from functools import partial
4
+
1
5
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
2
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
3
- from liger_kernel.transformers.model.llama import lce_forward
7
+ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
8
+ from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
9
+ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
10
+ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
11
+ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
4
12
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
5
13
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
6
- from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
14
+ from liger_kernel.transformers.swiglu import (
15
+ LigerBlockSparseTop2MLP,
16
+ LigerPhi3SwiGLUMLP,
17
+ LigerSwiGLUMLP,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
7
21
 
8
22
 
9
23
  def apply_liger_kernel_to_llama(
@@ -20,7 +34,7 @@ def apply_liger_kernel_to_llama(
20
34
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
21
35
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
22
36
  fused_linear_cross_entropy (bool):
23
- Whether to apply Liger's fused lienar cross entropy loss. Default is True.
37
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
24
38
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
25
39
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
26
40
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
@@ -42,12 +56,13 @@ def apply_liger_kernel_to_llama(
42
56
  if cross_entropy:
43
57
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
44
58
  if fused_linear_cross_entropy:
45
- modeling_llama.LlamaForCausalLM.forward = lce_forward
59
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
46
60
 
47
61
 
48
62
  def apply_liger_kernel_to_mistral(
49
63
  rope: bool = True,
50
- cross_entropy: bool = True,
64
+ cross_entropy: bool = False,
65
+ fused_linear_cross_entropy: bool = True,
51
66
  rms_norm: bool = True,
52
67
  swiglu: bool = True,
53
68
  ) -> None:
@@ -57,9 +72,17 @@ def apply_liger_kernel_to_mistral(
57
72
  Args:
58
73
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
59
74
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
75
+ fused_linear_cross_entropy (bool):
76
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
77
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
78
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
79
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
60
80
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
61
81
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
62
82
  """
83
+ assert not (
84
+ cross_entropy and fused_linear_cross_entropy
85
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
63
86
 
64
87
  from transformers.models.mistral import modeling_mistral
65
88
 
@@ -69,6 +92,8 @@ def apply_liger_kernel_to_mistral(
69
92
  modeling_mistral.MistralRMSNorm = LigerRMSNorm
70
93
  if cross_entropy:
71
94
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
95
+ if fused_linear_cross_entropy:
96
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
72
97
  if swiglu:
73
98
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
74
99
 
@@ -94,7 +119,7 @@ def apply_liger_kernel_to_mixtral(
94
119
  if rope:
95
120
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
96
121
  if rms_norm:
97
- modeling_mixtral.MistralRMSNorm = LigerRMSNorm
122
+ modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
98
123
  if cross_entropy:
99
124
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
100
125
  if swiglu:
@@ -104,12 +129,13 @@ def apply_liger_kernel_to_mixtral(
104
129
  def apply_liger_kernel_to_gemma(
105
130
  rope: bool = True,
106
131
  cross_entropy: bool = True,
132
+ fused_linear_cross_entropy: bool = True,
107
133
  rms_norm: bool = True,
108
134
  geglu: bool = True,
109
135
  ) -> None:
110
136
  """
111
- Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
112
- to make GPU go burrr.
137
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma
138
+ (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
113
139
 
114
140
  Args:
115
141
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -117,14 +143,181 @@ def apply_liger_kernel_to_gemma(
117
143
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
118
144
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
119
145
  """
120
- # TODO(yundai424): add convergence test for gemma
146
+ assert not (
147
+ cross_entropy and fused_linear_cross_entropy
148
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
149
+
121
150
  from transformers.models.gemma import modeling_gemma
122
151
 
123
152
  if rope:
124
153
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
125
154
  if rms_norm:
126
- modeling_gemma.GemmaRMSNorm = LigerRMSNorm
155
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
156
+ modeling_gemma.GemmaRMSNorm = partial(
157
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
158
+ )
127
159
  if cross_entropy:
128
160
  modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
129
161
  if geglu:
130
162
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
163
+ if fused_linear_cross_entropy:
164
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
165
+
166
+
167
+ def apply_liger_kernel_to_gemma2(
168
+ rope: bool = True,
169
+ cross_entropy: bool = True,
170
+ rms_norm: bool = True,
171
+ geglu: bool = True,
172
+ ) -> None:
173
+ """
174
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma2
175
+ (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
176
+
177
+ Args:
178
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
179
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
180
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
181
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
182
+ """
183
+ from transformers.models.gemma2 import modeling_gemma2
184
+
185
+ if rope:
186
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
187
+ if rms_norm:
188
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
189
+ modeling_gemma2.Gemma2RMSNorm = partial(
190
+ LigerRMSNorm, offset=1.0, init_fn="zeros"
191
+ )
192
+ if cross_entropy:
193
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
194
+ if geglu:
195
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
196
+
197
+
198
+ def apply_liger_kernel_to_qwen2(
199
+ rope: bool = True,
200
+ cross_entropy: bool = False,
201
+ fused_linear_cross_entropy: bool = True,
202
+ rms_norm: bool = True,
203
+ swiglu: bool = True,
204
+ ) -> None:
205
+ """
206
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
207
+
208
+ Args:
209
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
210
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
211
+ fused_linear_cross_entropy (bool):
212
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
213
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
214
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
215
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
216
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
217
+ """
218
+ assert not (
219
+ cross_entropy and fused_linear_cross_entropy
220
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
221
+
222
+ from transformers.models.qwen2 import modeling_qwen2
223
+
224
+ if rope:
225
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
226
+ if rms_norm:
227
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
228
+ if cross_entropy:
229
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
230
+ if fused_linear_cross_entropy:
231
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
232
+ if swiglu:
233
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
234
+
235
+
236
+ def apply_liger_kernel_to_phi3(
237
+ rope: bool = True,
238
+ cross_entropy: bool = False,
239
+ fused_linear_cross_entropy: bool = True,
240
+ rms_norm: bool = True,
241
+ swiglu: bool = True,
242
+ ) -> None:
243
+ """
244
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
245
+
246
+ Args:
247
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
248
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
249
+ fused_linear_cross_entropy (bool):
250
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
251
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
252
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
253
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
254
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
255
+ """
256
+ assert not (
257
+ cross_entropy and fused_linear_cross_entropy
258
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
259
+
260
+ from transformers.models.phi3 import modeling_phi3
261
+
262
+ if rope:
263
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
264
+ if rms_norm:
265
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
266
+ if swiglu:
267
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
268
+ if cross_entropy:
269
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
270
+ if fused_linear_cross_entropy:
271
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
272
+
273
+
274
+ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
275
+ MODEL_TYPE_TO_APPLY_LIGER_FN = {
276
+ "gemma": apply_liger_kernel_to_gemma,
277
+ "gemma2": apply_liger_kernel_to_gemma2,
278
+ "llama": apply_liger_kernel_to_llama,
279
+ "mistral": apply_liger_kernel_to_mistral,
280
+ "mixtral": apply_liger_kernel_to_mixtral,
281
+ "qwen2": apply_liger_kernel_to_qwen2,
282
+ "phi3": apply_liger_kernel_to_phi3,
283
+ }
284
+
285
+
286
+ def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
287
+ """
288
+ Applies Liger kernels based on the specified model type. The custom
289
+ kernels for the specified model type will be applied with the provided
290
+ keyword arguments, otherwise the default configuration will be used.
291
+
292
+ Args:
293
+ - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
294
+ and specified in the model's config.json
295
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
296
+ """
297
+
298
+ if not model_type:
299
+ logger.info("Model type was not provided. No Liger kernels will be applied.")
300
+ return
301
+
302
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
303
+ logger.info(
304
+ f"There are currently no Liger kernels supported for model type: {model_type}."
305
+ )
306
+ return
307
+
308
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
309
+ apply_fn_signature = inspect.signature(apply_fn)
310
+
311
+ # Filter out the keyword arguments that are not supported by the apply function
312
+ applicable_kwargs = {
313
+ key: value
314
+ for key, value in kwargs.items()
315
+ if key in apply_fn_signature.parameters
316
+ }
317
+
318
+ logger.info(
319
+ f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
320
+ )
321
+
322
+ # Apply the default combination of liger kernels available for the model
323
+ apply_fn(**applicable_kwargs)
@@ -5,12 +5,28 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
5
5
 
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
- def __init__(self, hidden_size, eps=1e-6):
8
+ def __init__(
9
+ self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
10
+ ):
9
11
  super().__init__()
10
- self.weight = nn.Parameter(torch.ones(hidden_size))
11
- self.variance_epsilon = eps
12
+ assert init_fn in [
13
+ "ones",
14
+ "zeros",
15
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
16
+ self.weight = nn.Parameter(
17
+ torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
+ )
19
+ self.variance_epsilon, self.offset, self.casting_mode = (
20
+ eps,
21
+ offset,
22
+ casting_mode,
23
+ )
12
24
 
13
25
  def forward(self, hidden_states):
14
26
  return LigerRMSNormFunction.apply(
15
- hidden_states, self.weight, self.variance_epsilon
27
+ hidden_states,
28
+ self.weight,
29
+ self.variance_epsilon,
30
+ self.offset,
31
+ self.casting_mode,
16
32
  )
@@ -38,3 +38,27 @@ class LigerBlockSparseTop2MLP(nn.Module):
38
38
  def forward(self, x):
39
39
 
40
40
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
+
42
+
43
+ class LigerPhi3SwiGLUMLP(nn.Module):
44
+ """
45
+ Patch Phi3MLP to use LigerSiLUMulFunction
46
+ https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.config = config
52
+ self.hidden_size = config.hidden_size
53
+ self.intermediate_size = config.intermediate_size
54
+ self.gate_up_proj = nn.Linear(
55
+ self.hidden_size, 2 * self.intermediate_size, bias=False
56
+ )
57
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
+ if config.hidden_act not in ["silu", "swish"]:
59
+ raise ValueError(f"Activation function {config.hidden_act} not supported.")
60
+
61
+ def forward(self, x):
62
+ up_states = self.gate_up_proj(x)
63
+ gate, up_states = up_states.chunk(2, dim=-1)
64
+ return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
@@ -1,45 +1,2 @@
1
- import logging
2
-
3
- from liger_kernel.transformers.monkey_patch import (
4
- apply_liger_kernel_to_gemma,
5
- apply_liger_kernel_to_llama,
6
- apply_liger_kernel_to_mistral,
7
- apply_liger_kernel_to_mixtral,
8
- )
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
13
- MODEL_TYPE_TO_APPLY_LIGER_FN = {
14
- "gemma": apply_liger_kernel_to_gemma,
15
- "llama": apply_liger_kernel_to_llama,
16
- "mistral": apply_liger_kernel_to_mistral,
17
- "mixtral": apply_liger_kernel_to_mixtral,
18
- }
19
-
20
-
21
- def _apply_liger_kernel(model_type: str = "", **kwargs) -> None:
22
- """
23
- Applies Liger kernels based on the specified model type. The custom
24
- kernels for the specified model type will be applied with the provided
25
- keyword arguments, otherwise the default configuration will be used.
26
-
27
- Args:
28
- - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
29
- and specified in the model's config.json
30
- - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
31
- """
32
-
33
- if not model_type:
34
- logger.info("Model type was not provided. No Liger kernels will be applied.")
35
- return
36
-
37
- if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
38
- logger.info(
39
- f"There are currently no Liger kernels supported for model type: {model_type}."
40
- )
41
- return
42
-
43
- logger.info(f"Applying Liger kernels for model type: {model_type}.")
44
- # Apply the default combination of liger kernels available for the model
45
- MODEL_TYPE_TO_APPLY_LIGER_FN[model_type](**kwargs)
1
+ # To not break HF Trainer integration
2
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401