liger-kernel-nightly 0.4.0.dev20241106174658__tar.gz → 0.4.0.dev20241107054539__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- {liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241107054539}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/cross_entropy.py +104 -20
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/fused_linear_cross_entropy.py +14 -2
- liger_kernel_nightly-0.4.0.dev20241107054539/src/liger_kernel/transformers/cross_entropy.py +43 -0
- liger_kernel_nightly-0.4.0.dev20241107054539/src/liger_kernel/transformers/fused_linear_cross_entropy.py +35 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/monkey_patch.py +24 -51
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel/transformers/cross_entropy.py +0 -21
- liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -21
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/README.md +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "liger_kernel_nightly"
|
|
7
|
-
version = "0.4.0.
|
|
7
|
+
version = "0.4.0.dev20241107054539"
|
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -4,6 +4,9 @@ import triton.language as tl
|
|
|
4
4
|
|
|
5
5
|
from liger_kernel.ops.utils import element_mul_kernel, is_hip
|
|
6
6
|
|
|
7
|
+
_TRUE = tl.constexpr(1)
|
|
8
|
+
_FALSE = tl.constexpr(0)
|
|
9
|
+
|
|
7
10
|
|
|
8
11
|
@triton.jit
|
|
9
12
|
def liger_cross_entropy_kernel(
|
|
@@ -12,12 +15,15 @@ def liger_cross_entropy_kernel(
|
|
|
12
15
|
Y_ptr,
|
|
13
16
|
Y_stride,
|
|
14
17
|
loss_ptr,
|
|
18
|
+
z_loss_ptr,
|
|
15
19
|
loss_stride,
|
|
16
20
|
n_cols,
|
|
17
21
|
n_non_ignore,
|
|
18
22
|
ignore_index,
|
|
23
|
+
lse_square_scale: tl.constexpr,
|
|
19
24
|
label_smoothing: tl.constexpr,
|
|
20
25
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
26
|
+
RETURN_Z_LOSS: tl.constexpr,
|
|
21
27
|
BLOCK_SIZE: tl.constexpr,
|
|
22
28
|
):
|
|
23
29
|
"""
|
|
@@ -30,11 +36,14 @@ def liger_cross_entropy_kernel(
|
|
|
30
36
|
Y_ptr: Pointer to target tensor.
|
|
31
37
|
Y_stride (int): The stride of the target tensor.
|
|
32
38
|
loss_ptr: Pointer to tensor to store the loss.
|
|
39
|
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
33
40
|
loss_stride (int): The stride of the loss tensor.
|
|
34
41
|
n_cols (int): The number of columns in the input tensor.
|
|
35
42
|
n_non_ignore (int): The number of non-ignored elements in the batch.
|
|
36
43
|
ignore_index (int): The index to ignore in the target.
|
|
37
44
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
45
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
46
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
38
47
|
reduction (str): The string for the reduction to apply
|
|
39
48
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
40
49
|
"""
|
|
@@ -58,6 +67,7 @@ def liger_cross_entropy_kernel(
|
|
|
58
67
|
return
|
|
59
68
|
|
|
60
69
|
loss_ptr += program_id * loss_stride
|
|
70
|
+
z_loss_ptr += program_id * loss_stride
|
|
61
71
|
|
|
62
72
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
63
73
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
@@ -87,32 +97,40 @@ def liger_cross_entropy_kernel(
|
|
|
87
97
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
88
98
|
m = m_new
|
|
89
99
|
|
|
100
|
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
|
101
|
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
|
102
|
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
|
103
|
+
lse = m + tl.log(d)
|
|
104
|
+
|
|
90
105
|
# 4. [Online Softmax] Second pass: compute gradients
|
|
91
106
|
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
|
92
107
|
# dx_y = (softmax(x_y) - 1) / N
|
|
93
108
|
# dx_i = softmax(x_i) / N, i != y
|
|
94
109
|
# For label smoothing:
|
|
95
|
-
# dx_i = (softmax(
|
|
110
|
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
|
96
111
|
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
|
97
112
|
# = dx_i - (1 - label_smoothing) / N
|
|
98
|
-
#
|
|
113
|
+
# With Z loss:
|
|
114
|
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
|
115
|
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
|
99
116
|
# For 'sum' reduction, no normalization is applied:
|
|
100
117
|
# dx_y = softmax(x_y) - 1
|
|
101
118
|
# dx_i = softmax(x_i), for i ≠ y
|
|
102
|
-
# For label smoothing:
|
|
103
|
-
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
|
|
104
|
-
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
|
|
105
|
-
# = dx_i - (1 - label_smoothing)
|
|
106
119
|
|
|
107
120
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
108
121
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
109
122
|
X_block = tl.load(
|
|
110
123
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
111
124
|
)
|
|
125
|
+
# softmax(x_i)
|
|
126
|
+
X_block = tl.exp(X_block - m) / d
|
|
127
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
128
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
129
|
+
# smoothing term
|
|
130
|
+
X_block += -eps
|
|
131
|
+
# reduction scale
|
|
112
132
|
if reduction == "mean":
|
|
113
|
-
X_block =
|
|
114
|
-
else:
|
|
115
|
-
X_block = tl.exp(X_block - m) / d - eps
|
|
133
|
+
X_block = X_block / (n_non_ignore)
|
|
116
134
|
|
|
117
135
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
118
136
|
|
|
@@ -124,9 +142,10 @@ def liger_cross_entropy_kernel(
|
|
|
124
142
|
|
|
125
143
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
126
144
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
145
|
+
# = X_y - m - log d = X_y - lse
|
|
127
146
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
128
147
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
129
|
-
loss =
|
|
148
|
+
loss = lse - ori_X_y
|
|
130
149
|
|
|
131
150
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
132
151
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
@@ -137,11 +156,16 @@ def liger_cross_entropy_kernel(
|
|
|
137
156
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
138
157
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
139
158
|
if label_smoothing > 0:
|
|
140
|
-
smooth_loss = scaled_x_sum + label_smoothing *
|
|
159
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
141
160
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
142
161
|
|
|
162
|
+
# An auxiliary loss, z_loss
|
|
163
|
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
164
|
+
z_loss = lse_square_scale * lse * lse
|
|
165
|
+
loss += z_loss
|
|
143
166
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
144
167
|
if reduction == "mean":
|
|
168
|
+
z_loss = z_loss / n_non_ignore
|
|
145
169
|
loss = loss / n_non_ignore
|
|
146
170
|
|
|
147
171
|
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
@@ -152,6 +176,8 @@ def liger_cross_entropy_kernel(
|
|
|
152
176
|
X_y += -(1 - label_smoothing)
|
|
153
177
|
|
|
154
178
|
tl.store(loss_ptr, loss)
|
|
179
|
+
if RETURN_Z_LOSS == _TRUE:
|
|
180
|
+
tl.store(z_loss_ptr, z_loss)
|
|
155
181
|
tl.store(X_ptr + y, X_y)
|
|
156
182
|
|
|
157
183
|
|
|
@@ -161,7 +187,31 @@ def liger_cross_entropy_kernel(
|
|
|
161
187
|
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
162
188
|
|
|
163
189
|
|
|
164
|
-
|
|
190
|
+
_bool_to_return_z_loss = {
|
|
191
|
+
True: _TRUE.value,
|
|
192
|
+
False: _FALSE.value,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def cross_entropy_forward(
|
|
197
|
+
_input,
|
|
198
|
+
target,
|
|
199
|
+
ignore_index,
|
|
200
|
+
lse_square_scale,
|
|
201
|
+
label_smoothing,
|
|
202
|
+
reduction,
|
|
203
|
+
return_z_loss,
|
|
204
|
+
):
|
|
205
|
+
if not isinstance(return_z_loss, int):
|
|
206
|
+
assert (
|
|
207
|
+
return_z_loss in _bool_to_return_z_loss
|
|
208
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
209
|
+
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
210
|
+
else:
|
|
211
|
+
assert (
|
|
212
|
+
return_z_loss in _bool_to_return_z_loss
|
|
213
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
214
|
+
|
|
165
215
|
BT, V = _input.shape
|
|
166
216
|
n_rows = BT
|
|
167
217
|
|
|
@@ -169,6 +219,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
169
219
|
|
|
170
220
|
# unreduced loss
|
|
171
221
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
222
|
+
if return_z_loss == _TRUE.value:
|
|
223
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
224
|
+
else:
|
|
225
|
+
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
|
|
172
226
|
|
|
173
227
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
174
228
|
|
|
@@ -185,12 +239,15 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
185
239
|
Y_ptr=target,
|
|
186
240
|
Y_stride=target.stride(-1), # always 1
|
|
187
241
|
loss_ptr=loss_1d,
|
|
242
|
+
z_loss_ptr=z_loss_1d,
|
|
188
243
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
189
244
|
n_cols=V,
|
|
190
245
|
n_non_ignore=n_non_ignore,
|
|
191
246
|
ignore_index=ignore_index,
|
|
247
|
+
lse_square_scale=lse_square_scale,
|
|
192
248
|
label_smoothing=label_smoothing,
|
|
193
249
|
reduction=reduction,
|
|
250
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
194
251
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
195
252
|
# TODO: 32 seems to give the best performance
|
|
196
253
|
# Performance is quite sensitive to num_warps
|
|
@@ -198,7 +255,12 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
198
255
|
)
|
|
199
256
|
|
|
200
257
|
loss = torch.sum(loss_1d)
|
|
201
|
-
|
|
258
|
+
if return_z_loss == _TRUE.value:
|
|
259
|
+
z_loss = torch.sum(z_loss_1d)
|
|
260
|
+
else:
|
|
261
|
+
z_loss = None
|
|
262
|
+
|
|
263
|
+
return loss, z_loss, _input
|
|
202
264
|
|
|
203
265
|
|
|
204
266
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -233,7 +295,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
233
295
|
|
|
234
296
|
@staticmethod
|
|
235
297
|
def forward(
|
|
236
|
-
ctx,
|
|
298
|
+
ctx,
|
|
299
|
+
_input,
|
|
300
|
+
target,
|
|
301
|
+
ignore_index=-100,
|
|
302
|
+
lse_square_scale=0.0,
|
|
303
|
+
label_smoothing=0.0,
|
|
304
|
+
reduction="mean",
|
|
305
|
+
return_z_loss=False,
|
|
237
306
|
):
|
|
238
307
|
"""
|
|
239
308
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -243,33 +312,46 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
243
312
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
244
313
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
245
314
|
ignore_index (int): The index to ignore in the target.
|
|
315
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
246
316
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
247
317
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
318
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
248
319
|
|
|
249
320
|
Returns:
|
|
250
|
-
|
|
321
|
+
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
|
251
322
|
"""
|
|
252
|
-
loss, _input = cross_entropy_forward(
|
|
253
|
-
_input,
|
|
323
|
+
loss, z_loss, _input = cross_entropy_forward(
|
|
324
|
+
_input,
|
|
325
|
+
target,
|
|
326
|
+
ignore_index,
|
|
327
|
+
lse_square_scale,
|
|
328
|
+
label_smoothing,
|
|
329
|
+
reduction,
|
|
330
|
+
return_z_loss,
|
|
254
331
|
)
|
|
255
332
|
# TODO: investigation
|
|
256
333
|
# If we don't detach the _input tensor, the memory will double
|
|
257
334
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
258
335
|
ctx.save_for_backward(_input.detach())
|
|
259
|
-
|
|
336
|
+
ctx.return_z_loss = return_z_loss
|
|
337
|
+
|
|
338
|
+
return loss, z_loss
|
|
260
339
|
|
|
261
340
|
@staticmethod
|
|
262
|
-
def backward(ctx, grad_output):
|
|
341
|
+
def backward(ctx, grad_output, grad_ouput2):
|
|
263
342
|
"""
|
|
264
343
|
The backward pass of the Liger Cross Entropy loss.
|
|
265
344
|
|
|
266
345
|
Parameters:
|
|
267
346
|
ctx : The context object with saved tensors.
|
|
268
347
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
269
|
-
|
|
348
|
+
grad_output2 (tenosr): No use.
|
|
270
349
|
Returns:
|
|
271
350
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
272
351
|
"""
|
|
352
|
+
if ctx.return_z_loss:
|
|
353
|
+
del grad_ouput2 # z_loss is only for logging
|
|
354
|
+
|
|
273
355
|
(_input,) = ctx.saved_tensors
|
|
274
356
|
_input = cross_entropy_backward(_input, grad_output)
|
|
275
357
|
return (
|
|
@@ -278,4 +360,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
278
360
|
None,
|
|
279
361
|
None,
|
|
280
362
|
None,
|
|
363
|
+
None,
|
|
364
|
+
None,
|
|
281
365
|
)
|
|
@@ -21,6 +21,7 @@ def fused_linear_cross_entropy_forward(
|
|
|
21
21
|
target,
|
|
22
22
|
bias=None,
|
|
23
23
|
ignore_index=-100,
|
|
24
|
+
lse_square_scale=0.0,
|
|
24
25
|
label_smoothing=0.0,
|
|
25
26
|
reduction="mean",
|
|
26
27
|
):
|
|
@@ -86,12 +87,15 @@ def fused_linear_cross_entropy_forward(
|
|
|
86
87
|
Y_ptr=target_chunk,
|
|
87
88
|
Y_stride=target_chunk.stride(-1), # always 1
|
|
88
89
|
loss_ptr=loss_1d_slice,
|
|
90
|
+
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
|
|
89
91
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
90
92
|
n_cols=V,
|
|
91
93
|
n_non_ignore=n_non_ignore,
|
|
92
94
|
ignore_index=ignore_index,
|
|
95
|
+
lse_square_scale=lse_square_scale,
|
|
93
96
|
label_smoothing=label_smoothing,
|
|
94
97
|
reduction=reduction,
|
|
98
|
+
RETURN_Z_LOSS=0, # False
|
|
95
99
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
96
100
|
num_warps=32 if not is_hip() else 16,
|
|
97
101
|
)
|
|
@@ -200,6 +204,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
200
204
|
target,
|
|
201
205
|
bias=None,
|
|
202
206
|
ignore_index=-100,
|
|
207
|
+
lse_square_scale=0.0,
|
|
203
208
|
label_smoothing=0.0,
|
|
204
209
|
reduction="mean",
|
|
205
210
|
):
|
|
@@ -221,7 +226,14 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
221
226
|
reduction: reduction to apply
|
|
222
227
|
"""
|
|
223
228
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
224
|
-
_input,
|
|
229
|
+
_input,
|
|
230
|
+
weight,
|
|
231
|
+
target,
|
|
232
|
+
bias,
|
|
233
|
+
ignore_index,
|
|
234
|
+
lse_square_scale,
|
|
235
|
+
label_smoothing,
|
|
236
|
+
reduction,
|
|
225
237
|
)
|
|
226
238
|
# downcast to dtype and store for backward
|
|
227
239
|
ctx.save_for_backward(
|
|
@@ -238,4 +250,4 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
238
250
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
239
251
|
grad_output, grad_input, grad_weight, grad_bias
|
|
240
252
|
)
|
|
241
|
-
return (grad_input, grad_weight, None, grad_bias, None, None, None)
|
|
253
|
+
return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerCrossEntropyLoss(nn.Module):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
ignore_index=-100,
|
|
10
|
+
lse_square_scale=0.0,
|
|
11
|
+
label_smoothing=0.0,
|
|
12
|
+
reduction="mean",
|
|
13
|
+
return_z_loss=False,
|
|
14
|
+
):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.ignore_index = ignore_index
|
|
17
|
+
self.lse_square_scale = lse_square_scale
|
|
18
|
+
self.label_smoothing = label_smoothing
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
self.return_z_loss = return_z_loss
|
|
21
|
+
|
|
22
|
+
assert (self.label_smoothing >= 0) and (
|
|
23
|
+
self.label_smoothing <= 1
|
|
24
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
25
|
+
assert self.reduction in {
|
|
26
|
+
"mean",
|
|
27
|
+
"sum",
|
|
28
|
+
"none",
|
|
29
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
|
|
30
|
+
|
|
31
|
+
def forward(self, _input, target):
|
|
32
|
+
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
33
|
+
_input,
|
|
34
|
+
target,
|
|
35
|
+
self.ignore_index,
|
|
36
|
+
self.lse_square_scale,
|
|
37
|
+
self.label_smoothing,
|
|
38
|
+
self.reduction,
|
|
39
|
+
self.return_z_loss,
|
|
40
|
+
)
|
|
41
|
+
if not self.return_z_loss:
|
|
42
|
+
return loss
|
|
43
|
+
return loss, z_loss
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
4
|
+
LigerFusedLinearCrossEntropyFunction,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LigerFusedLinearCrossEntropyLoss(nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
ignore_index=-100,
|
|
12
|
+
label_smoothing=0.0,
|
|
13
|
+
reduction="mean",
|
|
14
|
+
lse_square_scale=0.0,
|
|
15
|
+
):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.ignore_index = ignore_index
|
|
18
|
+
self.label_smoothing = label_smoothing
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
self.lse_square_scale = lse_square_scale
|
|
21
|
+
assert (self.label_smoothing >= 0) and (
|
|
22
|
+
self.label_smoothing <= 1
|
|
23
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
24
|
+
|
|
25
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
26
|
+
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
27
|
+
_input,
|
|
28
|
+
lin_weight,
|
|
29
|
+
target,
|
|
30
|
+
bias,
|
|
31
|
+
self.ignore_index,
|
|
32
|
+
self.lse_square_scale,
|
|
33
|
+
self.label_smoothing,
|
|
34
|
+
self.reduction,
|
|
35
|
+
)
|
|
@@ -99,6 +99,7 @@ def apply_liger_kernel_to_llama(
|
|
|
99
99
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
100
100
|
|
|
101
101
|
from transformers.models.llama import modeling_llama
|
|
102
|
+
from transformers.models.llama.modeling_llama import LlamaModel
|
|
102
103
|
|
|
103
104
|
if rope:
|
|
104
105
|
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -119,15 +120,8 @@ def apply_liger_kernel_to_llama(
|
|
|
119
120
|
# The model instance already exists, so we need to additionally patch the
|
|
120
121
|
# instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
|
|
121
122
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
base_model = model.model
|
|
125
|
-
elif hasattr(model, "transformer"):
|
|
126
|
-
# LlamaForQuestionAnswering uses "transformer" instead of "model"
|
|
127
|
-
base_model = model.transformer
|
|
128
|
-
else:
|
|
129
|
-
# Direct LlamaModel
|
|
130
|
-
base_model = model
|
|
123
|
+
# get the base model from the model instance
|
|
124
|
+
base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
|
|
131
125
|
|
|
132
126
|
if rms_norm:
|
|
133
127
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -275,6 +269,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
275
269
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
276
270
|
|
|
277
271
|
from transformers.models.mistral import modeling_mistral
|
|
272
|
+
from transformers.models.mistral.modeling_mistral import MistralModel
|
|
278
273
|
|
|
279
274
|
if rope:
|
|
280
275
|
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -291,12 +286,8 @@ def apply_liger_kernel_to_mistral(
|
|
|
291
286
|
# The model instance already exists, so we need to additionally patch the
|
|
292
287
|
# instance variables that reference already-instantiated modules
|
|
293
288
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
base_model = model.model
|
|
297
|
-
else:
|
|
298
|
-
# Direct MistralModel
|
|
299
|
-
base_model = model
|
|
289
|
+
# get the base model from the model instance
|
|
290
|
+
base_model: MistralModel = getattr(model, model.base_model_prefix, model)
|
|
300
291
|
|
|
301
292
|
if rms_norm:
|
|
302
293
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -340,6 +331,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
340
331
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
341
332
|
|
|
342
333
|
from transformers.models.mixtral import modeling_mixtral
|
|
334
|
+
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
343
335
|
|
|
344
336
|
if rope:
|
|
345
337
|
modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -360,12 +352,8 @@ def apply_liger_kernel_to_mixtral(
|
|
|
360
352
|
# The model instance already exists, so we need to additionally patch the
|
|
361
353
|
# instance variables that reference already-instantiated modules
|
|
362
354
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
base_model = model.model
|
|
366
|
-
else:
|
|
367
|
-
# Direct MixtralModel
|
|
368
|
-
base_model = model
|
|
355
|
+
# get the base model from the model instance
|
|
356
|
+
base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
|
|
369
357
|
|
|
370
358
|
if rms_norm:
|
|
371
359
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -410,6 +398,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
410
398
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
411
399
|
|
|
412
400
|
from transformers.models.gemma import modeling_gemma
|
|
401
|
+
from transformers.models.gemma.modeling_gemma import GemmaModel
|
|
413
402
|
|
|
414
403
|
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
|
|
415
404
|
LigerRMSNormForGemma = partial(
|
|
@@ -438,12 +427,8 @@ def apply_liger_kernel_to_gemma(
|
|
|
438
427
|
# The model instance already exists, so we need to additionally patch the
|
|
439
428
|
# instance variables that reference already-instantiated modules
|
|
440
429
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
base_model = model.model
|
|
444
|
-
else:
|
|
445
|
-
# Direct GemmaModel
|
|
446
|
-
base_model = model
|
|
430
|
+
# get the base model from the model instance
|
|
431
|
+
base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
|
|
447
432
|
|
|
448
433
|
if rms_norm:
|
|
449
434
|
_patch_rms_norm_module_for_gemma(base_model.norm)
|
|
@@ -478,6 +463,7 @@ def apply_liger_kernel_to_gemma2(
|
|
|
478
463
|
loaded. Default is None.
|
|
479
464
|
"""
|
|
480
465
|
from transformers.models.gemma2 import modeling_gemma2
|
|
466
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
|
|
481
467
|
|
|
482
468
|
LigerRMSNormForGemma2 = partial(
|
|
483
469
|
LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
|
|
@@ -500,12 +486,8 @@ def apply_liger_kernel_to_gemma2(
|
|
|
500
486
|
# The model instance already exists, so we need to additionally patch the
|
|
501
487
|
# instance variables that reference already-instantiated modules
|
|
502
488
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
base_model = model.model
|
|
506
|
-
else:
|
|
507
|
-
# Direct Gemma2Model
|
|
508
|
-
base_model = model
|
|
489
|
+
# get the base model from the model instance
|
|
490
|
+
base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
|
|
509
491
|
|
|
510
492
|
if rms_norm:
|
|
511
493
|
_patch_rms_norm_module_for_gemma2(base_model.norm)
|
|
@@ -556,6 +538,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
556
538
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
557
539
|
|
|
558
540
|
from transformers.models.qwen2 import modeling_qwen2
|
|
541
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
|
559
542
|
|
|
560
543
|
if rope:
|
|
561
544
|
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
@@ -580,12 +563,8 @@ def apply_liger_kernel_to_qwen2(
|
|
|
580
563
|
# The model instance already exists, so we need to additionally patch the
|
|
581
564
|
# instance variables that reference already-instantiated modules
|
|
582
565
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
base_model = model.model
|
|
586
|
-
else:
|
|
587
|
-
# Direct Qwen2Model
|
|
588
|
-
base_model = model
|
|
566
|
+
# get the base model from the model instance
|
|
567
|
+
base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
|
|
589
568
|
|
|
590
569
|
if rms_norm:
|
|
591
570
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -630,6 +609,7 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
630
609
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
631
610
|
|
|
632
611
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
612
|
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
|
|
633
613
|
|
|
634
614
|
from liger_kernel.transformers.model.qwen2_vl import (
|
|
635
615
|
lce_forward as qwen2_vl_lce_forward,
|
|
@@ -653,12 +633,8 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
653
633
|
# The model instance already exists, so we need to additionally patch the
|
|
654
634
|
# instance variables that reference already-instantiated modules
|
|
655
635
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
base_model = model.model
|
|
659
|
-
else:
|
|
660
|
-
# Direct Qwen2VLModel
|
|
661
|
-
base_model = model
|
|
636
|
+
# get the base model from the model instance
|
|
637
|
+
base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
|
|
662
638
|
|
|
663
639
|
if hasattr(model, "visual"):
|
|
664
640
|
# Patch Qwen2VisionTransformerPretrainedModel
|
|
@@ -707,6 +683,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
707
683
|
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
708
684
|
|
|
709
685
|
from transformers.models.phi3 import modeling_phi3
|
|
686
|
+
from transformers.models.phi3.modeling_phi3 import Phi3Model
|
|
710
687
|
|
|
711
688
|
if rope:
|
|
712
689
|
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
|
|
@@ -727,12 +704,8 @@ def apply_liger_kernel_to_phi3(
|
|
|
727
704
|
# The model instance already exists, so we need to additionally patch the
|
|
728
705
|
# instance variables that reference already-instantiated modules
|
|
729
706
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
base_model = model.model
|
|
733
|
-
else:
|
|
734
|
-
# Direct Phi3Model
|
|
735
|
-
base_model = model
|
|
707
|
+
# get the base model from the model instance
|
|
708
|
+
base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
|
|
736
709
|
|
|
737
710
|
if rms_norm:
|
|
738
711
|
_patch_rms_norm_module(base_model.norm)
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
from torch.nn import CrossEntropyLoss
|
|
2
|
-
|
|
3
|
-
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class LigerCrossEntropyLoss(CrossEntropyLoss):
|
|
7
|
-
def __init__(self, *args, **kwargs):
|
|
8
|
-
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
9
|
-
assert (self.label_smoothing >= 0) and (
|
|
10
|
-
self.label_smoothing <= 1
|
|
11
|
-
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
|
|
12
|
-
assert self.reduction in {
|
|
13
|
-
"mean",
|
|
14
|
-
"sum",
|
|
15
|
-
"none",
|
|
16
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
|
|
17
|
-
|
|
18
|
-
def forward(self, _input, target):
|
|
19
|
-
return LigerCrossEntropyFunction.apply(
|
|
20
|
-
_input, target, self.ignore_index, self.label_smoothing, self.reduction
|
|
21
|
-
)
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
from torch.nn import CrossEntropyLoss
|
|
2
|
-
|
|
3
|
-
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
4
|
-
LigerFusedLinearCrossEntropyFunction,
|
|
5
|
-
)
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
|
|
9
|
-
def __init__(self, *args, **kwargs):
|
|
10
|
-
super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
11
|
-
|
|
12
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
|
13
|
-
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
14
|
-
_input,
|
|
15
|
-
lin_weight,
|
|
16
|
-
target,
|
|
17
|
-
bias,
|
|
18
|
-
self.ignore_index,
|
|
19
|
-
self.label_smoothing,
|
|
20
|
-
self.reduction,
|
|
21
|
-
)
|
|
File without changes
|
{liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107054539}/NOTICE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|