liger-kernel-nightly 0.5.5.dev20250322021112__py3-none-any.whl → 0.5.5.dev20250326012054__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/monkey_patch.py +24 -12
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/RECORD +11 -9
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250322021112.dist-info → liger_kernel_nightly-0.5.5.dev20250326012054.dist-info}/top_level.txt +0 -0
liger_kernel/ops/dyt.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
19
|
+
else:
|
|
20
|
+
from triton.language.math import tanh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _dyt_fwd_kernel(
|
|
25
|
+
x_ptr,
|
|
26
|
+
x_row_stride,
|
|
27
|
+
alpha_ptr,
|
|
28
|
+
gamma_ptr,
|
|
29
|
+
beta_ptr,
|
|
30
|
+
y_ptr,
|
|
31
|
+
y_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Reference:
|
|
37
|
+
https://arxiv.org/abs/2503.10622
|
|
38
|
+
|
|
39
|
+
Shapes:
|
|
40
|
+
- x: (BT, C)
|
|
41
|
+
- alpha: (1)
|
|
42
|
+
- gamma: (C)
|
|
43
|
+
- beta: (C)
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0)
|
|
46
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = offsets < n_cols
|
|
48
|
+
|
|
49
|
+
x_ptr += row_idx * x_row_stride
|
|
50
|
+
y_ptr += row_idx * y_row_stride
|
|
51
|
+
|
|
52
|
+
alpha = tl.load(alpha_ptr)
|
|
53
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
+
tl.store(y_ptr + offsets, y, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def _dyt_bwd_kernel(
|
|
62
|
+
x_ptr,
|
|
63
|
+
x_row_stride,
|
|
64
|
+
dy_ptr,
|
|
65
|
+
dy_row_stride,
|
|
66
|
+
dx_ptr,
|
|
67
|
+
dx_row_stride,
|
|
68
|
+
alpha_ptr,
|
|
69
|
+
dalpha_ptr,
|
|
70
|
+
gamma_ptr,
|
|
71
|
+
dgamma_ptr,
|
|
72
|
+
dgamma_row_stride,
|
|
73
|
+
n_cols,
|
|
74
|
+
n_rows,
|
|
75
|
+
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE: tl.constexpr,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Reference:
|
|
80
|
+
https://arxiv.org/abs/2503.10622
|
|
81
|
+
|
|
82
|
+
Shapes:
|
|
83
|
+
- x: (BT, C)
|
|
84
|
+
- alpha: (1)
|
|
85
|
+
- gamma: (C)
|
|
86
|
+
- dx: (BT, C)
|
|
87
|
+
- dy: (BT, C)
|
|
88
|
+
- dgamma: (sm_count, C)
|
|
89
|
+
- dalpha: (sm_count,)
|
|
90
|
+
"""
|
|
91
|
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
|
92
|
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
|
93
|
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
|
94
|
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
|
95
|
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
|
96
|
+
# = tanh(alpha * x)
|
|
97
|
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
|
98
|
+
pid = tl.program_id(0)
|
|
99
|
+
|
|
100
|
+
row_start = pid * ROWS_PER_PROGRAM
|
|
101
|
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
|
102
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
103
|
+
mask = offsets < n_cols
|
|
104
|
+
|
|
105
|
+
dalpha = 0.0
|
|
106
|
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
+
|
|
108
|
+
x_ptr += row_start * x_row_stride
|
|
109
|
+
dx_ptr += row_start * dx_row_stride
|
|
110
|
+
dy_ptr += row_start * dy_row_stride
|
|
111
|
+
alpha = tl.load(alpha_ptr)
|
|
112
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
+
|
|
114
|
+
for _ in tl.range(row_start, row_end):
|
|
115
|
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
+
|
|
120
|
+
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
+
dgamma += dy * tanh_ax
|
|
123
|
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
+
|
|
125
|
+
dy_ptr += dy_row_stride
|
|
126
|
+
x_ptr += x_row_stride
|
|
127
|
+
dx_ptr += dx_row_stride
|
|
128
|
+
|
|
129
|
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
+
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
+
shape = x.shape
|
|
137
|
+
dim = shape[-1]
|
|
138
|
+
x = x.view(-1, dim)
|
|
139
|
+
n_rows, n_cols = x.shape
|
|
140
|
+
y = torch.empty_like(x)
|
|
141
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
142
|
+
_dyt_fwd_kernel[(n_rows,)](
|
|
143
|
+
x_ptr=x,
|
|
144
|
+
alpha_ptr=alpha,
|
|
145
|
+
gamma_ptr=gamma,
|
|
146
|
+
beta_ptr=beta,
|
|
147
|
+
y_ptr=y,
|
|
148
|
+
x_row_stride=x.stride(0),
|
|
149
|
+
y_row_stride=y.stride(0),
|
|
150
|
+
n_cols=n_cols,
|
|
151
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
152
|
+
num_warps=num_warps,
|
|
153
|
+
)
|
|
154
|
+
return y.view(*shape)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
+
shape = dy.shape
|
|
159
|
+
dtype = x.dtype
|
|
160
|
+
dim = shape[-1]
|
|
161
|
+
dy = dy.view(-1, dim)
|
|
162
|
+
x = x.view(-1, dim)
|
|
163
|
+
n_rows, n_cols = dy.shape
|
|
164
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
+
sm_count = 1
|
|
166
|
+
device = infer_device()
|
|
167
|
+
if device == "cuda":
|
|
168
|
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
|
+
elif device == "xpu":
|
|
170
|
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
171
|
+
if n_cols > BLOCK_SIZE:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
dx = torch.empty_like(x, dtype=torch.float32)
|
|
177
|
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
|
178
|
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
|
179
|
+
|
|
180
|
+
grid = (sm_count,)
|
|
181
|
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
|
182
|
+
_dyt_bwd_kernel[grid](
|
|
183
|
+
x_ptr=x,
|
|
184
|
+
x_row_stride=x.stride(0),
|
|
185
|
+
dy_ptr=dy,
|
|
186
|
+
dy_row_stride=dy.stride(0),
|
|
187
|
+
dx_ptr=dx,
|
|
188
|
+
dx_row_stride=dx.stride(0),
|
|
189
|
+
alpha_ptr=alpha,
|
|
190
|
+
dalpha_ptr=_dalpha,
|
|
191
|
+
gamma_ptr=gamma,
|
|
192
|
+
dgamma_ptr=_dgamma,
|
|
193
|
+
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
+
n_cols=n_cols,
|
|
195
|
+
n_rows=n_rows,
|
|
196
|
+
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
)
|
|
200
|
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
+
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerDyTFunction(torch.autograd.Function):
|
|
207
|
+
@staticmethod
|
|
208
|
+
@ensure_contiguous
|
|
209
|
+
def forward(ctx, x, alpha, gamma, beta):
|
|
210
|
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
+
ctx.save_for_backward(x, alpha, gamma)
|
|
212
|
+
return y
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
@ensure_contiguous
|
|
216
|
+
def backward(ctx, grad_output):
|
|
217
|
+
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
+
grad_output,
|
|
220
|
+
x,
|
|
221
|
+
alpha,
|
|
222
|
+
gamma,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return (dx, dalpha, dgamma, dbeta)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
2
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
3
|
+
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
|
|
3
4
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
4
5
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
5
6
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerDyT(nn.Module):
|
|
8
|
+
def __init__(self, hidden_size, init_alpha=0.5):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hidden_size = hidden_size
|
|
11
|
+
self.init_alpha = init_alpha
|
|
12
|
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
|
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
|
+
|
|
19
|
+
def extra_repr(self):
|
|
20
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
4
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
5
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
7
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
192
193
|
|
|
193
194
|
def liger_swiglu(a, b):
|
|
194
195
|
return LigerSiLUMulFunction.apply(a, b)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def liger_dyt(x, alpha, gamma, beta):
|
|
199
|
+
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
|
@@ -52,6 +52,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
52
52
|
module.in_place = in_place
|
|
53
53
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
54
54
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
55
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
@@ -59,6 +60,17 @@ def _patch_layer_norm_module(module, eps=1e-6):
|
|
|
59
60
|
module.hidden_size = module.normalized_shape
|
|
60
61
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
61
62
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
63
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _patch_swiglu_module(module, liger_module):
|
|
67
|
+
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
68
|
+
module.__class__.__name__ = liger_module.__name__
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _patch_geglu_module(module):
|
|
72
|
+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
73
|
+
module.__class__.__name__ = LigerGEGLUMLP.__name__
|
|
62
74
|
|
|
63
75
|
|
|
64
76
|
def apply_liger_kernel_to_granite(
|
|
@@ -134,7 +146,7 @@ def apply_liger_kernel_to_granite(
|
|
|
134
146
|
|
|
135
147
|
for decoder_layer in base_model.layers:
|
|
136
148
|
if swiglu:
|
|
137
|
-
|
|
149
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
138
150
|
if rms_norm:
|
|
139
151
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
140
152
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -206,7 +218,7 @@ def apply_liger_kernel_to_llama(
|
|
|
206
218
|
|
|
207
219
|
for decoder_layer in base_model.layers:
|
|
208
220
|
if swiglu:
|
|
209
|
-
|
|
221
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
210
222
|
if rms_norm:
|
|
211
223
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
212
224
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -296,7 +308,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
296
308
|
_patch_rms_norm_module(text_model.norm)
|
|
297
309
|
for decoder_layer in text_model.layers:
|
|
298
310
|
if swiglu:
|
|
299
|
-
|
|
311
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
300
312
|
if rms_norm:
|
|
301
313
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
302
314
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -370,7 +382,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
370
382
|
|
|
371
383
|
for decoder_layer in base_model.layers:
|
|
372
384
|
if swiglu:
|
|
373
|
-
|
|
385
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
374
386
|
if rms_norm:
|
|
375
387
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
376
388
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -442,7 +454,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
442
454
|
for decoder_layer in base_model.layers:
|
|
443
455
|
if swiglu:
|
|
444
456
|
for expert in decoder_layer.block_sparse_moe.experts:
|
|
445
|
-
|
|
457
|
+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
|
|
446
458
|
if rms_norm:
|
|
447
459
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
448
460
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -516,7 +528,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
516
528
|
|
|
517
529
|
for decoder_layer in base_model.layers:
|
|
518
530
|
if geglu:
|
|
519
|
-
|
|
531
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
520
532
|
if rms_norm:
|
|
521
533
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
522
534
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
@@ -592,7 +604,7 @@ def apply_liger_kernel_to_gemma2(
|
|
|
592
604
|
|
|
593
605
|
for decoder_layer in base_model.layers:
|
|
594
606
|
if geglu:
|
|
595
|
-
|
|
607
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
596
608
|
if rms_norm:
|
|
597
609
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
598
610
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
@@ -776,7 +788,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
776
788
|
|
|
777
789
|
for decoder_layer in base_model.layers:
|
|
778
790
|
if swiglu:
|
|
779
|
-
|
|
791
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
780
792
|
if rms_norm:
|
|
781
793
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
782
794
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -849,7 +861,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
849
861
|
_patch_rms_norm_module(base_model.norm)
|
|
850
862
|
for decoder_layer in base_model.layers:
|
|
851
863
|
if swiglu:
|
|
852
|
-
|
|
864
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
853
865
|
if rms_norm:
|
|
854
866
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
855
867
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -916,7 +928,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
|
|
|
916
928
|
_patch_rms_norm_module(base_model.norm)
|
|
917
929
|
for decoder_layer in base_model.layers:
|
|
918
930
|
if swiglu:
|
|
919
|
-
|
|
931
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
920
932
|
if rms_norm:
|
|
921
933
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
922
934
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -985,7 +997,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
985
997
|
|
|
986
998
|
for decoder_layer in base_model.layers:
|
|
987
999
|
if swiglu:
|
|
988
|
-
|
|
1000
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
989
1001
|
if rms_norm:
|
|
990
1002
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
991
1003
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -1048,7 +1060,7 @@ def apply_liger_kernel_to_olmo2(
|
|
|
1048
1060
|
|
|
1049
1061
|
for decoder_layer in base_model.layers:
|
|
1050
1062
|
if swiglu:
|
|
1051
|
-
|
|
1063
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
1052
1064
|
if rms_norm:
|
|
1053
1065
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
1054
1066
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
@@ -17,6 +17,7 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmM
|
|
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
19
|
liger_kernel/ops/cross_entropy.py,sha256=yKKhN63I7r9NxJye4wTLBvvKAyrXQt6jf4nBo3lJyVg,18860
|
|
20
|
+
liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
|
|
20
21
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
|
|
21
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
|
|
22
23
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
|
@@ -32,10 +33,11 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
|
32
33
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
|
33
34
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
34
35
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
35
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
36
|
+
liger_kernel/transformers/__init__.py,sha256=eGCDpnvIBX7bhE_jGo5RRBipwT62WE_obzlniedNzt8,2525
|
|
36
37
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
37
38
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
|
38
|
-
liger_kernel/transformers/
|
|
39
|
+
liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
|
|
40
|
+
liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
|
|
39
41
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
|
|
40
42
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
|
41
43
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
|
@@ -43,7 +45,7 @@ liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD
|
|
|
43
45
|
liger_kernel/transformers/jsd.py,sha256=DGqRnxIZxsvxo0_tbbxX3b-sDbDjC_yKufyRIHCcScY,2979
|
|
44
46
|
liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZAF1Fj8,409
|
|
45
47
|
liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
|
|
46
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
48
|
+
liger_kernel/transformers/monkey_patch.py,sha256=_-4oMqEq5mQCSWQ7PaNI9cbLdT_UPPobYaqboa1oN4I,52210
|
|
47
49
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
|
|
48
50
|
liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUSyH2uFW2A,1191
|
|
49
51
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
|
@@ -69,9 +71,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
69
71
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
70
72
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
71
73
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
72
|
-
liger_kernel_nightly-0.5.5.
|
|
73
|
-
liger_kernel_nightly-0.5.5.
|
|
74
|
-
liger_kernel_nightly-0.5.5.
|
|
75
|
-
liger_kernel_nightly-0.5.5.
|
|
76
|
-
liger_kernel_nightly-0.5.5.
|
|
77
|
-
liger_kernel_nightly-0.5.5.
|
|
74
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
75
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/METADATA,sha256=xBzfl6G44MrSL8itL5Fv8d4jNasaC-fRFgiDjaK-_W4,22959
|
|
76
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
77
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
78
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
79
|
+
liger_kernel_nightly-0.5.5.dev20250326012054.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|