liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/grpo_loss.py +103 -61
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +219 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
19
19
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
20
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
if beta == 0:
|
|
23
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
24
|
+
elif beta == 1:
|
|
25
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
|
+
else:
|
|
27
|
+
# Compute probabilities (only required for mean calculation)
|
|
28
|
+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
|
|
29
|
+
log_mean_probs = mean_probs.log()
|
|
25
30
|
|
|
26
|
-
|
|
27
|
-
|
|
31
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
33
|
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
# JSD is the weighted average of the KL divergences
|
|
35
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
36
|
return jsd_loss
|
|
32
37
|
|
|
33
38
|
@classmethod
|
|
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
|
+
from liger_kernel.utils import infer_device
|
|
12
13
|
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
15
|
try:
|
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
|
62
|
-
n_non_ignore (
|
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
|
262
263
|
|
|
263
264
|
|
|
264
265
|
def cross_entropy_forward(
|
liger_kernel/ops/dyt.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import infer_device
|
|
11
|
+
|
|
12
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
try:
|
|
14
|
+
# typical import path with dispatch available
|
|
15
|
+
from triton.language.extra.libdevice import tanh
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
# for working with NGC containers
|
|
18
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
19
|
+
else:
|
|
20
|
+
from triton.language.math import tanh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _dyt_fwd_kernel(
|
|
25
|
+
x_ptr,
|
|
26
|
+
x_row_stride,
|
|
27
|
+
alpha_ptr,
|
|
28
|
+
gamma_ptr,
|
|
29
|
+
beta_ptr,
|
|
30
|
+
y_ptr,
|
|
31
|
+
y_row_stride,
|
|
32
|
+
n_cols,
|
|
33
|
+
BLOCK_SIZE: tl.constexpr,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Reference:
|
|
37
|
+
https://arxiv.org/abs/2503.10622
|
|
38
|
+
|
|
39
|
+
Shapes:
|
|
40
|
+
- x: (BT, C)
|
|
41
|
+
- alpha: (1)
|
|
42
|
+
- gamma: (C)
|
|
43
|
+
- beta: (C)
|
|
44
|
+
"""
|
|
45
|
+
row_idx = tl.program_id(0)
|
|
46
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
+
mask = offsets < n_cols
|
|
48
|
+
|
|
49
|
+
x_ptr += row_idx * x_row_stride
|
|
50
|
+
y_ptr += row_idx * y_row_stride
|
|
51
|
+
|
|
52
|
+
alpha = tl.load(alpha_ptr)
|
|
53
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
+
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
+
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
+
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
+
tl.store(y_ptr + offsets, y, mask=mask)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def _dyt_bwd_kernel(
|
|
62
|
+
x_ptr,
|
|
63
|
+
x_row_stride,
|
|
64
|
+
dy_ptr,
|
|
65
|
+
dy_row_stride,
|
|
66
|
+
dx_ptr,
|
|
67
|
+
dx_row_stride,
|
|
68
|
+
alpha_ptr,
|
|
69
|
+
dalpha_ptr,
|
|
70
|
+
gamma_ptr,
|
|
71
|
+
dgamma_ptr,
|
|
72
|
+
dgamma_row_stride,
|
|
73
|
+
n_cols,
|
|
74
|
+
n_rows,
|
|
75
|
+
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE: tl.constexpr,
|
|
77
|
+
):
|
|
78
|
+
"""
|
|
79
|
+
Reference:
|
|
80
|
+
https://arxiv.org/abs/2503.10622
|
|
81
|
+
|
|
82
|
+
Shapes:
|
|
83
|
+
- x: (BT, C)
|
|
84
|
+
- alpha: (1)
|
|
85
|
+
- gamma: (C)
|
|
86
|
+
- dx: (BT, C)
|
|
87
|
+
- dy: (BT, C)
|
|
88
|
+
- dgamma: (sm_count, C)
|
|
89
|
+
- dalpha: (sm_count,)
|
|
90
|
+
"""
|
|
91
|
+
# d(gamma * tanh(alpha * x) + beta) / dx
|
|
92
|
+
# = gamma * (1 - tanh^2(alpha * x)) * alpha
|
|
93
|
+
# d(gamma * tanh(alpha * x) + beta) / dalpha
|
|
94
|
+
# = gamma * (1 - tanh^2(alpha * x)) * x
|
|
95
|
+
# d(gamma * tanh(alpha * x) + beta) / dgamma
|
|
96
|
+
# = tanh(alpha * x)
|
|
97
|
+
# d(gamma * tanh(alpha * x)) / dbeta = 1
|
|
98
|
+
pid = tl.program_id(0)
|
|
99
|
+
|
|
100
|
+
row_start = pid * ROWS_PER_PROGRAM
|
|
101
|
+
row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
|
|
102
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
103
|
+
mask = offsets < n_cols
|
|
104
|
+
|
|
105
|
+
dalpha = 0.0
|
|
106
|
+
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
+
|
|
108
|
+
x_ptr += row_start * x_row_stride
|
|
109
|
+
dx_ptr += row_start * dx_row_stride
|
|
110
|
+
dy_ptr += row_start * dy_row_stride
|
|
111
|
+
alpha = tl.load(alpha_ptr)
|
|
112
|
+
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
+
|
|
114
|
+
for _ in tl.range(row_start, row_end):
|
|
115
|
+
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
+
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
+
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
+
|
|
120
|
+
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
+
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
+
dgamma += dy * tanh_ax
|
|
123
|
+
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
+
|
|
125
|
+
dy_ptr += dy_row_stride
|
|
126
|
+
x_ptr += x_row_stride
|
|
127
|
+
dx_ptr += dx_row_stride
|
|
128
|
+
|
|
129
|
+
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
+
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
+
shape = x.shape
|
|
137
|
+
dim = shape[-1]
|
|
138
|
+
x = x.view(-1, dim)
|
|
139
|
+
n_rows, n_cols = x.shape
|
|
140
|
+
y = torch.empty_like(x)
|
|
141
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
142
|
+
_dyt_fwd_kernel[(n_rows,)](
|
|
143
|
+
x_ptr=x,
|
|
144
|
+
alpha_ptr=alpha,
|
|
145
|
+
gamma_ptr=gamma,
|
|
146
|
+
beta_ptr=beta,
|
|
147
|
+
y_ptr=y,
|
|
148
|
+
x_row_stride=x.stride(0),
|
|
149
|
+
y_row_stride=y.stride(0),
|
|
150
|
+
n_cols=n_cols,
|
|
151
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
152
|
+
num_warps=num_warps,
|
|
153
|
+
)
|
|
154
|
+
return y.view(*shape)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
+
shape = dy.shape
|
|
159
|
+
dtype = x.dtype
|
|
160
|
+
dim = shape[-1]
|
|
161
|
+
dy = dy.view(-1, dim)
|
|
162
|
+
x = x.view(-1, dim)
|
|
163
|
+
n_rows, n_cols = dy.shape
|
|
164
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
+
sm_count = 1
|
|
166
|
+
device = infer_device()
|
|
167
|
+
if device == "cuda":
|
|
168
|
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
|
+
elif device == "xpu":
|
|
170
|
+
sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
171
|
+
if n_cols > BLOCK_SIZE:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
dx = torch.empty_like(x, dtype=torch.float32)
|
|
177
|
+
_dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
|
|
178
|
+
_dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
|
|
179
|
+
|
|
180
|
+
grid = (sm_count,)
|
|
181
|
+
rows_per_program = triton.cdiv(n_rows, sm_count)
|
|
182
|
+
_dyt_bwd_kernel[grid](
|
|
183
|
+
x_ptr=x,
|
|
184
|
+
x_row_stride=x.stride(0),
|
|
185
|
+
dy_ptr=dy,
|
|
186
|
+
dy_row_stride=dy.stride(0),
|
|
187
|
+
dx_ptr=dx,
|
|
188
|
+
dx_row_stride=dx.stride(0),
|
|
189
|
+
alpha_ptr=alpha,
|
|
190
|
+
dalpha_ptr=_dalpha,
|
|
191
|
+
gamma_ptr=gamma,
|
|
192
|
+
dgamma_ptr=_dgamma,
|
|
193
|
+
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
+
n_cols=n_cols,
|
|
195
|
+
n_rows=n_rows,
|
|
196
|
+
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
+
num_warps=num_warps,
|
|
199
|
+
)
|
|
200
|
+
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
+
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
+
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
+
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerDyTFunction(torch.autograd.Function):
|
|
207
|
+
@staticmethod
|
|
208
|
+
@ensure_contiguous
|
|
209
|
+
def forward(ctx, x, alpha, gamma, beta):
|
|
210
|
+
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
+
ctx.save_for_backward(x, alpha, gamma)
|
|
212
|
+
return y
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
@ensure_contiguous
|
|
216
|
+
def backward(ctx, grad_output):
|
|
217
|
+
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
+
grad_output,
|
|
220
|
+
x,
|
|
221
|
+
alpha,
|
|
222
|
+
gamma,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return (dx, dalpha, dgamma, dbeta)
|
|
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
8
8
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
9
9
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
10
10
|
from liger_kernel.ops.utils import is_hip
|
|
11
|
+
from liger_kernel.utils import infer_device
|
|
11
12
|
|
|
12
13
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
13
14
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
14
15
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
15
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
16
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def fused_linear_jsd_forward(
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -51,24 +51,43 @@ def _jsd_kernel(
|
|
|
51
51
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
52
|
|
|
53
53
|
if beta == 0.0: # forward KL
|
|
54
|
-
|
|
54
|
+
Y_max = tl.max(Y, axis=0)
|
|
55
|
+
Y_shifted = Y - Y_max
|
|
56
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
55
57
|
loss = Y_prob * (Y - X)
|
|
56
58
|
dX = -Y_prob
|
|
57
|
-
elif beta == 1.0:
|
|
58
|
-
|
|
59
|
+
elif beta == 1.0: # reverse KL
|
|
60
|
+
X_max = tl.max(X, axis=0)
|
|
61
|
+
X_shifted = X - X_max
|
|
62
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
59
63
|
loss = X_prob * (X - Y)
|
|
60
64
|
dX = loss + X_prob
|
|
61
65
|
else:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
log_M = tl.log(M)
|
|
66
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
67
|
+
X_shifted = X - max_val
|
|
68
|
+
Y_shifted = Y - max_val
|
|
66
69
|
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
71
|
+
exp_max = tl.exp(max_val)
|
|
72
|
+
|
|
73
|
+
# Compute exp terms with compensation
|
|
74
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
75
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
76
|
+
|
|
77
|
+
# Pre-compute common terms
|
|
78
|
+
beta_P = beta * P
|
|
79
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
80
|
+
M = beta_P + one_minus_beta_Q
|
|
81
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
82
|
+
|
|
83
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
84
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
85
|
+
|
|
86
|
+
# Pre-compute scaling factor
|
|
87
|
+
scale = 1.0 / n_non_ignore
|
|
88
|
+
loss = loss * scale
|
|
89
|
+
dX = dX * scale
|
|
69
90
|
|
|
70
|
-
loss = loss / n_non_ignore
|
|
71
|
-
dX = dX / n_non_ignore
|
|
72
91
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
73
92
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
74
93
|
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -185,9 +185,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
185
185
|
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
|
186
186
|
```python
|
|
187
187
|
if log_target:
|
|
188
|
-
loss = target * (target.log() - input)
|
|
189
|
-
else:
|
|
190
188
|
loss = target.exp() * (target - input)
|
|
189
|
+
else:
|
|
190
|
+
loss = target * (target.log() - input)
|
|
191
191
|
```,
|
|
192
192
|
then the loss is reduced according to the `reduction` parameter.
|
|
193
193
|
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
2
2
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
|
|
3
|
+
from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
|
|
3
4
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
|
|
4
5
|
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
|
|
5
6
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
|
|
@@ -11,10 +12,12 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma
|
|
|
11
12
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
12
13
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
13
14
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
|
|
15
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
|
|
14
16
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
|
|
15
17
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
|
|
16
18
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
|
|
17
19
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
20
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
|
|
18
21
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
19
22
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
20
23
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerDyT(nn.Module):
|
|
8
|
+
def __init__(self, hidden_size, init_alpha=0.5):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hidden_size = hidden_size
|
|
11
|
+
self.init_alpha = init_alpha
|
|
12
|
+
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
13
|
+
self.gamma = nn.Parameter(torch.ones(hidden_size))
|
|
14
|
+
self.beta = nn.Parameter(torch.zeros(hidden_size))
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
|
|
18
|
+
|
|
19
|
+
def extra_repr(self):
|
|
20
|
+
return f"{self.hidden_size}, init_alpha={self.init_alpha}"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
from liger_kernel.ops.dyt import LigerDyTFunction
|
|
4
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
|
5
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
7
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
192
193
|
|
|
193
194
|
def liger_swiglu(a, b):
|
|
194
195
|
return LigerSiLUMulFunction.apply(a, b)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def liger_dyt(x, alpha, gamma, beta):
|
|
199
|
+
return LigerDyTFunction.apply(x, alpha, gamma, beta)
|
|
@@ -14,6 +14,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
14
14
|
from transformers.utils import replace_return_docstrings
|
|
15
15
|
|
|
16
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
@@ -200,22 +201,13 @@ def lce_forward(
|
|
|
200
201
|
loss = None
|
|
201
202
|
# if in training mode, don't materialize logits
|
|
202
203
|
if self.training and (labels is not None):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
shift_labels = shift_labels.view(-1)
|
|
211
|
-
|
|
212
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
213
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
214
|
-
|
|
215
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
216
|
-
if reduction == "sum":
|
|
217
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
218
|
-
|
|
204
|
+
loss = LigerForCausalLMLoss(
|
|
205
|
+
hidden_states=hidden_states,
|
|
206
|
+
lm_head_weight=self.lm_head.weight,
|
|
207
|
+
labels=labels,
|
|
208
|
+
hidden_size=self.config.hidden_size,
|
|
209
|
+
**loss_kwargs,
|
|
210
|
+
)
|
|
219
211
|
else: # if in inference mode materialize logits
|
|
220
212
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
221
213
|
if labels is not None:
|
|
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
19
|
|
|
19
20
|
logger = logging.getLogger(__name__)
|
|
20
21
|
|
|
@@ -212,25 +213,15 @@ def lce_forward(
|
|
|
212
213
|
loss = None
|
|
213
214
|
# if in training mode, don't materialize logits
|
|
214
215
|
if self.training and (labels is not None):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
# flatten tokens
|
|
221
|
-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
|
|
222
|
-
shift_labels = shift_labels.view(-1)
|
|
223
|
-
|
|
224
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
225
|
-
lce = LigerFusedLinearCrossEntropyLoss(
|
|
216
|
+
loss = LigerForCausalLMLoss(
|
|
217
|
+
hidden_states=hidden_states,
|
|
218
|
+
lm_head_weight=self.lm_head.weight,
|
|
219
|
+
labels=labels,
|
|
220
|
+
hidden_size=self.config.hidden_size,
|
|
226
221
|
softcap=self.config.final_logit_softcapping,
|
|
227
|
-
|
|
222
|
+
**loss_kwargs,
|
|
228
223
|
)
|
|
229
224
|
|
|
230
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
231
|
-
if reduction == "sum":
|
|
232
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
233
|
-
|
|
234
225
|
else: # if in inference mode materialize logits
|
|
235
226
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
236
227
|
if self.config.final_logit_softcapping is not None:
|
|
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
|
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
17
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
19
|
|
|
19
20
|
if TYPE_CHECKING:
|
|
20
21
|
from transformers.cache_utils import Cache
|
|
@@ -212,21 +213,13 @@ def lce_forward(
|
|
|
212
213
|
loss = None
|
|
213
214
|
# if in training mode, don't materialize logits
|
|
214
215
|
if self.training and (labels is not None):
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
shift_labels = shift_labels.view(-1)
|
|
223
|
-
|
|
224
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
225
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
226
|
-
|
|
227
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
228
|
-
if reduction == "sum":
|
|
229
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
216
|
+
loss = LigerForCausalLMLoss(
|
|
217
|
+
hidden_states=hidden_states,
|
|
218
|
+
lm_head_weight=self.lm_head.weight,
|
|
219
|
+
labels=labels,
|
|
220
|
+
hidden_size=self.config.hidden_size,
|
|
221
|
+
**loss_kwargs,
|
|
222
|
+
)
|
|
230
223
|
|
|
231
224
|
else: # if in inference mode materialize logits
|
|
232
225
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|