liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
- liger_kernel/chunked_loss/grpo_loss.py +134 -60
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +32 -12
- liger_kernel/ops/kl_div.py +15 -8
- liger_kernel/ops/layer_norm.py +14 -1
- liger_kernel/ops/rms_norm.py +12 -1
- liger_kernel/transformers/__init__.py +133 -15
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/gema3_rms.py +8 -0
- liger_kernel/transformers/model/gemma.py +17 -20
- liger_kernel/transformers/model/gemma2.py +17 -21
- liger_kernel/transformers/model/gemma3.py +335 -0
- liger_kernel/transformers/model/llama.py +17 -19
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +64 -0
- liger_kernel/transformers/model/mistral.py +28 -25
- liger_kernel/transformers/model/mixtral.py +20 -26
- liger_kernel/transformers/model/mllama.py +17 -19
- liger_kernel/transformers/model/olmo2.py +17 -20
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +17 -19
- liger_kernel/transformers/model/qwen2.py +17 -19
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +392 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def get_num_warps(BLOCK_SIZE):
|
|
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
|
|
|
115
116
|
|
|
116
117
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
117
118
|
BT, V = y_pred.shape
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
BLOCK_SIZE = (
|
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
|
121
|
+
if infer_device() == "xpu"
|
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
+
)
|
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
121
125
|
|
|
122
126
|
grid = (BT,)
|
|
123
127
|
reduction = _str_to_reduction_mode[reduction]
|
|
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
155
159
|
|
|
156
160
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
157
161
|
BT, V = target.shape
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
162
|
+
BLOCK_SIZE = (
|
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
|
164
|
+
if infer_device() == "xpu"
|
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
+
)
|
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
161
168
|
|
|
162
169
|
grid = (BT,)
|
|
163
170
|
|
|
@@ -185,9 +192,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
185
192
|
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
|
186
193
|
```python
|
|
187
194
|
if log_target:
|
|
188
|
-
loss = target * (target.log() - input)
|
|
189
|
-
else:
|
|
190
195
|
loss = target.exp() * (target - input)
|
|
196
|
+
else:
|
|
197
|
+
loss = target * (target.log() - input)
|
|
191
198
|
```,
|
|
192
199
|
then the loss is reduced according to the `reduction` parameter.
|
|
193
200
|
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -154,6 +154,11 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
154
154
|
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
155
155
|
)
|
|
156
156
|
|
|
157
|
+
# XPU-specific optimization
|
|
158
|
+
kernel_args = {}
|
|
159
|
+
if X.device.type == "xpu":
|
|
160
|
+
kernel_args["grf_mode"] = "large"
|
|
161
|
+
|
|
157
162
|
_layer_norm_forward_kernel[(n_rows,)](
|
|
158
163
|
Y,
|
|
159
164
|
Y.stride(0),
|
|
@@ -171,6 +176,7 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
171
176
|
eps,
|
|
172
177
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
173
178
|
num_warps=num_warps,
|
|
179
|
+
**kernel_args, # XPU-specific optimization
|
|
174
180
|
)
|
|
175
181
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
176
182
|
|
|
@@ -185,7 +191,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
185
191
|
if X.device.type == "cuda":
|
|
186
192
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
187
193
|
elif X.device.type == "xpu":
|
|
188
|
-
sm_count = torch.xpu.get_device_properties(X.device).
|
|
194
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
189
195
|
|
|
190
196
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
191
197
|
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
|
@@ -208,6 +214,12 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
208
214
|
if X.dtype == torch.float16
|
|
209
215
|
else tl.float32 # fallback to float32 for other types
|
|
210
216
|
)
|
|
217
|
+
|
|
218
|
+
# XPU-specific optimization
|
|
219
|
+
kernel_args = {}
|
|
220
|
+
if X.device.type == "xpu":
|
|
221
|
+
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
222
|
+
|
|
211
223
|
_layer_norm_backward_kernel[grid](
|
|
212
224
|
X,
|
|
213
225
|
W,
|
|
@@ -227,6 +239,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
227
239
|
rows_per_program,
|
|
228
240
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
229
241
|
dtype=triton_dtype,
|
|
242
|
+
**kernel_args, # XPU-specific optimization
|
|
230
243
|
)
|
|
231
244
|
|
|
232
245
|
DW = _DW.sum(dim=0).to(W.dtype)
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -223,6 +223,10 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
223
223
|
# Check constraints.
|
|
224
224
|
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
225
225
|
|
|
226
|
+
# XPU-specific optimization
|
|
227
|
+
kernel_args = {}
|
|
228
|
+
if X.device.type == "xpu":
|
|
229
|
+
kernel_args["grf_mode"] = "large"
|
|
226
230
|
_rms_norm_forward_kernel[(n_rows,)](
|
|
227
231
|
Y,
|
|
228
232
|
Y.stride(0),
|
|
@@ -238,6 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
238
242
|
casting_mode,
|
|
239
243
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
240
244
|
num_warps=num_warps,
|
|
245
|
+
**kernel_args, # XPU-specific optimization
|
|
241
246
|
)
|
|
242
247
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
243
248
|
|
|
@@ -252,7 +257,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
252
257
|
if X.device.type == "cuda":
|
|
253
258
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
259
|
elif X.device.type == "xpu":
|
|
255
|
-
sm_count = torch.xpu.get_device_properties(X.device).
|
|
260
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
256
261
|
|
|
257
262
|
# fp32 for numerical stability especially.
|
|
258
263
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -267,6 +272,11 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
267
272
|
else:
|
|
268
273
|
dX = torch.zeros_like(dY)
|
|
269
274
|
|
|
275
|
+
# XPU-specific optimization
|
|
276
|
+
kernel_args = {}
|
|
277
|
+
if X.device.type == "xpu":
|
|
278
|
+
kernel_args["grf_mode"] = "large"
|
|
279
|
+
|
|
270
280
|
_rms_norm_backward_kernel[grid](
|
|
271
281
|
dY,
|
|
272
282
|
dY.stride(0),
|
|
@@ -288,6 +298,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
288
298
|
casting_mode,
|
|
289
299
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
290
300
|
num_warps=num_warps,
|
|
301
|
+
**kernel_args, # XPU-specific optimization
|
|
291
302
|
)
|
|
292
303
|
dX = dX.view(*shape)
|
|
293
304
|
dW = _dW.sum(dim=0).to(W.dtype)
|
|
@@ -1,27 +1,145 @@
|
|
|
1
|
-
|
|
1
|
+
import importlib
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
# Always-safe imports (independent of 'transformers')
|
|
2
6
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
7
|
+
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
|
|
3
8
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
4
9
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
5
10
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
6
11
|
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
|
|
7
12
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
|
|
8
|
-
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
9
|
-
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
10
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
11
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
12
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
13
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
14
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
15
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
16
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
17
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
18
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
19
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
20
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
21
|
-
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
22
13
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
23
14
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
24
15
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
25
16
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
26
17
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
27
18
|
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
19
|
+
|
|
20
|
+
# Static-only imports for IDEs and type checkers
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
23
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
24
|
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
25
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
26
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
27
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
|
|
28
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
|
|
29
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
30
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
31
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
32
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
33
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
34
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
35
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
36
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
|
|
37
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
38
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
39
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
40
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Check if 'transformers' is installed
|
|
44
|
+
try:
|
|
45
|
+
import transformers # noqa: F401
|
|
46
|
+
|
|
47
|
+
_TRANSFORMERS_AVAILABLE = True
|
|
48
|
+
except ImportError:
|
|
49
|
+
_TRANSFORMERS_AVAILABLE = False
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def is_transformers_available() -> bool:
|
|
53
|
+
"""
|
|
54
|
+
Returns True if the 'transformers' package is available.
|
|
55
|
+
Useful for conditional logic in downstream code.
|
|
56
|
+
"""
|
|
57
|
+
return _TRANSFORMERS_AVAILABLE
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __getattr__(name: str):
|
|
61
|
+
"""
|
|
62
|
+
Handles lazy access to transformer-dependent attributes.
|
|
63
|
+
If 'transformers' is not installed, raises a user-friendly ImportError.
|
|
64
|
+
"""
|
|
65
|
+
if not _TRANSFORMERS_AVAILABLE:
|
|
66
|
+
raise ImportError(
|
|
67
|
+
f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
|
|
68
|
+
f"Please install it with `pip install transformers` to use this functionality."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if name == "AutoLigerKernelForCausalLM":
|
|
72
|
+
module = importlib.import_module("liger_kernel.transformers.auto_model")
|
|
73
|
+
return getattr(module, name)
|
|
74
|
+
|
|
75
|
+
monkey_patch_symbols = {
|
|
76
|
+
"_apply_liger_kernel",
|
|
77
|
+
"_apply_liger_kernel_to_instance",
|
|
78
|
+
"apply_liger_kernel_to_gemma",
|
|
79
|
+
"apply_liger_kernel_to_gemma2",
|
|
80
|
+
"apply_liger_kernel_to_gemma3",
|
|
81
|
+
"apply_liger_kernel_to_gemma3_text",
|
|
82
|
+
"apply_liger_kernel_to_granite",
|
|
83
|
+
"apply_liger_kernel_to_llama",
|
|
84
|
+
"apply_liger_kernel_to_llava",
|
|
85
|
+
"apply_liger_kernel_to_mistral",
|
|
86
|
+
"apply_liger_kernel_to_mixtral",
|
|
87
|
+
"apply_liger_kernel_to_mllama",
|
|
88
|
+
"apply_liger_kernel_to_olmo2",
|
|
89
|
+
"apply_liger_kernel_to_paligemma",
|
|
90
|
+
"apply_liger_kernel_to_phi3",
|
|
91
|
+
"apply_liger_kernel_to_qwen2",
|
|
92
|
+
"apply_liger_kernel_to_qwen2_5_vl",
|
|
93
|
+
"apply_liger_kernel_to_qwen2_vl",
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
if name in monkey_patch_symbols:
|
|
97
|
+
module = importlib.import_module("liger_kernel.transformers.monkey_patch")
|
|
98
|
+
return getattr(module, name)
|
|
99
|
+
|
|
100
|
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# Shared symbols in all environments
|
|
104
|
+
__all__ = [
|
|
105
|
+
"is_transformers_available",
|
|
106
|
+
"LigerCrossEntropyLoss",
|
|
107
|
+
"LigerDyT",
|
|
108
|
+
"LigerFusedLinearCrossEntropyLoss",
|
|
109
|
+
"LigerFusedLinearJSD",
|
|
110
|
+
"LigerGEGLUMLP",
|
|
111
|
+
"LigerJSD",
|
|
112
|
+
"LigerLayerNorm",
|
|
113
|
+
"LigerRMSNorm",
|
|
114
|
+
"liger_rotary_pos_emb",
|
|
115
|
+
"LigerBlockSparseTop2MLP",
|
|
116
|
+
"LigerPhi3SwiGLUMLP",
|
|
117
|
+
"LigerSwiGLUMLP",
|
|
118
|
+
"LigerTVDLoss",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Add transformer-dependent symbols only if available
|
|
122
|
+
if _TRANSFORMERS_AVAILABLE:
|
|
123
|
+
__all__.extend(
|
|
124
|
+
[
|
|
125
|
+
"AutoLigerKernelForCausalLM",
|
|
126
|
+
"_apply_liger_kernel",
|
|
127
|
+
"_apply_liger_kernel_to_instance",
|
|
128
|
+
"apply_liger_kernel_to_gemma",
|
|
129
|
+
"apply_liger_kernel_to_gemma2",
|
|
130
|
+
"apply_liger_kernel_to_gemma3",
|
|
131
|
+
"apply_liger_kernel_to_gemma3_text",
|
|
132
|
+
"apply_liger_kernel_to_granite",
|
|
133
|
+
"apply_liger_kernel_to_llama",
|
|
134
|
+
"apply_liger_kernel_to_llava",
|
|
135
|
+
"apply_liger_kernel_to_mistral",
|
|
136
|
+
"apply_liger_kernel_to_mixtral",
|
|
137
|
+
"apply_liger_kernel_to_mllama",
|
|
138
|
+
"apply_liger_kernel_to_olmo2",
|
|
139
|
+
"apply_liger_kernel_to_paligemma",
|
|
140
|
+
"apply_liger_kernel_to_phi3",
|
|
141
|
+
"apply_liger_kernel_to_qwen2",
|
|
142
|
+
"apply_liger_kernel_to_qwen2_5_vl",
|
|
143
|
+
"apply_liger_kernel_to_qwen2_vl",
|
|
144
|
+
]
|
|
145
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerDyT(nn.Module):
|
|
8
|
+
def __init__(self, hidden_size, init_alpha=0.5):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hidden_size = hidden_size
|
|
11
|
+
self.init_alpha = init_alpha
|
|
12
|
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
|
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
|
+
|
|
19
|
+
def extra_repr(self):
|
|
20
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
4
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
5
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
7
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
192
193
|
|
|
193
194
|
def liger_swiglu(a, b):
|
|
194
195
|
return LigerSiLUMulFunction.apply(a, b)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def liger_dyt(x, alpha, gamma, beta):
|
|
199
|
+
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .rms_norm import LigerRMSNorm
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
5
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
8
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
@@ -12,8 +12,10 @@ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
|
|
12
12
|
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
|
13
13
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
14
|
from transformers.utils import replace_return_docstrings
|
|
15
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
15
16
|
|
|
16
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
@@ -126,6 +128,7 @@ def lce_forward_deprecated(
|
|
|
126
128
|
)
|
|
127
129
|
|
|
128
130
|
|
|
131
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
129
132
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
130
133
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
131
134
|
def lce_forward(
|
|
@@ -141,7 +144,7 @@ def lce_forward(
|
|
|
141
144
|
output_hidden_states: Optional[bool] = None,
|
|
142
145
|
return_dict: Optional[bool] = None,
|
|
143
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
144
|
-
|
|
147
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
145
148
|
**loss_kwargs,
|
|
146
149
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
147
150
|
r"""
|
|
@@ -151,10 +154,12 @@ def lce_forward(
|
|
|
151
154
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
152
155
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
153
156
|
|
|
154
|
-
|
|
155
|
-
|
|
157
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
158
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
156
159
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
157
160
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
161
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
162
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
158
163
|
|
|
159
164
|
Returns:
|
|
160
165
|
|
|
@@ -200,24 +205,16 @@ def lce_forward(
|
|
|
200
205
|
loss = None
|
|
201
206
|
# if in training mode, don't materialize logits
|
|
202
207
|
if self.training and (labels is not None):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
shift_labels = shift_labels.view(-1)
|
|
211
|
-
|
|
212
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
213
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
214
|
-
|
|
215
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
216
|
-
if reduction == "sum":
|
|
217
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
218
|
-
|
|
208
|
+
loss = LigerForCausalLMLoss(
|
|
209
|
+
hidden_states=hidden_states,
|
|
210
|
+
lm_head_weight=self.lm_head.weight,
|
|
211
|
+
labels=labels,
|
|
212
|
+
hidden_size=self.config.hidden_size,
|
|
213
|
+
**loss_kwargs,
|
|
214
|
+
)
|
|
219
215
|
else: # if in inference mode materialize logits
|
|
220
|
-
|
|
216
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
217
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
221
218
|
if labels is not None:
|
|
222
219
|
loss = self.loss_function(
|
|
223
220
|
logits=logits,
|
|
@@ -13,8 +13,10 @@ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
|
|
13
13
|
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
|
14
14
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
17
|
|
|
17
18
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
19
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
20
|
|
|
19
21
|
logger = logging.getLogger(__name__)
|
|
20
22
|
|
|
@@ -133,6 +135,7 @@ def lce_forward_deprecated(
|
|
|
133
135
|
)
|
|
134
136
|
|
|
135
137
|
|
|
138
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
136
139
|
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
137
140
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
138
141
|
def lce_forward(
|
|
@@ -148,7 +151,7 @@ def lce_forward(
|
|
|
148
151
|
output_hidden_states: Optional[bool] = None,
|
|
149
152
|
return_dict: Optional[bool] = None,
|
|
150
153
|
cache_position: Optional[torch.LongTensor] = None,
|
|
151
|
-
|
|
154
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
152
155
|
**loss_kwargs,
|
|
153
156
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
154
157
|
r"""
|
|
@@ -158,10 +161,12 @@ def lce_forward(
|
|
|
158
161
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
159
162
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
160
163
|
|
|
161
|
-
|
|
162
|
-
|
|
164
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
165
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
163
166
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
164
167
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
168
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
169
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
165
170
|
|
|
166
171
|
Returns:
|
|
167
172
|
|
|
@@ -212,27 +217,18 @@ def lce_forward(
|
|
|
212
217
|
loss = None
|
|
213
218
|
# if in training mode, don't materialize logits
|
|
214
219
|
if self.training and (labels is not None):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
shift_labels = shift_labels.view(-1)
|
|
223
|
-
|
|
224
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
225
|
-
lce = LigerFusedLinearCrossEntropyLoss(
|
|
226
|
-
softcap=self.config.final_logit_softcapping,
|
|
227
|
-
reduction=reduction,
|
|
220
|
+
loss = LigerForCausalLMLoss(
|
|
221
|
+
hidden_states=hidden_states,
|
|
222
|
+
lm_head_weight=self.lm_head.weight,
|
|
223
|
+
labels=labels,
|
|
224
|
+
hidden_size=self.config.hidden_size,
|
|
225
|
+
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
226
|
+
**loss_kwargs,
|
|
228
227
|
)
|
|
229
228
|
|
|
230
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
231
|
-
if reduction == "sum":
|
|
232
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
233
|
-
|
|
234
229
|
else: # if in inference mode materialize logits
|
|
235
|
-
|
|
230
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
231
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
236
232
|
if self.config.final_logit_softcapping is not None:
|
|
237
233
|
logits = logits / self.config.final_logit_softcapping
|
|
238
234
|
logits = torch.tanh(logits)
|