liger-kernel-nightly 0.6.4.dev20260107111351__py3-none-any.whl → 0.6.4.dev20260116023519__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/backends/_ascend/ops/__init__.py +6 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +34 -12
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +1 -1
- liger_kernel/ops/fused_add_rms_norm.py +16 -22
- liger_kernel/ops/group_norm.py +10 -7
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +15 -15
- liger_kernel/ops/poly_norm.py +14 -20
- liger_kernel/ops/rms_norm.py +20 -24
- liger_kernel/ops/utils.py +11 -0
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/gemma2.py +3 -3
- liger_kernel/transformers/model/gemma3.py +10 -5
- liger_kernel/transformers/model/loss_utils.py +6 -0
- liger_kernel/transformers/monkey_patch.py +78 -0
- liger_kernel/transformers/tiled_mlp.py +2 -10
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/RECORD +24 -22
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/WHEEL +1 -1
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/top_level.txt +0 -0
|
@@ -26,6 +26,9 @@ from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
|
|
|
26
26
|
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
|
|
27
27
|
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
|
|
28
28
|
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
|
|
29
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
|
|
30
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
|
|
31
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
|
|
29
32
|
|
|
30
33
|
__all__ = [
|
|
31
34
|
"LigerGELUMulFunction",
|
|
@@ -40,4 +43,7 @@ __all__ = [
|
|
|
40
43
|
"LigerSiLUMulFunction",
|
|
41
44
|
"swiglu_forward",
|
|
42
45
|
"swiglu_backward",
|
|
46
|
+
"LigerTVDLossFunction",
|
|
47
|
+
"tv_distance_forward_triton",
|
|
48
|
+
"tvd_backward_triton",
|
|
43
49
|
]
|
|
@@ -130,20 +130,26 @@ def geglu_forward(a, b):
|
|
|
130
130
|
dtype_size = a.element_size()
|
|
131
131
|
# GEGLU forward tiling strategy:
|
|
132
132
|
# - Calculates maximum safe block size based on UB capacity
|
|
133
|
-
# - Memory analysis:
|
|
134
|
-
# * Inputs:
|
|
135
|
-
# *
|
|
136
|
-
# *
|
|
137
|
-
#
|
|
138
|
-
#
|
|
133
|
+
# - Memory analysis (only buffers that occupy UB, excluding temporary variables):
|
|
134
|
+
# * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
|
|
135
|
+
# * Output: c_row (dtype_size bytes)
|
|
136
|
+
# * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
|
|
137
|
+
# and don't occupy UB since they are only used once
|
|
138
|
+
# * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
|
|
139
|
+
# * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
|
|
140
|
+
# - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
|
|
139
141
|
# - shapes: ((n_cols,),)
|
|
140
142
|
# - tiling_dims: (0,) means first dimension can be tiled
|
|
141
143
|
# - Returns: ((block_size,),)
|
|
142
144
|
shapes = ((n_cols,),)
|
|
145
|
+
if dtype_size == 2:
|
|
146
|
+
memory_multiplier = 4.0
|
|
147
|
+
else:
|
|
148
|
+
memory_multiplier = 3.0
|
|
143
149
|
tile_shapes = compute_default_tiling_strategy(
|
|
144
150
|
safety_margin=0.80,
|
|
145
151
|
dtype_size=dtype_size,
|
|
146
|
-
memory_multiplier=
|
|
152
|
+
memory_multiplier=memory_multiplier,
|
|
147
153
|
shapes=shapes,
|
|
148
154
|
tiling_dims=(0,),
|
|
149
155
|
)
|
|
@@ -187,18 +193,34 @@ def geglu_backward(a, b, dc):
|
|
|
187
193
|
dtype_size = dc.element_size()
|
|
188
194
|
# GEGLU backward tiling strategy:
|
|
189
195
|
# - Calculates maximum safe block size based on UB capacity
|
|
190
|
-
# - Memory analysis:
|
|
191
|
-
#
|
|
192
|
-
#
|
|
193
|
-
#
|
|
196
|
+
# - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
|
|
197
|
+
# At this point, the following buffers simultaneously occupy UB:
|
|
198
|
+
# 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
|
|
199
|
+
# 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
|
|
200
|
+
# 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
|
|
201
|
+
# 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
|
|
202
|
+
# 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
|
|
203
|
+
# 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
|
|
204
|
+
# Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
|
|
205
|
+
# Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
|
|
206
|
+
# and don't occupy UB since they are only used once
|
|
207
|
+
# * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
|
|
208
|
+
# = 20 bytes/element, ratio = 20/2 = 10.0
|
|
209
|
+
# * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
|
|
210
|
+
# = 24 bytes/element, ratio = 24/4 = 6.0
|
|
211
|
+
# - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
|
|
194
212
|
# - shapes: ((n_cols,),)
|
|
195
213
|
# - tiling_dims: (0,) means first dimension can be tiled
|
|
196
214
|
# - Returns: ((block_size,),)
|
|
197
215
|
shapes = ((n_cols,),)
|
|
216
|
+
if dtype_size == 2:
|
|
217
|
+
memory_multiplier = 10.0
|
|
218
|
+
else:
|
|
219
|
+
memory_multiplier = 6.0
|
|
198
220
|
tile_shapes = compute_default_tiling_strategy(
|
|
199
221
|
safety_margin=0.80,
|
|
200
222
|
dtype_size=dtype_size,
|
|
201
|
-
memory_multiplier=
|
|
223
|
+
memory_multiplier=memory_multiplier,
|
|
202
224
|
shapes=shapes,
|
|
203
225
|
tiling_dims=(0,),
|
|
204
226
|
)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
MAX_FUSED_SIZE = 65536 // 4
|
|
12
|
+
|
|
13
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@triton.jit
|
|
17
|
+
def _tv_distance_kernel(
|
|
18
|
+
p_ptr,
|
|
19
|
+
p_stride,
|
|
20
|
+
q_ptr,
|
|
21
|
+
q_stride,
|
|
22
|
+
loss_ptr,
|
|
23
|
+
loss_stride,
|
|
24
|
+
grads_ptr,
|
|
25
|
+
grads_stride,
|
|
26
|
+
label_ptr,
|
|
27
|
+
ignore_index: tl.constexpr,
|
|
28
|
+
n_cols, # V
|
|
29
|
+
total_rows: tl.constexpr, # BT
|
|
30
|
+
BLOCK_SIZE: tl.constexpr,
|
|
31
|
+
HAS_LABEL: tl.constexpr,
|
|
32
|
+
reduction: tl.constexpr = "batchmean",
|
|
33
|
+
):
|
|
34
|
+
thread_id = tl.program_id(0)
|
|
35
|
+
num_threads = tl.num_programs(0)
|
|
36
|
+
|
|
37
|
+
for pid in range(thread_id, total_rows, num_threads):
|
|
38
|
+
p_row_ptr = p_ptr + pid * p_stride
|
|
39
|
+
q_row_ptr = q_ptr + pid * q_stride
|
|
40
|
+
loss_row_ptr = loss_ptr + pid * loss_stride
|
|
41
|
+
grads_row_ptr = grads_ptr + pid * grads_stride
|
|
42
|
+
label_row_ptr = label_ptr + pid
|
|
43
|
+
|
|
44
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
45
|
+
|
|
46
|
+
should_skip = False
|
|
47
|
+
if HAS_LABEL:
|
|
48
|
+
label = tl.load(label_row_ptr)
|
|
49
|
+
if label == ignore_index:
|
|
50
|
+
should_skip = True
|
|
51
|
+
|
|
52
|
+
if should_skip:
|
|
53
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
54
|
+
offsets = i + base_offsets
|
|
55
|
+
mask = offsets < n_cols
|
|
56
|
+
tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
|
|
57
|
+
if reduction == "none":
|
|
58
|
+
tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
|
|
59
|
+
else:
|
|
60
|
+
loss_sum = 0.0
|
|
61
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
62
|
+
offsets = i + base_offsets
|
|
63
|
+
mask = offsets < n_cols
|
|
64
|
+
|
|
65
|
+
p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
|
|
66
|
+
q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
|
|
67
|
+
|
|
68
|
+
# TVD(P || Q) = 0.5 * |P - Q|
|
|
69
|
+
tv_loss = 0.5 * tl.abs(p - q)
|
|
70
|
+
grad_res = tl.where(p > q, 0.5, -0.5)
|
|
71
|
+
|
|
72
|
+
tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
|
|
73
|
+
|
|
74
|
+
if reduction == "none":
|
|
75
|
+
tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
|
|
76
|
+
else:
|
|
77
|
+
loss_sum += tl.sum(tv_loss, axis=0)
|
|
78
|
+
|
|
79
|
+
if reduction != "none":
|
|
80
|
+
tl.store(loss_row_ptr, loss_sum)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
|
|
84
|
+
BT, V = p.shape
|
|
85
|
+
|
|
86
|
+
# TVD forward tiling strategy
|
|
87
|
+
# - In main loop (calculate loss and grad):
|
|
88
|
+
# * p: BLOCK_Q elements
|
|
89
|
+
# * q: BLOCK_Q elements
|
|
90
|
+
# * tv_loss: BLOCK_Q elements
|
|
91
|
+
# * grad_res: BLOCK_Q elements
|
|
92
|
+
# * loss_sum: BLOCK_Q elements (when reduction != "none")
|
|
93
|
+
# * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
|
|
94
|
+
# - Since loss_sum is not necessarily used in every calculation,
|
|
95
|
+
# - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
|
|
96
|
+
# - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
|
|
97
|
+
# - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
|
|
98
|
+
# - shapes: ((V,),)
|
|
99
|
+
# - tiling_dims: (0,) means first dimension of each shape can be tiled
|
|
100
|
+
# - Returns: ((block_size,),
|
|
101
|
+
shapes = ((V,),)
|
|
102
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
103
|
+
safety_margin=0.80,
|
|
104
|
+
# In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
|
|
105
|
+
dtype_size=4,
|
|
106
|
+
memory_multiplier=5.0,
|
|
107
|
+
shapes=shapes,
|
|
108
|
+
tiling_dims=(0,),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
|
|
112
|
+
# Strategy returns ((block_size,),)
|
|
113
|
+
BLOCK_SIZE = tile_shapes[0][0]
|
|
114
|
+
else:
|
|
115
|
+
# Fallback to desired block size if no best practice found (no tiling needed)
|
|
116
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
117
|
+
|
|
118
|
+
MAX_BATCH_PER_KERNEL = 65535 # The maximum processing capacity of each kernel in npu
|
|
119
|
+
if BT <= MAX_BATCH_PER_KERNEL:
|
|
120
|
+
grid = (BT,)
|
|
121
|
+
else:
|
|
122
|
+
grid = (MAX_BATCH_PER_KERNEL,)
|
|
123
|
+
|
|
124
|
+
out_size = (BT, V) if reduction == "none" else (BT,)
|
|
125
|
+
output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
|
|
126
|
+
grads = torch.empty_like(p)
|
|
127
|
+
|
|
128
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
|
|
129
|
+
|
|
130
|
+
_tv_distance_kernel[grid](
|
|
131
|
+
p,
|
|
132
|
+
p.stride(0),
|
|
133
|
+
q,
|
|
134
|
+
q.stride(0),
|
|
135
|
+
output_tensor,
|
|
136
|
+
output_tensor.stride(0),
|
|
137
|
+
grads,
|
|
138
|
+
grads.stride(0),
|
|
139
|
+
shift_labels if has_label else torch.empty(1, device=p.device),
|
|
140
|
+
ignore_index,
|
|
141
|
+
V,
|
|
142
|
+
BT,
|
|
143
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
144
|
+
HAS_LABEL=has_label,
|
|
145
|
+
reduction=reduction,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if reduction == "batchmean":
|
|
149
|
+
return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
|
|
150
|
+
elif reduction == "sum":
|
|
151
|
+
return output_tensor.sum(dim=0), grads
|
|
152
|
+
elif reduction == "mean":
|
|
153
|
+
return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
|
|
154
|
+
else:
|
|
155
|
+
return output_tensor, grads
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def tvd_backward_triton(grad_output, grads):
|
|
159
|
+
# If this is the last layer, grad_output is 1.0. Skip the mul then.
|
|
160
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
161
|
+
return grads
|
|
162
|
+
|
|
163
|
+
return grads * grad_output
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class LigerTVDLossFunction(torch.autograd.Function):
|
|
167
|
+
"""
|
|
168
|
+
Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
@ensure_contiguous
|
|
173
|
+
def forward(
|
|
174
|
+
ctx,
|
|
175
|
+
p: torch.Tensor,
|
|
176
|
+
q: torch.Tensor,
|
|
177
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
178
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
179
|
+
ignore_index: int = -100,
|
|
180
|
+
) -> torch.Tensor:
|
|
181
|
+
"""A forward pass for the Total Variation Distance Loss.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
ctx: Torch autograd context
|
|
185
|
+
p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
|
|
186
|
+
q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
|
|
187
|
+
shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
|
|
188
|
+
reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
|
|
189
|
+
ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
torch.Tensor: The computed Total Variation Distance Loss.
|
|
193
|
+
"""
|
|
194
|
+
has_label = False
|
|
195
|
+
if shift_labels is not None:
|
|
196
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
197
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
198
|
+
)
|
|
199
|
+
shift_labels = shift_labels.contiguous()
|
|
200
|
+
has_label = True
|
|
201
|
+
|
|
202
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
203
|
+
ctx.save_for_backward(grads)
|
|
204
|
+
return loss
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
@ensure_contiguous
|
|
208
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
209
|
+
"""A backward pass for the Total Variation Distance Loss.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
ctx: Torch autograd context
|
|
213
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
|
|
217
|
+
"""
|
|
218
|
+
(grads,) = ctx.saved_tensors
|
|
219
|
+
grads = tvd_backward_triton(grad_output, grads)
|
|
220
|
+
|
|
221
|
+
return grads, None, None, None, None
|
|
@@ -241,7 +241,7 @@ def compute_default_tiling_strategy(
|
|
|
241
241
|
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
|
|
242
242
|
Must be provided. If None or <= 0, defaults to 4 (float32).
|
|
243
243
|
memory_multiplier: Memory multiplier for estimating peak memory usage.
|
|
244
|
-
- For GEGLU: typically 10.0 for backward,
|
|
244
|
+
- For GEGLU: typically 10.0 for backward, 4.0 for forward
|
|
245
245
|
- For ROPE: typically 3.0
|
|
246
246
|
If None, defaults to 10.0 (conservative estimate).
|
|
247
247
|
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
@@ -8,6 +8,7 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
11
12
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
12
13
|
from liger_kernel.utils import get_npu_multi_processor_count
|
|
13
14
|
from liger_kernel.utils import is_npu_available
|
|
@@ -162,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
162
163
|
|
|
163
164
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
164
165
|
|
|
165
|
-
dY_ptr += row_start * dY_row_stride
|
|
166
|
-
dX_ptr += row_start * dX_row_stride
|
|
167
|
-
if has_dS_out:
|
|
168
|
-
dS_out_ptr += row_start * dS_out_row_stride
|
|
169
|
-
|
|
170
|
-
X_ptr += row_start * X_row_stride
|
|
171
|
-
RSTD_ptr += row_start
|
|
172
|
-
|
|
173
166
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
174
167
|
W_row = W_row + offset
|
|
175
168
|
|
|
176
|
-
for
|
|
177
|
-
|
|
178
|
-
|
|
169
|
+
for row_idx in range(row_start, row_end):
|
|
170
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
171
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
172
|
+
|
|
173
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
174
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
175
|
+
|
|
176
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
177
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
179
178
|
|
|
180
179
|
# Get cached rms
|
|
181
|
-
rstd_row = tl.load(
|
|
180
|
+
rstd_row = tl.load(rstd_base)
|
|
182
181
|
|
|
183
182
|
X_row = X_row.to(tl.float32)
|
|
184
183
|
|
|
@@ -195,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
195
194
|
dX_row = rstd_row * m
|
|
196
195
|
|
|
197
196
|
if has_dS_out:
|
|
198
|
-
|
|
197
|
+
ds_base = dS_out_ptr + row_idx * dS_out_row_stride
|
|
198
|
+
dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
|
|
199
199
|
dX_row += (rstd_row) * (
|
|
200
200
|
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
201
201
|
) + dS_out_row
|
|
202
|
-
dS_out_ptr += dS_out_row_stride
|
|
203
202
|
else:
|
|
204
203
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
205
204
|
|
|
@@ -210,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
|
|
|
210
209
|
# here X_row is already in fp32 (see previous if block)
|
|
211
210
|
dW_row += dY_row * (X_row * rstd_row)
|
|
212
211
|
|
|
213
|
-
tl.store(
|
|
214
|
-
|
|
215
|
-
dY_ptr += dY_row_stride
|
|
216
|
-
dX_ptr += dX_row_stride
|
|
217
|
-
X_ptr += X_row_stride
|
|
218
|
-
RSTD_ptr += RSTD_row_stride
|
|
212
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
219
213
|
|
|
220
214
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
221
215
|
|
|
@@ -254,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
|
|
|
254
248
|
# XPU-specific optimization
|
|
255
249
|
kernel_args = {}
|
|
256
250
|
if X.device.type == "xpu":
|
|
257
|
-
kernel_args
|
|
251
|
+
set_large_grf_mode(kernel_args)
|
|
258
252
|
|
|
259
253
|
# TODO: add _block_fused_add_rms_norm_forward_kernel
|
|
260
254
|
_fused_add_rms_norm_forward_kernel[(n_rows,)](
|
|
@@ -314,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
314
308
|
# XPU-specific optimization
|
|
315
309
|
kernel_args = {}
|
|
316
310
|
if S.device.type == "xpu":
|
|
317
|
-
kernel_args
|
|
311
|
+
set_large_grf_mode(kernel_args)
|
|
318
312
|
|
|
319
313
|
# TODO: add _block_fused_add_rms_norm_backward_kernel
|
|
320
314
|
_fused_add_rms_norm_backward_kernel[grid](
|
liger_kernel/ops/group_norm.py
CHANGED
|
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
from liger_kernel.utils import is_npu_available
|
|
10
11
|
|
|
11
12
|
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
@@ -18,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
|
18
19
|
else:
|
|
19
20
|
from triton.language.math import rsqrt
|
|
20
21
|
|
|
21
|
-
|
|
22
|
+
if infer_device() == "npu":
|
|
23
|
+
MAX_FUSED_SIZE = 16384 # 8192
|
|
24
|
+
else:
|
|
25
|
+
MAX_FUSED_SIZE = 65536
|
|
22
26
|
|
|
23
27
|
|
|
24
28
|
@triton.jit
|
|
@@ -78,15 +82,14 @@ def _group_norm_forward_kernel(
|
|
|
78
82
|
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
79
83
|
W = tl.load(W_ptr + channel_idx)
|
|
80
84
|
B = tl.load(B_ptr + channel_idx)
|
|
81
|
-
|
|
85
|
+
# Calculate channel offset within the group
|
|
86
|
+
channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
|
|
87
|
+
for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
82
88
|
hidden_size_offsets = i + block_range
|
|
83
89
|
mask = hidden_size_offsets < hidden_size_per_channel
|
|
84
|
-
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
90
|
+
X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
|
|
85
91
|
Y = (X - m) * rstd * W + B
|
|
86
|
-
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
87
|
-
|
|
88
|
-
X_ptr += hidden_size_per_channel
|
|
89
|
-
Y_ptr += hidden_size_per_channel
|
|
92
|
+
tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
|
|
90
93
|
|
|
91
94
|
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
92
95
|
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
|
|
|
21
21
|
return num_warps
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
if infer_device() == "xpu":
|
|
25
|
+
MAX_FUSED_SIZE = 8192
|
|
26
|
+
elif infer_device() == "npu":
|
|
27
|
+
MAX_FUSED_SIZE = 8192
|
|
28
|
+
else:
|
|
29
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
25
30
|
|
|
26
31
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
27
32
|
|
|
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
|
|
|
116
121
|
|
|
117
122
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
118
123
|
BT, V = y_pred.shape
|
|
119
|
-
BLOCK_SIZE = (
|
|
120
|
-
min(8192, triton.next_power_of_2(V))
|
|
121
|
-
if infer_device() == "xpu"
|
|
122
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
-
)
|
|
124
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
124
125
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
125
126
|
|
|
126
127
|
grid = (BT,)
|
|
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
159
160
|
|
|
160
161
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
161
162
|
BT, V = target.shape
|
|
162
|
-
BLOCK_SIZE = (
|
|
163
|
-
min(8192, triton.next_power_of_2(V))
|
|
164
|
-
if infer_device() == "xpu"
|
|
165
|
-
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
-
)
|
|
163
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
167
164
|
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
168
165
|
|
|
169
166
|
grid = (BT,)
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -8,6 +8,8 @@ import triton.language as tl
|
|
|
8
8
|
from liger_kernel.ops.utils import calculate_settings
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
12
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
13
|
from liger_kernel.utils import is_npu_available
|
|
12
14
|
|
|
13
15
|
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
@@ -124,14 +126,14 @@ def _layer_norm_backward_kernel(
|
|
|
124
126
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
125
127
|
w_f32 = w.to(tl.float32)
|
|
126
128
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
for row_idx in range(row_start, row_end):
|
|
130
|
+
# Calculate pointers for this specific row
|
|
131
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
132
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
133
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
134
|
+
row_Mean_ptr = Mean_ptr + row_idx * stride_mean
|
|
135
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
|
|
133
136
|
|
|
134
|
-
for _ in range(row_start, row_end):
|
|
135
137
|
# Load data for this row
|
|
136
138
|
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
139
|
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
@@ -160,12 +162,6 @@ def _layer_norm_backward_kernel(
|
|
|
160
162
|
dW_row += dw
|
|
161
163
|
db_row += db
|
|
162
164
|
|
|
163
|
-
row_X_ptr += stride_x
|
|
164
|
-
row_DX_ptr += stride_dx
|
|
165
|
-
row_DY_ptr += stride_dy
|
|
166
|
-
row_Mean_ptr += stride_mean
|
|
167
|
-
row_RSTD_ptr += stride_rstd
|
|
168
|
-
|
|
169
165
|
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
166
|
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
171
167
|
|
|
@@ -204,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
204
200
|
# XPU-specific optimization
|
|
205
201
|
kernel_args = {}
|
|
206
202
|
if X.device.type == "xpu":
|
|
207
|
-
kernel_args
|
|
203
|
+
set_large_grf_mode(kernel_args)
|
|
208
204
|
|
|
209
205
|
# Launch kernel with one thread block per row for optimal performance
|
|
210
206
|
grid = (n_rows,)
|
|
@@ -254,6 +250,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
254
250
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
255
251
|
elif X.device.type == "xpu":
|
|
256
252
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
253
|
+
elif X.device.type == "npu":
|
|
254
|
+
sm_count = get_npu_multi_processor_count()
|
|
257
255
|
|
|
258
256
|
# fp32 for numerical stability especially.
|
|
259
257
|
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -272,7 +270,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
272
270
|
kernel_args = {"num_warps": num_warps}
|
|
273
271
|
# XPU-specific optimization
|
|
274
272
|
if X.device.type == "xpu":
|
|
275
|
-
kernel_args.update({"
|
|
273
|
+
kernel_args.update({"num_warps": 32, "num_stages": 4})
|
|
274
|
+
set_large_grf_mode(kernel_args)
|
|
276
275
|
|
|
277
276
|
# Launch kernel with one thread block per row for optimal performance
|
|
278
277
|
_layer_norm_backward_kernel[grid](
|
|
@@ -301,6 +300,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
301
300
|
DX = DX.view(*shape)
|
|
302
301
|
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
302
|
DB = _DB.sum(dim=0).to(B.dtype)
|
|
303
|
+
|
|
304
304
|
return DX, DW, DB
|
|
305
305
|
|
|
306
306
|
|
liger_kernel/ops/poly_norm.py
CHANGED
|
@@ -7,6 +7,7 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
10
11
|
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
12
|
from liger_kernel.utils import is_npu_available
|
|
12
13
|
|
|
@@ -140,20 +141,19 @@ def _poly_norm_backward_kernel(
|
|
|
140
141
|
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
141
142
|
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
142
143
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
144
|
+
for row_idx in range(row_start, row_end):
|
|
145
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
146
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
147
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
148
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
147
149
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
152
152
|
|
|
153
153
|
# Load cached rstd values
|
|
154
|
-
rstd_3 = tl.load(
|
|
155
|
-
rstd_2 = tl.load(
|
|
156
|
-
rstd_1 = tl.load(
|
|
154
|
+
rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
|
|
155
|
+
rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
|
|
156
|
+
rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
|
|
157
157
|
|
|
158
158
|
# Compute powers
|
|
159
159
|
X_pow3 = X_row * X_row * X_row
|
|
@@ -190,13 +190,7 @@ def _poly_norm_backward_kernel(
|
|
|
190
190
|
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
191
191
|
|
|
192
192
|
# Store gradient
|
|
193
|
-
tl.store(
|
|
194
|
-
|
|
195
|
-
# Update pointers
|
|
196
|
-
dY_ptr += dY_row_stride
|
|
197
|
-
dX_ptr += dX_row_stride
|
|
198
|
-
X_ptr += X_row_stride
|
|
199
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
+
tl.store(dx_base + col_offsets, dX_row, mask=mask)
|
|
200
194
|
|
|
201
195
|
# Store accumulated gradients (scalars)
|
|
202
196
|
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
@@ -239,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
|
|
|
239
233
|
# XPU-specific optimization
|
|
240
234
|
kernel_args = {}
|
|
241
235
|
if X.device.type == "xpu":
|
|
242
|
-
kernel_args
|
|
236
|
+
set_large_grf_mode(kernel_args)
|
|
243
237
|
|
|
244
238
|
# Launch kernel
|
|
245
239
|
_poly_norm_forward_kernel[(n_rows,)](
|
|
@@ -310,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
|
310
304
|
# XPU-specific optimization
|
|
311
305
|
kernel_args = {}
|
|
312
306
|
if X.device.type == "xpu":
|
|
313
|
-
kernel_args
|
|
307
|
+
set_large_grf_mode(kernel_args)
|
|
314
308
|
|
|
315
309
|
# Launch backward kernel
|
|
316
310
|
_poly_norm_backward_kernel[grid](
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -20,6 +20,7 @@ import triton.language as tl
|
|
|
20
20
|
from liger_kernel.ops.utils import calculate_settings
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
23
24
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
25
|
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
26
|
from liger_kernel.utils import is_npu_available
|
|
@@ -70,11 +71,11 @@ def _rms_norm_forward_kernel(
|
|
|
70
71
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
71
72
|
mask = col_offsets < n_cols
|
|
72
73
|
|
|
73
|
-
Y_ptr
|
|
74
|
-
X_ptr
|
|
75
|
-
RSTD_ptr
|
|
74
|
+
y_base = Y_ptr + row_idx * Y_row_stride
|
|
75
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
76
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
76
77
|
|
|
77
|
-
X_row = tl.load(
|
|
78
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
|
78
79
|
X_row_dtype = X_row.dtype
|
|
79
80
|
if elementwise_affine:
|
|
80
81
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
@@ -99,7 +100,7 @@ def _rms_norm_forward_kernel(
|
|
|
99
100
|
# We can save time by caching rms with minimal memory overhead
|
|
100
101
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
101
102
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
102
|
-
tl.store(
|
|
103
|
+
tl.store(rstd_base, rstd)
|
|
103
104
|
|
|
104
105
|
X_row = X_row * rstd
|
|
105
106
|
|
|
@@ -115,7 +116,7 @@ def _rms_norm_forward_kernel(
|
|
|
115
116
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
116
117
|
Y_row = Y_row.to(X_row_dtype)
|
|
117
118
|
|
|
118
|
-
tl.store(
|
|
119
|
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
|
119
120
|
|
|
120
121
|
|
|
121
122
|
@triton.jit
|
|
@@ -155,22 +156,22 @@ def _rms_norm_backward_kernel(
|
|
|
155
156
|
if elementwise_affine:
|
|
156
157
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
157
158
|
|
|
158
|
-
dY_ptr += row_start * dY_row_stride
|
|
159
|
-
dX_ptr += row_start * dX_row_stride
|
|
160
|
-
|
|
161
|
-
X_ptr += row_start * X_row_stride
|
|
162
|
-
RSTD_ptr += row_start
|
|
163
|
-
|
|
164
159
|
if elementwise_affine:
|
|
165
160
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
166
161
|
W_row = W_row + offset
|
|
167
162
|
|
|
168
|
-
for
|
|
169
|
-
|
|
170
|
-
|
|
163
|
+
for row_idx in range(row_start, row_end):
|
|
164
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
165
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
166
|
+
|
|
167
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
168
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
169
|
+
|
|
170
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
171
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
171
172
|
|
|
172
173
|
# Get cached rms
|
|
173
|
-
rstd_row = tl.load(
|
|
174
|
+
rstd_row = tl.load(rstd_base)
|
|
174
175
|
|
|
175
176
|
X_row = X_row.to(tl.float32)
|
|
176
177
|
|
|
@@ -205,12 +206,7 @@ def _rms_norm_backward_kernel(
|
|
|
205
206
|
# here X_row is already in fp32 (see previous if block)
|
|
206
207
|
dW_row += dY_row * (X_row * rstd_row)
|
|
207
208
|
|
|
208
|
-
tl.store(
|
|
209
|
-
|
|
210
|
-
dY_ptr += dY_row_stride
|
|
211
|
-
dX_ptr += dX_row_stride
|
|
212
|
-
X_ptr += X_row_stride
|
|
213
|
-
RSTD_ptr += RSTD_row_stride
|
|
209
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
214
210
|
|
|
215
211
|
if elementwise_affine:
|
|
216
212
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
@@ -441,7 +437,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
441
437
|
# XPU-specific optimization
|
|
442
438
|
kernel_args = {}
|
|
443
439
|
if X.device.type == "xpu":
|
|
444
|
-
kernel_args
|
|
440
|
+
set_large_grf_mode(kernel_args)
|
|
445
441
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
446
442
|
_rms_norm_forward_kernel[(n_rows,)](
|
|
447
443
|
Y,
|
|
@@ -521,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
521
517
|
# XPU-specific optimization
|
|
522
518
|
kernel_args = {}
|
|
523
519
|
if X.device.type == "xpu":
|
|
524
|
-
kernel_args
|
|
520
|
+
set_large_grf_mode(kernel_args)
|
|
525
521
|
|
|
526
522
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
527
523
|
_rms_norm_backward_kernel[grid](
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -139,3 +139,14 @@ def get_npu_core_count(default: int = 20) -> int:
|
|
|
139
139
|
return int(props.get("num_vectorcore", default))
|
|
140
140
|
except Exception:
|
|
141
141
|
return default
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def set_large_grf_mode(kernel_args: dict):
|
|
145
|
+
"""Set large GRF mode for XPU devices."""
|
|
146
|
+
# On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
|
|
147
|
+
# triton XPU installed from source will be called `triton`.
|
|
148
|
+
if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
|
|
149
|
+
kernel_args["grf_mode"] = "256"
|
|
150
|
+
else:
|
|
151
|
+
# API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
|
|
152
|
+
kernel_args["grf_mode"] = "large"
|
|
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
|
|
|
33
33
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
34
34
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
35
35
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
36
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401
|
|
36
37
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
|
|
37
38
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
38
39
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
@@ -136,6 +137,7 @@ def __getattr__(name: str):
|
|
|
136
137
|
"apply_liger_kernel_to_smolvlm",
|
|
137
138
|
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
138
139
|
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
140
|
+
"apply_liger_kernel_to_exaone4",
|
|
139
141
|
}
|
|
140
142
|
|
|
141
143
|
if name in monkey_patch_symbols:
|
|
@@ -214,5 +216,6 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
214
216
|
"apply_liger_kernel_to_smolvlm",
|
|
215
217
|
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
216
218
|
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
219
|
+
"apply_liger_kernel_to_exaone4",
|
|
217
220
|
]
|
|
218
221
|
)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
8
|
+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
9
|
+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def lce_forward(
|
|
13
|
+
self,
|
|
14
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
15
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
16
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
17
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
18
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
19
|
+
labels: Optional[torch.LongTensor] = None,
|
|
20
|
+
use_cache: Optional[bool] = None,
|
|
21
|
+
output_attentions: Optional[bool] = None,
|
|
22
|
+
output_hidden_states: Optional[bool] = None,
|
|
23
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
24
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
25
|
+
skip_logits: Optional[bool] = None,
|
|
26
|
+
return_dict: Optional[bool] = None,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> LigerCausalLMOutputWithPast:
|
|
29
|
+
r"""
|
|
30
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
31
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
32
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
33
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
34
|
+
|
|
35
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
36
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
37
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
38
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
39
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
40
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
|
|
46
|
+
````python
|
|
47
|
+
>>> from transformers import AutoTokenizer, Exaone4ForCausalLM
|
|
48
|
+
|
|
49
|
+
>>> model = Exaone4ForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
|
|
50
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
|
|
51
|
+
|
|
52
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
53
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
54
|
+
|
|
55
|
+
>>> # Generate
|
|
56
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
57
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
58
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
59
|
+
```"""
|
|
60
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
61
|
+
output_hidden_states = (
|
|
62
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
63
|
+
)
|
|
64
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
65
|
+
|
|
66
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
67
|
+
outputs = self.model(
|
|
68
|
+
input_ids=input_ids,
|
|
69
|
+
attention_mask=attention_mask,
|
|
70
|
+
position_ids=position_ids,
|
|
71
|
+
past_key_values=past_key_values,
|
|
72
|
+
inputs_embeds=inputs_embeds,
|
|
73
|
+
use_cache=use_cache,
|
|
74
|
+
output_attentions=output_attentions,
|
|
75
|
+
output_hidden_states=output_hidden_states,
|
|
76
|
+
cache_position=cache_position,
|
|
77
|
+
**kwargs,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
hidden_states = outputs[0]
|
|
81
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
82
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
83
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
84
|
+
|
|
85
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
86
|
+
# Remove output-control parameters that shouldn't be passed to loss functions
|
|
87
|
+
kwargs.pop("return_dict", None)
|
|
88
|
+
logits = None
|
|
89
|
+
loss = None
|
|
90
|
+
token_accuracy = None
|
|
91
|
+
|
|
92
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
93
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
94
|
+
|
|
95
|
+
if skip_logits is None:
|
|
96
|
+
# By default, if in training mode, don't materialize logits
|
|
97
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
98
|
+
|
|
99
|
+
# Compute loss
|
|
100
|
+
if skip_logits:
|
|
101
|
+
result = LigerForCausalLMLoss(
|
|
102
|
+
hidden_states=kept_hidden_states,
|
|
103
|
+
lm_head_weight=self.lm_head.weight,
|
|
104
|
+
labels=labels,
|
|
105
|
+
shift_labels=shift_labels,
|
|
106
|
+
hidden_size=self.config.hidden_size,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
logits = self.lm_head(kept_hidden_states)
|
|
113
|
+
if labels is not None or shift_labels is not None:
|
|
114
|
+
loss = self.loss_function(
|
|
115
|
+
logits=logits,
|
|
116
|
+
labels=labels,
|
|
117
|
+
shift_labels=shift_labels,
|
|
118
|
+
vocab_size=self.config.vocab_size,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if not return_dict:
|
|
123
|
+
output = (logits,) + outputs[1:]
|
|
124
|
+
output = ((loss,) + output) if loss is not None else output
|
|
125
|
+
output = output + (token_accuracy,) if token_accuracy is not None else output
|
|
126
|
+
return output
|
|
127
|
+
|
|
128
|
+
# Return custom output class with accuracy field
|
|
129
|
+
return LigerCausalLMOutputWithPast(
|
|
130
|
+
loss=loss,
|
|
131
|
+
logits=logits,
|
|
132
|
+
past_key_values=outputs.past_key_values,
|
|
133
|
+
hidden_states=outputs.hidden_states,
|
|
134
|
+
attentions=outputs.attentions,
|
|
135
|
+
token_accuracy=token_accuracy,
|
|
136
|
+
)
|
|
@@ -7,7 +7,7 @@ from typing import Union
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
|
-
from transformers.cache_utils import
|
|
10
|
+
from transformers.cache_utils import Cache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
12
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
13
13
|
|
|
@@ -24,7 +24,7 @@ def lce_forward_deprecated(
|
|
|
24
24
|
input_ids: torch.LongTensor = None,
|
|
25
25
|
attention_mask: Optional[torch.Tensor] = None,
|
|
26
26
|
position_ids: Optional[torch.LongTensor] = None,
|
|
27
|
-
past_key_values: Optional[
|
|
27
|
+
past_key_values: Optional[Cache] = None,
|
|
28
28
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
29
29
|
labels: Optional[torch.LongTensor] = None,
|
|
30
30
|
use_cache: Optional[bool] = None,
|
|
@@ -149,7 +149,7 @@ def lce_forward(
|
|
|
149
149
|
input_ids: torch.LongTensor = None,
|
|
150
150
|
attention_mask: Optional[torch.Tensor] = None,
|
|
151
151
|
position_ids: Optional[torch.LongTensor] = None,
|
|
152
|
-
past_key_values: Optional[
|
|
152
|
+
past_key_values: Optional[Cache] = None,
|
|
153
153
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
154
154
|
labels: Optional[torch.LongTensor] = None,
|
|
155
155
|
use_cache: Optional[bool] = None,
|
|
@@ -6,10 +6,8 @@ import torch
|
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
|
|
8
8
|
from transformers.cache_utils import Cache
|
|
9
|
-
from transformers.cache_utils import HybridCache
|
|
10
9
|
from transformers.utils import logging
|
|
11
10
|
|
|
12
|
-
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
13
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
14
12
|
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
|
|
15
13
|
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
|
|
@@ -23,7 +21,7 @@ def causal_forward(
|
|
|
23
21
|
input_ids: torch.LongTensor = None,
|
|
24
22
|
attention_mask: Optional[torch.Tensor] = None,
|
|
25
23
|
position_ids: Optional[torch.LongTensor] = None,
|
|
26
|
-
past_key_values: Optional[
|
|
24
|
+
past_key_values: Optional[Cache] = None,
|
|
27
25
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
28
26
|
labels: Optional[torch.LongTensor] = None,
|
|
29
27
|
use_cache: Optional[bool] = None,
|
|
@@ -269,8 +267,15 @@ def multimodal_forward(
|
|
|
269
267
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
|
|
270
268
|
shift_labels = shift_labels.view(-1).to(hidden_device)
|
|
271
269
|
|
|
272
|
-
|
|
273
|
-
|
|
270
|
+
result = LigerForCausalLMLoss(
|
|
271
|
+
hidden_states=shift_hidden_states,
|
|
272
|
+
lm_head_weight=self.lm_head.weight,
|
|
273
|
+
labels=shift_labels,
|
|
274
|
+
hidden_size=self.config.text_config.hidden_size,
|
|
275
|
+
shift_labels=shift_labels,
|
|
276
|
+
final_logit_softcapping=getattr(self.config.text_config, "final_logit_softcapping", None),
|
|
277
|
+
**lm_kwargs,
|
|
278
|
+
)
|
|
274
279
|
loss, _, token_accuracy = unpack_cross_entropy_result(result)
|
|
275
280
|
|
|
276
281
|
else:
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
1
3
|
from typing import Optional
|
|
2
4
|
from typing import Tuple
|
|
3
5
|
|
|
@@ -71,6 +73,10 @@ def LigerForCausalLMLoss(
|
|
|
71
73
|
return_token_accuracy: bool = False,
|
|
72
74
|
**kwargs,
|
|
73
75
|
):
|
|
76
|
+
# Filter out inapplicable kwargs to liger_fused_linear_cross_entropy
|
|
77
|
+
applicable_params = inspect.signature(F.liger_fused_linear_cross_entropy).parameters
|
|
78
|
+
kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
|
|
79
|
+
|
|
74
80
|
# Skip upcast since intermediate values for the loss are all fp32 in kernel
|
|
75
81
|
if shift_labels is None:
|
|
76
82
|
# Shift so that token < n predict n
|
|
@@ -2821,6 +2821,83 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
|
|
|
2821
2821
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
2822
2822
|
|
|
2823
2823
|
|
|
2824
|
+
def apply_liger_kernel_to_exaone4(
|
|
2825
|
+
rope: bool = True,
|
|
2826
|
+
cross_entropy: bool = False,
|
|
2827
|
+
fused_linear_cross_entropy: bool = True,
|
|
2828
|
+
rms_norm: bool = True,
|
|
2829
|
+
swiglu: bool = True,
|
|
2830
|
+
model: PreTrainedModel = None,
|
|
2831
|
+
) -> None:
|
|
2832
|
+
"""
|
|
2833
|
+
Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
|
|
2834
|
+
|
|
2835
|
+
Args:
|
|
2836
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
2837
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
2838
|
+
fused_linear_cross_entropy (bool):
|
|
2839
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
2840
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
2841
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
2842
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
2843
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
2844
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
2845
|
+
loaded. Default is None.
|
|
2846
|
+
"""
|
|
2847
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
2848
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
2849
|
+
)
|
|
2850
|
+
|
|
2851
|
+
from transformers.models.exaone4 import modeling_exaone4
|
|
2852
|
+
from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
|
|
2853
|
+
|
|
2854
|
+
from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
|
|
2855
|
+
|
|
2856
|
+
if rope:
|
|
2857
|
+
modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
2858
|
+
|
|
2859
|
+
if rms_norm:
|
|
2860
|
+
# EXAONE4 requires in_place=False to avoid gradient issues
|
|
2861
|
+
class Exaone4LigerRMSNorm(LigerRMSNorm):
|
|
2862
|
+
def __init__(self, hidden_size, eps=1e-6, **kwargs):
|
|
2863
|
+
super().__init__(hidden_size, eps, **kwargs)
|
|
2864
|
+
self.in_place = False
|
|
2865
|
+
|
|
2866
|
+
modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
|
|
2867
|
+
|
|
2868
|
+
if cross_entropy:
|
|
2869
|
+
from transformers.loss.loss_utils import nn
|
|
2870
|
+
|
|
2871
|
+
nn.functional.cross_entropy = liger_cross_entropy
|
|
2872
|
+
|
|
2873
|
+
if fused_linear_cross_entropy:
|
|
2874
|
+
if model is not None:
|
|
2875
|
+
model.forward = MethodType(exaone4_lce_forward, model)
|
|
2876
|
+
else:
|
|
2877
|
+
modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
|
|
2878
|
+
|
|
2879
|
+
if swiglu:
|
|
2880
|
+
modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
|
|
2881
|
+
|
|
2882
|
+
if model is not None:
|
|
2883
|
+
# The model instance already exists, so we need to additionally patch the
|
|
2884
|
+
# instance variables that reference already-instantiated modules
|
|
2885
|
+
|
|
2886
|
+
# get the base model from the model instance
|
|
2887
|
+
base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
|
|
2888
|
+
|
|
2889
|
+
if rms_norm:
|
|
2890
|
+
_patch_rms_norm_module(base_model.norm, in_place=False)
|
|
2891
|
+
for decoder_layer in base_model.layers:
|
|
2892
|
+
if swiglu:
|
|
2893
|
+
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
|
|
2894
|
+
if rms_norm:
|
|
2895
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
2896
|
+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
2897
|
+
_patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
|
|
2898
|
+
_patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
|
|
2899
|
+
|
|
2900
|
+
|
|
2824
2901
|
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
|
|
2825
2902
|
MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
2826
2903
|
"gemma": apply_liger_kernel_to_gemma,
|
|
@@ -2862,6 +2939,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
2862
2939
|
"smolvlm": apply_liger_kernel_to_smolvlm,
|
|
2863
2940
|
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
|
|
2864
2941
|
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
|
|
2942
|
+
"exaone4": apply_liger_kernel_to_exaone4,
|
|
2865
2943
|
}
|
|
2866
2944
|
|
|
2867
2945
|
|
|
@@ -57,11 +57,7 @@ class LigerTiledGEGLUMLP(nn.Module):
|
|
|
57
57
|
Returns:
|
|
58
58
|
Output tensor of the same shape as input
|
|
59
59
|
"""
|
|
60
|
-
compute_params = [
|
|
61
|
-
self.gate_proj.weight,
|
|
62
|
-
self.up_proj.weight,
|
|
63
|
-
self.down_proj.weight,
|
|
64
|
-
]
|
|
60
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
65
61
|
|
|
66
62
|
return apply_tiled_mlp(
|
|
67
63
|
fn=self._mlp_forward,
|
|
@@ -118,11 +114,7 @@ class LigerTiledSwiGLUMLP(nn.Module):
|
|
|
118
114
|
Returns:
|
|
119
115
|
Output tensor of the same shape as input
|
|
120
116
|
"""
|
|
121
|
-
compute_params = [
|
|
122
|
-
self.gate_proj.weight,
|
|
123
|
-
self.up_proj.weight,
|
|
124
|
-
self.down_proj.weight,
|
|
125
|
-
]
|
|
117
|
+
compute_params = [p for p in self.parameters() if p.requires_grad]
|
|
126
118
|
|
|
127
119
|
return apply_tiled_mlp(
|
|
128
120
|
fn=self._mlp_forward,
|
|
@@ -19,42 +19,43 @@ liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZ
|
|
|
19
19
|
liger_kernel/ops/__init__.py,sha256=F3m9qlXbgttykKEBsrMFf1WyK_0H8CKqLuDnFRR-cvc,7237
|
|
20
20
|
liger_kernel/ops/cross_entropy.py,sha256=DnXFRZ9TGN1SnEo8xGBFFPLNQaen8aLVNPJ1em-LbK4,22910
|
|
21
21
|
liger_kernel/ops/dyt.py,sha256=4XmkCCZaPPM8Tl4QHo6vSF2m68jrwsnjucrbyOJvZpM,5628
|
|
22
|
-
liger_kernel/ops/fused_add_rms_norm.py,sha256=
|
|
22
|
+
liger_kernel/ops/fused_add_rms_norm.py,sha256=E4SqFDw13ixd6S3DMhB1HlvtxAfuPL_DiHkgpk3exCI,14174
|
|
23
23
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1gx2qljre9PVc861iknFnNCGC-P35D2w1cc_yMDO9ow,16239
|
|
24
24
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
|
25
25
|
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
|
26
26
|
liger_kernel/ops/geglu.py,sha256=-ruMACDsFH1YsAak6BGvZ0ktLGIrBE6yGF0dAyR82UU,4307
|
|
27
|
-
liger_kernel/ops/group_norm.py,sha256=
|
|
27
|
+
liger_kernel/ops/group_norm.py,sha256=7BqYIP5-HQCdvHKMJlA6jCQoYKZjbtsoD9-eXld5qzk,11133
|
|
28
28
|
liger_kernel/ops/grpo_loss.py,sha256=2SyOujtF9I3xiNo4wFf4s6MeiDotE_qeYfRWgj_bOBE,9573
|
|
29
29
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
|
30
|
-
liger_kernel/ops/kl_div.py,sha256=
|
|
31
|
-
liger_kernel/ops/layer_norm.py,sha256
|
|
30
|
+
liger_kernel/ops/kl_div.py,sha256=MZZb7eAPMXlydYVV4uL9aTytXFkdQdp-jmiDw9tC0pg,8652
|
|
31
|
+
liger_kernel/ops/layer_norm.py,sha256=D1qPDn0HVHfyOmNHQyMDKv7f_JEnFsFxzHgfq9B4rI8,10696
|
|
32
32
|
liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd8s,8223
|
|
33
33
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
34
|
-
liger_kernel/ops/poly_norm.py,sha256=
|
|
34
|
+
liger_kernel/ops/poly_norm.py,sha256=BBwdOtSzW02W-c-UAN8pzn2vAU-AM3gCsWqZnSE5zf4,11288
|
|
35
35
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
36
|
-
liger_kernel/ops/rms_norm.py,sha256=
|
|
36
|
+
liger_kernel/ops/rms_norm.py,sha256=bd5ZAdiqh2iO7a7FdwWH7woslJEVyPlDKXSoUqDZ3GQ,21874
|
|
37
37
|
liger_kernel/ops/rope.py,sha256=v-7JHRrv-5ImoROkpKfl30WwWI4qTa2tAl7zQeB4ml4,8956
|
|
38
38
|
liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
|
|
39
39
|
liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
|
|
40
40
|
liger_kernel/ops/swiglu.py,sha256=D7nd4u_LInwsIRNCDdY77lqnTz8-W5dJrpEAt8zEO_A,3033
|
|
41
41
|
liger_kernel/ops/tiled_mlp.py,sha256=eyMFsFFgHch8a_6R6IYRG24_jqKg5GF_BQUoQuAG8SY,4529
|
|
42
42
|
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
43
|
-
liger_kernel/ops/utils.py,sha256=
|
|
43
|
+
liger_kernel/ops/utils.py,sha256=90V8P0ElZeBathDhmIKm_506Nhrsr1ojO0qRl53_Tn0,4909
|
|
44
44
|
liger_kernel/ops/backends/README.md,sha256=ZP59UUqD1WW8LwM5Y-cTpSM-Dtgdp8Wku2mE9kqAc2E,4185
|
|
45
45
|
liger_kernel/ops/backends/__init__.py,sha256=-mgef3cHfDFeL5NbXbq1TI7ngCahE9qqL3aMaHnXvis,629
|
|
46
46
|
liger_kernel/ops/backends/registry.py,sha256=yJa_Sh2FZ__iPCIU8h2nOQbnsFQh1I-_czROLtb1uQM,1637
|
|
47
47
|
liger_kernel/ops/backends/_ascend/__init__.py,sha256=6n0keOX9H-kLadBdVZlx-Ce0ZLVJvLiEfR-9-uxmYUk,221
|
|
48
48
|
liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md,sha256=FVXHSO1KY4ZFxCAE5r4hOYB2Q8ANyrJZ7WnFJ_GeQOA,19605
|
|
49
|
-
liger_kernel/ops/backends/_ascend/ub_manager.py,sha256=
|
|
50
|
-
liger_kernel/ops/backends/_ascend/ops/__init__.py,sha256=
|
|
51
|
-
liger_kernel/ops/backends/_ascend/ops/geglu.py,sha256=
|
|
49
|
+
liger_kernel/ops/backends/_ascend/ub_manager.py,sha256=3Utke2Dwx9huB0Qoch1KU2CXKN3JS5DbP9_JusIbfQU,13174
|
|
50
|
+
liger_kernel/ops/backends/_ascend/ops/__init__.py,sha256=N41VgPn8D_YJpHez1-UEYTtA-JZxpERmAzN7WcDfE2U,2067
|
|
51
|
+
liger_kernel/ops/backends/_ascend/ops/geglu.py,sha256=M3YFE44UREf91PtOvY0X_GZouUxeeDCy3GmXDrvRLQk,10131
|
|
52
52
|
liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py,sha256=pUYcstJ4FuzDTkuhmQaO3U9gcVQoNCpzuwwUdtES5hM,11015
|
|
53
53
|
liger_kernel/ops/backends/_ascend/ops/rope.py,sha256=nOwtm6_eSnzDjl2S-jvGpwHrumAOgWfr5pNg6SL3R2k,10842
|
|
54
54
|
liger_kernel/ops/backends/_ascend/ops/swiglu.py,sha256=yrbEgIgeCZyayMYHCRNq7LntZE9cEemht39_TFPro0k,4682
|
|
55
|
+
liger_kernel/ops/backends/_ascend/ops/tvd.py,sha256=4Q_DXSuVRqummX5dwFT5zOgQpdaWViLbMPjJ3kWy2IE,7745
|
|
55
56
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
56
57
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
57
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
58
|
+
liger_kernel/transformers/__init__.py,sha256=h7U1Vxrg5OoqOstBmZMd-0G0LROYleYt_fS-RpvEq84,11057
|
|
58
59
|
liger_kernel/transformers/auto_model.py,sha256=RnJhK8xHamRnnswgRLG_muJE1i6T6LszjK8lC6vonhE,2410
|
|
59
60
|
liger_kernel/transformers/cross_entropy.py,sha256=08H8RxSxGX_52UzrHNnSZ_wWH-uvU8KrRiDmVrkOw14,1996
|
|
60
61
|
liger_kernel/transformers/dyt.py,sha256=Rng-MZQSprnGGWFtpmYKt7MIX26vFUYbq5ruM4MjH-U,719
|
|
@@ -71,7 +72,7 @@ liger_kernel/transformers/jsd.py,sha256=_KlOX8YcdONU0tq0bIRDQ5VDBwtywm3Ro-FmlmI0
|
|
|
71
72
|
liger_kernel/transformers/kl_div.py,sha256=94VR4uuj-2dZCTEnwFksvDi-LporrpB5HgmYtQCZnw0,402
|
|
72
73
|
liger_kernel/transformers/layer_norm.py,sha256=l4nsT_Zj4CdVZOM7F0I0Ox-lmLHyIJzqQvVaF0o0HbI,895
|
|
73
74
|
liger_kernel/transformers/llama4_rope.py,sha256=A_nxcS_KiUCyNeL2FAZX7yUhDsX7krrI9BG49OaN_nM,3627
|
|
74
|
-
liger_kernel/transformers/monkey_patch.py,sha256=
|
|
75
|
+
liger_kernel/transformers/monkey_patch.py,sha256=hCFLKviPteLyDTUxjehiUS6k4hEx2GHDEualDhKpEYs,138949
|
|
75
76
|
liger_kernel/transformers/multi_token_attention.py,sha256=LtEjG7qy1-JK-HIPaz8zZ4P08aSZTnj5D635Pa04Onc,1730
|
|
76
77
|
liger_kernel/transformers/poly_norm.py,sha256=T3VdLQHLcCY7KzNzrc6IJRs8SzO8Yc7a0BS_2p6d7Wo,1367
|
|
77
78
|
liger_kernel/transformers/qwen2vl_mrope.py,sha256=0hOBR3j2Yd6xbT4z9BNRKEy1D0eyOUsIW6EmI_3PPNI,1033
|
|
@@ -80,16 +81,17 @@ liger_kernel/transformers/rope.py,sha256=-W9aYLa2hMOmmG5yeHcvPsOI5UTc95ylYxUddxk
|
|
|
80
81
|
liger_kernel/transformers/softmax.py,sha256=VI5QGHYpXSiXckgovEnDGcXwitimsxKB0GX-AT4dAC4,256
|
|
81
82
|
liger_kernel/transformers/sparsemax.py,sha256=Os49bSpPX4pWymsasv_3j20m8GFaI54e03XFPkHiPE0,393
|
|
82
83
|
liger_kernel/transformers/swiglu.py,sha256=LpgikAs9hibAL7G6itygBbOlW9tZe5s4D2IGAKGpbPw,4284
|
|
83
|
-
liger_kernel/transformers/tiled_mlp.py,sha256=
|
|
84
|
+
liger_kernel/transformers/tiled_mlp.py,sha256=_Go2bN8huL4I0EHBPXNfpIRaEukl8hiQEEJIwpJST20,4498
|
|
84
85
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
|
85
86
|
liger_kernel/transformers/tvd.py,sha256=GYjhtXgS3RTPveOTN2gyK4uBnjs6ii2vkSZRX21QpqA,446
|
|
86
87
|
liger_kernel/transformers/experimental/__init__.py,sha256=oQqk-f32JYgWEP9DJCj6ty6bbJSGrdXsFDQFwGeX6vI,127
|
|
87
88
|
liger_kernel/transformers/experimental/embedding.py,sha256=bjy9hHj--ivy6xEWdiE6qLy9uLyeS4PsBEgl_MdDrng,858
|
|
88
89
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
90
|
+
liger_kernel/transformers/model/exaone4.py,sha256=T5Ef2FnkJ-i8ktRWvBB5GXFOIyJmvMPyGsDFt5awpmE,5802
|
|
89
91
|
liger_kernel/transformers/model/falcon_h1.py,sha256=heUZ4wUt2ATmtBtmv8Rcro3pQl6fV9T0pburjTTW7os,5004
|
|
90
92
|
liger_kernel/transformers/model/gemma.py,sha256=pAri4PYpknsFfkvyo8Ez2NNlqrUDW-KkExUXTGZAcH4,10621
|
|
91
|
-
liger_kernel/transformers/model/gemma2.py,sha256=
|
|
92
|
-
liger_kernel/transformers/model/gemma3.py,sha256=
|
|
93
|
+
liger_kernel/transformers/model/gemma2.py,sha256=KgSpXVi04c8hVFa7dqJtjzVobz6z7BNTvGc1WjoV4nk,12006
|
|
94
|
+
liger_kernel/transformers/model/gemma3.py,sha256=2XPmtpZxR55wccKflIDqf2AwHJdxypUbd62fLuZ8two,15092
|
|
93
95
|
liger_kernel/transformers/model/glm4.py,sha256=bSp22iPIjsli4-c_usUOsyh1Bs2gIK8X6ynS0azseUs,5900
|
|
94
96
|
liger_kernel/transformers/model/glm4v.py,sha256=dd-BQpccDCp1SbIxcJ5rG8xcwYQK3KOv1Tgm9TGnZc4,6594
|
|
95
97
|
liger_kernel/transformers/model/glm4v_moe.py,sha256=zKhMdOOrRhlrvCSFaeVYfddL1ubpY8edEO91TN81n98,7135
|
|
@@ -99,7 +101,7 @@ liger_kernel/transformers/model/internvl.py,sha256=OOutracs9qrPHSU7FVYar08yinvGr
|
|
|
99
101
|
liger_kernel/transformers/model/llama.py,sha256=kqZeONzwTBzudoChlKMzq1w23BtYGbxWZC1l1V__JTw,13410
|
|
100
102
|
liger_kernel/transformers/model/llama4.py,sha256=PfkynGVI0xxMs3EtyYpCgaALI6stu25OIrTIymE-pvg,4853
|
|
101
103
|
liger_kernel/transformers/model/llava.py,sha256=yoADM_BuIEummtTDiwWqjfUjXUMZD78VJzS0TRj5GJ4,15687
|
|
102
|
-
liger_kernel/transformers/model/loss_utils.py,sha256=
|
|
104
|
+
liger_kernel/transformers/model/loss_utils.py,sha256=tNbC94Z4Ca2mlv3MRhnqfpJ7sBc5MZJtt1-mzMMJT1M,3088
|
|
103
105
|
liger_kernel/transformers/model/mistral.py,sha256=OcwOzVDMwwDbVccVPv-AaocznzWwzLT3aRaKK5SMaAg,6030
|
|
104
106
|
liger_kernel/transformers/model/mixtral.py,sha256=YcBDoTEJDgLFJ_RTo180DYGxR8D5Ad9-idumif7kCPE,12130
|
|
105
107
|
liger_kernel/transformers/model/mllama.py,sha256=vAHwCm63sn4kpAY0rDGf_N0HR7KRTBVpBYDVTPOaZTg,12079
|
|
@@ -122,9 +124,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
122
124
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
123
125
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
124
126
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
125
|
-
liger_kernel_nightly-0.6.4.
|
|
126
|
-
liger_kernel_nightly-0.6.4.
|
|
127
|
-
liger_kernel_nightly-0.6.4.
|
|
128
|
-
liger_kernel_nightly-0.6.4.
|
|
129
|
-
liger_kernel_nightly-0.6.4.
|
|
130
|
-
liger_kernel_nightly-0.6.4.
|
|
127
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
128
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/METADATA,sha256=Ja1hknX3Qd5-8K5-BO7pX4Ln11BgPKgBrYBjf291kzU,25660
|
|
129
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
130
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/WHEEL,sha256=WnJ8fYhv8N4SYVK2lLYNI6N0kVATA7b0piVUNvqIIJE,91
|
|
131
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
132
|
+
liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|