liger-kernel-nightly 0.4.0.dev20241108173943__tar.gz → 0.4.0.dev20241108174843__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.4.0.dev20241108173943/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241108174843}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/cross_entropy.py +46 -17
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/fused_linear_cross_entropy.py +6 -1
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/cross_entropy.py +27 -17
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +23 -10
- liger_kernel_nightly-0.4.0.dev20241108174843/src/liger_kernel/transformers/model/gemma2.py +277 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/monkey_patch.py +21 -3
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/README.md +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.4.0.
|
|
7
|
+
version = "0.4.0.dev20241108174843"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -1,8 +1,21 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import triton
|
|
3
6
|
import triton.language as tl
|
|
4
7
|
|
|
5
|
-
from liger_kernel.ops.utils import element_mul_kernel, is_hip
|
|
8
|
+
from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
|
|
9
|
+
|
|
10
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
try:
|
|
12
|
+
# typical import path with dispatch available
|
|
13
|
+
from triton.language.extra.libdevice import tanh
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
# for working with NGC containers
|
|
16
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
17
|
+
else:
|
|
18
|
+
from triton.language.math import tanh
|
|
6
19
|
|
|
7
20
|
_TRUE = tl.constexpr(1)
|
|
8
21
|
_FALSE = tl.constexpr(0)
|
|
@@ -23,8 +36,10 @@ def liger_cross_entropy_kernel(
|
|
|
23
36
|
lse_square_scale: tl.constexpr,
|
|
24
37
|
label_smoothing: tl.constexpr,
|
|
25
38
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
39
|
+
softcap,
|
|
26
40
|
RETURN_Z_LOSS: tl.constexpr,
|
|
27
41
|
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
HAS_SOFTCAPPING: tl.constexpr,
|
|
28
43
|
):
|
|
29
44
|
"""
|
|
30
45
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -45,7 +60,9 @@ def liger_cross_entropy_kernel(
|
|
|
45
60
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
46
61
|
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
47
62
|
reduction (str): The string for the reduction to apply
|
|
63
|
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
48
64
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
65
|
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
49
66
|
"""
|
|
50
67
|
|
|
51
68
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -78,6 +95,8 @@ def liger_cross_entropy_kernel(
|
|
|
78
95
|
ori_X_y = tl.load(
|
|
79
96
|
X_ptr + y
|
|
80
97
|
) # we need to store the original value of X_y for the loss calculation
|
|
98
|
+
if HAS_SOFTCAPPING:
|
|
99
|
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
81
100
|
|
|
82
101
|
# Label smoothing is a general case of normal cross entropy
|
|
83
102
|
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
@@ -89,6 +108,8 @@ def liger_cross_entropy_kernel(
|
|
|
89
108
|
X_block = tl.load(
|
|
90
109
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
91
110
|
)
|
|
111
|
+
if HAS_SOFTCAPPING:
|
|
112
|
+
X_block = softcap * tanh(X_block / softcap)
|
|
92
113
|
block_max = tl.max(X_block)
|
|
93
114
|
if label_smoothing > 0:
|
|
94
115
|
# scale X beforehand to avoid overflow
|
|
@@ -122,15 +143,24 @@ def liger_cross_entropy_kernel(
|
|
|
122
143
|
X_block = tl.load(
|
|
123
144
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
124
145
|
)
|
|
146
|
+
if HAS_SOFTCAPPING:
|
|
147
|
+
intermediate = tanh(X_block / softcap)
|
|
148
|
+
X_block = softcap * intermediate
|
|
125
149
|
# softmax(x_i)
|
|
126
150
|
X_block = tl.exp(X_block - m) / d
|
|
127
151
|
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
128
152
|
X_block += 2 * lse_square_scale * lse * X_block
|
|
129
153
|
# smoothing term
|
|
130
154
|
X_block += -eps
|
|
155
|
+
# special handle dx_y
|
|
156
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
131
157
|
# reduction scale
|
|
132
158
|
if reduction == "mean":
|
|
133
159
|
X_block = X_block / (n_non_ignore)
|
|
160
|
+
# chain rule
|
|
161
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
162
|
+
if HAS_SOFTCAPPING:
|
|
163
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
134
164
|
|
|
135
165
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
136
166
|
|
|
@@ -151,7 +181,7 @@ def liger_cross_entropy_kernel(
|
|
|
151
181
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
152
182
|
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
153
183
|
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
154
|
-
# = (1 - label_smoothing) * H(q, p) + (
|
|
184
|
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
|
155
185
|
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
156
186
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
157
187
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
@@ -168,17 +198,9 @@ def liger_cross_entropy_kernel(
|
|
|
168
198
|
z_loss = z_loss / n_non_ignore
|
|
169
199
|
loss = loss / n_non_ignore
|
|
170
200
|
|
|
171
|
-
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
172
|
-
X_y = tl.load(X_ptr + y)
|
|
173
|
-
if reduction == "mean":
|
|
174
|
-
X_y += -(1 - label_smoothing) / (n_non_ignore)
|
|
175
|
-
else:
|
|
176
|
-
X_y += -(1 - label_smoothing)
|
|
177
|
-
|
|
178
201
|
tl.store(loss_ptr, loss)
|
|
179
202
|
if RETURN_Z_LOSS == _TRUE:
|
|
180
203
|
tl.store(z_loss_ptr, z_loss)
|
|
181
|
-
tl.store(X_ptr + y, X_y)
|
|
182
204
|
|
|
183
205
|
|
|
184
206
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -200,6 +222,7 @@ def cross_entropy_forward(
|
|
|
200
222
|
lse_square_scale,
|
|
201
223
|
label_smoothing,
|
|
202
224
|
reduction,
|
|
225
|
+
softcap,
|
|
203
226
|
return_z_loss,
|
|
204
227
|
):
|
|
205
228
|
if not isinstance(return_z_loss, int):
|
|
@@ -247,8 +270,10 @@ def cross_entropy_forward(
|
|
|
247
270
|
lse_square_scale=lse_square_scale,
|
|
248
271
|
label_smoothing=label_smoothing,
|
|
249
272
|
reduction=reduction,
|
|
273
|
+
softcap=softcap if softcap is not None else 0.0,
|
|
250
274
|
RETURN_Z_LOSS=return_z_loss,
|
|
251
275
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
276
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
252
277
|
# TODO: 32 seems to give the best performance
|
|
253
278
|
# Performance is quite sensitive to num_warps
|
|
254
279
|
num_warps=32 if not is_hip() else 16,
|
|
@@ -296,13 +321,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
296
321
|
@staticmethod
|
|
297
322
|
def forward(
|
|
298
323
|
ctx,
|
|
299
|
-
_input,
|
|
300
|
-
target,
|
|
301
|
-
ignore_index
|
|
302
|
-
lse_square_scale=0.0,
|
|
303
|
-
label_smoothing=0.0,
|
|
304
|
-
reduction="mean",
|
|
305
|
-
|
|
324
|
+
_input: torch.Tensor,
|
|
325
|
+
target: torch.Tensor,
|
|
326
|
+
ignore_index: int = -100,
|
|
327
|
+
lse_square_scale: float = 0.0,
|
|
328
|
+
label_smoothing: float = 0.0,
|
|
329
|
+
reduction: str = "mean",
|
|
330
|
+
softcap: Optional[float] = None,
|
|
331
|
+
return_z_loss: bool = False,
|
|
306
332
|
):
|
|
307
333
|
"""
|
|
308
334
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -315,6 +341,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
315
341
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
316
342
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
317
343
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
344
|
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
318
345
|
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
319
346
|
|
|
320
347
|
Returns:
|
|
@@ -327,6 +354,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
327
354
|
lse_square_scale,
|
|
328
355
|
label_smoothing,
|
|
329
356
|
reduction,
|
|
357
|
+
softcap,
|
|
330
358
|
return_z_loss,
|
|
331
359
|
)
|
|
332
360
|
# TODO: investigation
|
|
@@ -362,4 +390,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
362
390
|
None,
|
|
363
391
|
None,
|
|
364
392
|
None,
|
|
393
|
+
None,
|
|
365
394
|
)
|
|
@@ -24,6 +24,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
24
24
|
lse_square_scale=0.0,
|
|
25
25
|
label_smoothing=0.0,
|
|
26
26
|
reduction="mean",
|
|
27
|
+
softcap=None,
|
|
27
28
|
):
|
|
28
29
|
dtype = _input.dtype
|
|
29
30
|
device = _input.device
|
|
@@ -95,7 +96,9 @@ def fused_linear_cross_entropy_forward(
|
|
|
95
96
|
lse_square_scale=lse_square_scale,
|
|
96
97
|
label_smoothing=label_smoothing,
|
|
97
98
|
reduction=reduction,
|
|
99
|
+
softcap=softcap if softcap is not None else 0.0,
|
|
98
100
|
RETURN_Z_LOSS=0, # False
|
|
101
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
99
102
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
100
103
|
num_warps=32 if not is_hip() else 16,
|
|
101
104
|
)
|
|
@@ -207,6 +210,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
207
210
|
lse_square_scale=0.0,
|
|
208
211
|
label_smoothing=0.0,
|
|
209
212
|
reduction="mean",
|
|
213
|
+
softcap=None,
|
|
210
214
|
):
|
|
211
215
|
"""
|
|
212
216
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -234,6 +238,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
234
238
|
lse_square_scale,
|
|
235
239
|
label_smoothing,
|
|
236
240
|
reduction,
|
|
241
|
+
softcap,
|
|
237
242
|
)
|
|
238
243
|
# downcast to dtype and store for backward
|
|
239
244
|
ctx.save_for_backward(
|
|
@@ -250,4 +255,4 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
250
255
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
251
256
|
grad_output, grad_input, grad_weight, grad_bias
|
|
252
257
|
)
|
|
253
|
-
return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
|
|
258
|
+
return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
|
|
@@ -1,34 +1,43 @@
|
|
|
1
|
-
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class LigerCrossEntropyLoss(nn.Module):
|
|
8
|
+
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
7
9
|
def __init__(
|
|
8
10
|
self,
|
|
9
|
-
ignore_index
|
|
10
|
-
lse_square_scale=0.0,
|
|
11
|
-
label_smoothing=0.0,
|
|
12
|
-
reduction="mean",
|
|
13
|
-
|
|
11
|
+
ignore_index: int = -100,
|
|
12
|
+
lse_square_scale: float = 0.0,
|
|
13
|
+
label_smoothing: float = 0.0,
|
|
14
|
+
reduction: str = "mean",
|
|
15
|
+
softcap: Optional[float] = None,
|
|
16
|
+
return_z_loss: bool = False,
|
|
14
17
|
):
|
|
15
18
|
super().__init__()
|
|
19
|
+
assert (label_smoothing >= 0) and (
|
|
20
|
+
label_smoothing <= 1
|
|
21
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
22
|
+
assert (label_smoothing >= 0) and (
|
|
23
|
+
label_smoothing <= 1
|
|
24
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
25
|
+
assert reduction in {
|
|
26
|
+
"mean",
|
|
27
|
+
"sum",
|
|
28
|
+
"none",
|
|
29
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
30
|
+
assert (
|
|
31
|
+
softcap is None or softcap > 0
|
|
32
|
+
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
16
33
|
self.ignore_index = ignore_index
|
|
17
34
|
self.lse_square_scale = lse_square_scale
|
|
18
35
|
self.label_smoothing = label_smoothing
|
|
19
36
|
self.reduction = reduction
|
|
37
|
+
self.softcap = softcap
|
|
20
38
|
self.return_z_loss = return_z_loss
|
|
21
39
|
|
|
22
|
-
|
|
23
|
-
self.label_smoothing <= 1
|
|
24
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
25
|
-
assert self.reduction in {
|
|
26
|
-
"mean",
|
|
27
|
-
"sum",
|
|
28
|
-
"none",
|
|
29
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
|
|
30
|
-
|
|
31
|
-
def forward(self, _input, target):
|
|
40
|
+
def forward(self, _input: torch.Tensor, target: torch.Tensor):
|
|
32
41
|
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
33
42
|
_input,
|
|
34
43
|
target,
|
|
@@ -36,6 +45,7 @@ class LigerCrossEntropyLoss(nn.Module):
|
|
|
36
45
|
self.lse_square_scale,
|
|
37
46
|
self.label_smoothing,
|
|
38
47
|
self.reduction,
|
|
48
|
+
self.softcap,
|
|
39
49
|
self.return_z_loss,
|
|
40
50
|
)
|
|
41
51
|
if not self.return_z_loss:
|
|
@@ -1,26 +1,38 @@
|
|
|
1
|
-
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
4
6
|
LigerFusedLinearCrossEntropyFunction,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
class LigerFusedLinearCrossEntropyLoss(nn.Module):
|
|
10
|
+
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
9
11
|
def __init__(
|
|
10
12
|
self,
|
|
11
|
-
ignore_index
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
13
|
+
ignore_index: int = -100,
|
|
14
|
+
lse_square_scale: float = 0.0,
|
|
15
|
+
label_smoothing: float = 0.0,
|
|
16
|
+
reduction: str = "mean",
|
|
17
|
+
softcap: Optional[float] = None,
|
|
15
18
|
):
|
|
16
19
|
super().__init__()
|
|
20
|
+
assert (label_smoothing >= 0) and (
|
|
21
|
+
label_smoothing <= 1
|
|
22
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
23
|
+
assert reduction in {
|
|
24
|
+
"mean",
|
|
25
|
+
"sum",
|
|
26
|
+
"none",
|
|
27
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
28
|
+
assert (
|
|
29
|
+
softcap is None or softcap > 0
|
|
30
|
+
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
17
31
|
self.ignore_index = ignore_index
|
|
32
|
+
self.lse_square_scale = lse_square_scale
|
|
18
33
|
self.label_smoothing = label_smoothing
|
|
19
34
|
self.reduction = reduction
|
|
20
|
-
self.
|
|
21
|
-
assert (self.label_smoothing >= 0) and (
|
|
22
|
-
self.label_smoothing <= 1
|
|
23
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
35
|
+
self.softcap = softcap
|
|
24
36
|
|
|
25
37
|
def forward(self, lin_weight, _input, target, bias=None):
|
|
26
38
|
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
@@ -32,4 +44,5 @@ class LigerFusedLinearCrossEntropyLoss(nn.Module):
|
|
|
32
44
|
self.lse_square_scale,
|
|
33
45
|
self.label_smoothing,
|
|
34
46
|
self.reduction,
|
|
47
|
+
self.softcap,
|
|
35
48
|
)
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import CrossEntropyLoss
|
|
6
|
+
from transformers.cache_utils import HybridCache
|
|
7
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
8
|
+
from transformers.models.gemma2.modeling_gemma2 import (
|
|
9
|
+
_CONFIG_FOR_DOC,
|
|
10
|
+
GEMMA2_INPUTS_DOCSTRING,
|
|
11
|
+
)
|
|
12
|
+
from transformers.utils import (
|
|
13
|
+
add_start_docstrings_to_model_forward,
|
|
14
|
+
replace_return_docstrings,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|
18
|
+
LigerFusedLinearCrossEntropyLoss,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def lce_forward_deprecated(
|
|
25
|
+
self,
|
|
26
|
+
input_ids: torch.LongTensor = None,
|
|
27
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
28
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
29
|
+
past_key_values: Optional[HybridCache] = None,
|
|
30
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
31
|
+
labels: Optional[torch.LongTensor] = None,
|
|
32
|
+
use_cache: Optional[bool] = None,
|
|
33
|
+
output_attentions: Optional[bool] = None,
|
|
34
|
+
output_hidden_states: Optional[bool] = None,
|
|
35
|
+
return_dict: Optional[bool] = None,
|
|
36
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
37
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
|
+
r"""
|
|
39
|
+
Args:
|
|
40
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
41
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
42
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
43
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
51
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
52
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
53
|
+
>>> prompt = "What is your favorite condiment?"
|
|
54
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
55
|
+
>>> # Generate
|
|
56
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
57
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
58
|
+
"What is your favorite condiment?"
|
|
59
|
+
```"""
|
|
60
|
+
|
|
61
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
62
|
+
logger.warning_once(
|
|
63
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
64
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
65
|
+
)
|
|
66
|
+
output_attentions = (
|
|
67
|
+
output_attentions
|
|
68
|
+
if output_attentions is not None
|
|
69
|
+
else self.config.output_attentions
|
|
70
|
+
)
|
|
71
|
+
output_hidden_states = (
|
|
72
|
+
output_hidden_states
|
|
73
|
+
if output_hidden_states is not None
|
|
74
|
+
else self.config.output_hidden_states
|
|
75
|
+
)
|
|
76
|
+
return_dict = (
|
|
77
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
78
|
+
)
|
|
79
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
80
|
+
outputs = self.model(
|
|
81
|
+
input_ids=input_ids,
|
|
82
|
+
attention_mask=attention_mask,
|
|
83
|
+
position_ids=position_ids,
|
|
84
|
+
past_key_values=past_key_values,
|
|
85
|
+
inputs_embeds=inputs_embeds,
|
|
86
|
+
use_cache=use_cache,
|
|
87
|
+
output_attentions=output_attentions,
|
|
88
|
+
output_hidden_states=output_hidden_states,
|
|
89
|
+
return_dict=return_dict,
|
|
90
|
+
cache_position=cache_position,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
hidden_states = outputs[0]
|
|
94
|
+
|
|
95
|
+
loss = None
|
|
96
|
+
logits = None
|
|
97
|
+
|
|
98
|
+
if self.training and (labels is not None):
|
|
99
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
100
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
101
|
+
|
|
102
|
+
# flatten
|
|
103
|
+
|
|
104
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
105
|
+
shift_labels = shift_labels.view(-1)
|
|
106
|
+
|
|
107
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
108
|
+
softcap=self.config.final_logit_softcapping
|
|
109
|
+
)
|
|
110
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
114
|
+
logits = self.lm_head(hidden_states)
|
|
115
|
+
if self.config.final_logit_softcapping is not None:
|
|
116
|
+
logits = logits / self.config.final_logit_softcapping
|
|
117
|
+
logits = torch.tanh(logits)
|
|
118
|
+
logits = logits * self.config.final_logit_softcapping
|
|
119
|
+
|
|
120
|
+
loss = None
|
|
121
|
+
if labels is not None:
|
|
122
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
123
|
+
logits = logits.float()
|
|
124
|
+
# Shift so that tokens < n predict n
|
|
125
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
126
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
127
|
+
# Flatten the tokens
|
|
128
|
+
loss_fct = CrossEntropyLoss()
|
|
129
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
130
|
+
shift_labels = shift_labels.view(-1)
|
|
131
|
+
# Enable model parallelism
|
|
132
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
133
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
134
|
+
|
|
135
|
+
if not return_dict:
|
|
136
|
+
output = (logits,) + outputs[1:]
|
|
137
|
+
return (loss,) + output if loss is not None else output
|
|
138
|
+
|
|
139
|
+
return CausalLMOutputWithPast(
|
|
140
|
+
loss=loss,
|
|
141
|
+
logits=logits,
|
|
142
|
+
past_key_values=outputs.past_key_values,
|
|
143
|
+
hidden_states=outputs.hidden_states,
|
|
144
|
+
attentions=outputs.attentions,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
149
|
+
@replace_return_docstrings(
|
|
150
|
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
151
|
+
)
|
|
152
|
+
def lce_forward(
|
|
153
|
+
self,
|
|
154
|
+
input_ids: torch.LongTensor = None,
|
|
155
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
156
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
157
|
+
past_key_values: Optional[HybridCache] = None,
|
|
158
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
159
|
+
labels: Optional[torch.LongTensor] = None,
|
|
160
|
+
use_cache: Optional[bool] = None,
|
|
161
|
+
output_attentions: Optional[bool] = None,
|
|
162
|
+
output_hidden_states: Optional[bool] = None,
|
|
163
|
+
return_dict: Optional[bool] = None,
|
|
164
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
165
|
+
num_logits_to_keep: int = 0,
|
|
166
|
+
**loss_kwargs,
|
|
167
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
168
|
+
r"""
|
|
169
|
+
Args:
|
|
170
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
171
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
172
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
173
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
174
|
+
|
|
175
|
+
num_logits_to_keep (`int`, *optional*):
|
|
176
|
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
|
177
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
178
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
|
|
184
|
+
```python
|
|
185
|
+
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
186
|
+
|
|
187
|
+
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
188
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
189
|
+
|
|
190
|
+
>>> prompt = "What is your favorite condiment?"
|
|
191
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
192
|
+
|
|
193
|
+
>>> # Generate
|
|
194
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
195
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
196
|
+
"What is your favorite condiment?"
|
|
197
|
+
```"""
|
|
198
|
+
|
|
199
|
+
if self.training and self.config._attn_implementation != "eager":
|
|
200
|
+
logger.warning_once(
|
|
201
|
+
"It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
|
|
202
|
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
|
203
|
+
)
|
|
204
|
+
output_attentions = (
|
|
205
|
+
output_attentions
|
|
206
|
+
if output_attentions is not None
|
|
207
|
+
else self.config.output_attentions
|
|
208
|
+
)
|
|
209
|
+
output_hidden_states = (
|
|
210
|
+
output_hidden_states
|
|
211
|
+
if output_hidden_states is not None
|
|
212
|
+
else self.config.output_hidden_states
|
|
213
|
+
)
|
|
214
|
+
return_dict = (
|
|
215
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
216
|
+
)
|
|
217
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
218
|
+
outputs = self.model(
|
|
219
|
+
input_ids=input_ids,
|
|
220
|
+
attention_mask=attention_mask,
|
|
221
|
+
position_ids=position_ids,
|
|
222
|
+
past_key_values=past_key_values,
|
|
223
|
+
inputs_embeds=inputs_embeds,
|
|
224
|
+
use_cache=use_cache,
|
|
225
|
+
output_attentions=output_attentions,
|
|
226
|
+
output_hidden_states=output_hidden_states,
|
|
227
|
+
return_dict=return_dict,
|
|
228
|
+
cache_position=cache_position,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
hidden_states = outputs[0]
|
|
232
|
+
|
|
233
|
+
logits = None
|
|
234
|
+
loss = None
|
|
235
|
+
# if in training mode, don't materialize logits
|
|
236
|
+
if self.training and (labels is not None):
|
|
237
|
+
# We do the same thing as ForCausalLMLoss but using Liger FLCE
|
|
238
|
+
|
|
239
|
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
240
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
241
|
+
|
|
242
|
+
# flatten tokens
|
|
243
|
+
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
244
|
+
shift_labels = shift_labels.view(-1)
|
|
245
|
+
|
|
246
|
+
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
247
|
+
lce = LigerFusedLinearCrossEntropyLoss(
|
|
248
|
+
softcap=self.config.final_logit_softcapping,
|
|
249
|
+
reduction=reduction,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
253
|
+
if reduction == "sum":
|
|
254
|
+
loss /= loss_kwargs["num_items_in_batch"]
|
|
255
|
+
|
|
256
|
+
else: # if in inference mode materialize logits
|
|
257
|
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
258
|
+
if self.config.final_logit_softcapping is not None:
|
|
259
|
+
logits = logits / self.config.final_logit_softcapping
|
|
260
|
+
logits = torch.tanh(logits)
|
|
261
|
+
logits = logits * self.config.final_logit_softcapping
|
|
262
|
+
|
|
263
|
+
loss = None
|
|
264
|
+
if labels is not None:
|
|
265
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
266
|
+
|
|
267
|
+
if not return_dict:
|
|
268
|
+
output = (logits,) + outputs[1:]
|
|
269
|
+
return (loss,) + output if loss is not None else output
|
|
270
|
+
|
|
271
|
+
return CausalLMOutputWithPast(
|
|
272
|
+
loss=loss,
|
|
273
|
+
logits=logits,
|
|
274
|
+
past_key_values=outputs.past_key_values,
|
|
275
|
+
hidden_states=outputs.hidden_states,
|
|
276
|
+
attentions=outputs.attentions,
|
|
277
|
+
)
|
|
@@ -14,6 +14,10 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
|
|
|
14
14
|
from liger_kernel.transformers.model.gemma import (
|
|
15
15
|
lce_forward_deprecated as gemma_lce_forward_deprecated,
|
|
16
16
|
)
|
|
17
|
+
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
|
|
18
|
+
from liger_kernel.transformers.model.gemma2 import (
|
|
19
|
+
lce_forward_deprecated as gemma2_lce_forward_deprected,
|
|
20
|
+
)
|
|
17
21
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
18
22
|
from liger_kernel.transformers.model.llama import (
|
|
19
23
|
lce_forward_deprecated as llama_lce_forward_deprecated,
|
|
@@ -252,7 +256,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
252
256
|
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
|
|
253
257
|
|
|
254
258
|
Args:
|
|
255
|
-
rope (bool): Whether to apply Liger's rotary position embedding. Default is
|
|
259
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
|
|
256
260
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
|
|
257
261
|
fused_linear_cross_entropy (bool):
|
|
258
262
|
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
@@ -445,7 +449,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
445
449
|
|
|
446
450
|
def apply_liger_kernel_to_gemma2(
|
|
447
451
|
rope: bool = True,
|
|
448
|
-
cross_entropy: bool =
|
|
452
|
+
cross_entropy: bool = False,
|
|
453
|
+
fused_linear_cross_entropy: bool = True,
|
|
449
454
|
rms_norm: bool = True,
|
|
450
455
|
geglu: bool = True,
|
|
451
456
|
model: PreTrainedModel = None,
|
|
@@ -456,12 +461,19 @@ def apply_liger_kernel_to_gemma2(
|
|
|
456
461
|
|
|
457
462
|
Args:
|
|
458
463
|
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
459
|
-
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is
|
|
464
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
465
|
+
fused_linear_cross_entropy (bool):
|
|
466
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
467
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
468
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
460
469
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
461
470
|
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
462
471
|
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
463
472
|
loaded. Default is None.
|
|
464
473
|
"""
|
|
474
|
+
assert not (
|
|
475
|
+
cross_entropy and fused_linear_cross_entropy
|
|
476
|
+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
465
477
|
from transformers.models.gemma2 import modeling_gemma2
|
|
466
478
|
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
467
479
|
|
|
@@ -479,6 +491,12 @@ def apply_liger_kernel_to_gemma2(
|
|
|
479
491
|
modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
|
|
480
492
|
if cross_entropy:
|
|
481
493
|
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
494
|
+
if fused_linear_cross_entropy:
|
|
495
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
496
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
|
|
497
|
+
else:
|
|
498
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
499
|
+
modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
|
|
482
500
|
if geglu:
|
|
483
501
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
|
484
502
|
|
|
@@ -37,6 +37,7 @@ src/liger_kernel/transformers/trainer_integration.py
|
|
|
37
37
|
src/liger_kernel/transformers/experimental/embedding.py
|
|
38
38
|
src/liger_kernel/transformers/model/__init__.py
|
|
39
39
|
src/liger_kernel/transformers/model/gemma.py
|
|
40
|
+
src/liger_kernel/transformers/model/gemma2.py
|
|
40
41
|
src/liger_kernel/transformers/model/llama.py
|
|
41
42
|
src/liger_kernel/transformers/model/mistral.py
|
|
42
43
|
src/liger_kernel/transformers/model/mixtral.py
|
|
File without changes
|
{liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|