liger-kernel-nightly 0.6.4.dev20260107111351__py3-none-any.whl → 0.6.4.dev20260116023519__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. liger_kernel/ops/backends/_ascend/ops/__init__.py +6 -0
  2. liger_kernel/ops/backends/_ascend/ops/geglu.py +34 -12
  3. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  4. liger_kernel/ops/backends/_ascend/ub_manager.py +1 -1
  5. liger_kernel/ops/fused_add_rms_norm.py +16 -22
  6. liger_kernel/ops/group_norm.py +10 -7
  7. liger_kernel/ops/kl_div.py +8 -11
  8. liger_kernel/ops/layer_norm.py +15 -15
  9. liger_kernel/ops/poly_norm.py +14 -20
  10. liger_kernel/ops/rms_norm.py +20 -24
  11. liger_kernel/ops/utils.py +11 -0
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/model/exaone4.py +136 -0
  14. liger_kernel/transformers/model/gemma2.py +3 -3
  15. liger_kernel/transformers/model/gemma3.py +10 -5
  16. liger_kernel/transformers/model/loss_utils.py +6 -0
  17. liger_kernel/transformers/monkey_patch.py +78 -0
  18. liger_kernel/transformers/tiled_mlp.py +2 -10
  19. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/METADATA +1 -1
  20. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/RECORD +24 -22
  21. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/WHEEL +1 -1
  22. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/LICENSE +0 -0
  23. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/NOTICE +0 -0
  24. {liger_kernel_nightly-0.6.4.dev20260107111351.dist-info → liger_kernel_nightly-0.6.4.dev20260116023519.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,9 @@ from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
26
26
  from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
27
27
  from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
28
28
  from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
29
+ from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
30
+ from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
31
+ from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
29
32
 
30
33
  __all__ = [
31
34
  "LigerGELUMulFunction",
@@ -40,4 +43,7 @@ __all__ = [
40
43
  "LigerSiLUMulFunction",
41
44
  "swiglu_forward",
42
45
  "swiglu_backward",
46
+ "LigerTVDLossFunction",
47
+ "tv_distance_forward_triton",
48
+ "tvd_backward_triton",
43
49
  ]
@@ -130,20 +130,26 @@ def geglu_forward(a, b):
130
130
  dtype_size = a.element_size()
131
131
  # GEGLU forward tiling strategy:
132
132
  # - Calculates maximum safe block size based on UB capacity
133
- # - Memory analysis:
134
- # * Inputs: a, b
135
- # * Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
136
- # * Output: c
137
- # * Total: ~7x * BLOCK_SIZE * dtype_size
138
- # - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
133
+ # - Memory analysis (only buffers that occupy UB, excluding temporary variables):
134
+ # * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
135
+ # * Output: c_row (dtype_size bytes)
136
+ # * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
137
+ # and don't occupy UB since they are only used once
138
+ # * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
139
+ # * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
140
+ # - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
139
141
  # - shapes: ((n_cols,),)
140
142
  # - tiling_dims: (0,) means first dimension can be tiled
141
143
  # - Returns: ((block_size,),)
142
144
  shapes = ((n_cols,),)
145
+ if dtype_size == 2:
146
+ memory_multiplier = 4.0
147
+ else:
148
+ memory_multiplier = 3.0
143
149
  tile_shapes = compute_default_tiling_strategy(
144
150
  safety_margin=0.80,
145
151
  dtype_size=dtype_size,
146
- memory_multiplier=7.0,
152
+ memory_multiplier=memory_multiplier,
147
153
  shapes=shapes,
148
154
  tiling_dims=(0,),
149
155
  )
@@ -187,18 +193,34 @@ def geglu_backward(a, b, dc):
187
193
  dtype_size = dc.element_size()
188
194
  # GEGLU backward tiling strategy:
189
195
  # - Calculates maximum safe block size based on UB capacity
190
- # - Memory analysis:
191
- # * More intermediates for gradient computation compared to forward
192
- # * Total: ~10x * BLOCK_SIZE * dtype_size
193
- # - Uses memory_multiplier=10.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
196
+ # - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
197
+ # At this point, the following buffers simultaneously occupy UB:
198
+ # 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
199
+ # 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
200
+ # 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
201
+ # 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
202
+ # 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
203
+ # 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
204
+ # Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
205
+ # Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
206
+ # and don't occupy UB since they are only used once
207
+ # * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
208
+ # = 20 bytes/element, ratio = 20/2 = 10.0
209
+ # * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
210
+ # = 24 bytes/element, ratio = 24/4 = 6.0
211
+ # - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
194
212
  # - shapes: ((n_cols,),)
195
213
  # - tiling_dims: (0,) means first dimension can be tiled
196
214
  # - Returns: ((block_size,),)
197
215
  shapes = ((n_cols,),)
216
+ if dtype_size == 2:
217
+ memory_multiplier = 10.0
218
+ else:
219
+ memory_multiplier = 6.0
198
220
  tile_shapes = compute_default_tiling_strategy(
199
221
  safety_margin=0.80,
200
222
  dtype_size=dtype_size,
201
- memory_multiplier=10.0,
223
+ memory_multiplier=memory_multiplier,
202
224
  shapes=shapes,
203
225
  tiling_dims=(0,),
204
226
  )
@@ -0,0 +1,221 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+ MAX_FUSED_SIZE = 65536 // 4
12
+
13
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
14
+
15
+
16
+ @triton.jit
17
+ def _tv_distance_kernel(
18
+ p_ptr,
19
+ p_stride,
20
+ q_ptr,
21
+ q_stride,
22
+ loss_ptr,
23
+ loss_stride,
24
+ grads_ptr,
25
+ grads_stride,
26
+ label_ptr,
27
+ ignore_index: tl.constexpr,
28
+ n_cols, # V
29
+ total_rows: tl.constexpr, # BT
30
+ BLOCK_SIZE: tl.constexpr,
31
+ HAS_LABEL: tl.constexpr,
32
+ reduction: tl.constexpr = "batchmean",
33
+ ):
34
+ thread_id = tl.program_id(0)
35
+ num_threads = tl.num_programs(0)
36
+
37
+ for pid in range(thread_id, total_rows, num_threads):
38
+ p_row_ptr = p_ptr + pid * p_stride
39
+ q_row_ptr = q_ptr + pid * q_stride
40
+ loss_row_ptr = loss_ptr + pid * loss_stride
41
+ grads_row_ptr = grads_ptr + pid * grads_stride
42
+ label_row_ptr = label_ptr + pid
43
+
44
+ base_offsets = tl.arange(0, BLOCK_SIZE)
45
+
46
+ should_skip = False
47
+ if HAS_LABEL:
48
+ label = tl.load(label_row_ptr)
49
+ if label == ignore_index:
50
+ should_skip = True
51
+
52
+ if should_skip:
53
+ for i in range(0, n_cols, BLOCK_SIZE):
54
+ offsets = i + base_offsets
55
+ mask = offsets < n_cols
56
+ tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
57
+ if reduction == "none":
58
+ tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
59
+ else:
60
+ loss_sum = 0.0
61
+ for i in range(0, n_cols, BLOCK_SIZE):
62
+ offsets = i + base_offsets
63
+ mask = offsets < n_cols
64
+
65
+ p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
66
+ q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
67
+
68
+ # TVD(P || Q) = 0.5 * |P - Q|
69
+ tv_loss = 0.5 * tl.abs(p - q)
70
+ grad_res = tl.where(p > q, 0.5, -0.5)
71
+
72
+ tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
73
+
74
+ if reduction == "none":
75
+ tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
76
+ else:
77
+ loss_sum += tl.sum(tv_loss, axis=0)
78
+
79
+ if reduction != "none":
80
+ tl.store(loss_row_ptr, loss_sum)
81
+
82
+
83
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
84
+ BT, V = p.shape
85
+
86
+ # TVD forward tiling strategy
87
+ # - In main loop (calculate loss and grad):
88
+ # * p: BLOCK_Q elements
89
+ # * q: BLOCK_Q elements
90
+ # * tv_loss: BLOCK_Q elements
91
+ # * grad_res: BLOCK_Q elements
92
+ # * loss_sum: BLOCK_Q elements (when reduction != "none")
93
+ # * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
94
+ # - Since loss_sum is not necessarily used in every calculation,
95
+ # - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
96
+ # - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
97
+ # - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
98
+ # - shapes: ((V,),)
99
+ # - tiling_dims: (0,) means first dimension of each shape can be tiled
100
+ # - Returns: ((block_size,),
101
+ shapes = ((V,),)
102
+ tile_shapes = compute_default_tiling_strategy(
103
+ safety_margin=0.80,
104
+ # In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
105
+ dtype_size=4,
106
+ memory_multiplier=5.0,
107
+ shapes=shapes,
108
+ tiling_dims=(0,),
109
+ )
110
+
111
+ if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
112
+ # Strategy returns ((block_size,),)
113
+ BLOCK_SIZE = tile_shapes[0][0]
114
+ else:
115
+ # Fallback to desired block size if no best practice found (no tiling needed)
116
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
117
+
118
+ MAX_BATCH_PER_KERNEL = 65535 # The maximum processing capacity of each kernel in npu
119
+ if BT <= MAX_BATCH_PER_KERNEL:
120
+ grid = (BT,)
121
+ else:
122
+ grid = (MAX_BATCH_PER_KERNEL,)
123
+
124
+ out_size = (BT, V) if reduction == "none" else (BT,)
125
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
126
+ grads = torch.empty_like(p)
127
+
128
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
129
+
130
+ _tv_distance_kernel[grid](
131
+ p,
132
+ p.stride(0),
133
+ q,
134
+ q.stride(0),
135
+ output_tensor,
136
+ output_tensor.stride(0),
137
+ grads,
138
+ grads.stride(0),
139
+ shift_labels if has_label else torch.empty(1, device=p.device),
140
+ ignore_index,
141
+ V,
142
+ BT,
143
+ BLOCK_SIZE=BLOCK_SIZE,
144
+ HAS_LABEL=has_label,
145
+ reduction=reduction,
146
+ )
147
+
148
+ if reduction == "batchmean":
149
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
150
+ elif reduction == "sum":
151
+ return output_tensor.sum(dim=0), grads
152
+ elif reduction == "mean":
153
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
154
+ else:
155
+ return output_tensor, grads
156
+
157
+
158
+ def tvd_backward_triton(grad_output, grads):
159
+ # If this is the last layer, grad_output is 1.0. Skip the mul then.
160
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
161
+ return grads
162
+
163
+ return grads * grad_output
164
+
165
+
166
+ class LigerTVDLossFunction(torch.autograd.Function):
167
+ """
168
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
169
+ """
170
+
171
+ @staticmethod
172
+ @ensure_contiguous
173
+ def forward(
174
+ ctx,
175
+ p: torch.Tensor,
176
+ q: torch.Tensor,
177
+ shift_labels: Optional[torch.Tensor] = None,
178
+ reduction: REDUCTION_LITERAL = "batchmean",
179
+ ignore_index: int = -100,
180
+ ) -> torch.Tensor:
181
+ """A forward pass for the Total Variation Distance Loss.
182
+
183
+ Args:
184
+ ctx: Torch autograd context
185
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
186
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
187
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
188
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
189
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
190
+
191
+ Returns:
192
+ torch.Tensor: The computed Total Variation Distance Loss.
193
+ """
194
+ has_label = False
195
+ if shift_labels is not None:
196
+ assert shift_labels.shape == (p.shape[0],), (
197
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
198
+ )
199
+ shift_labels = shift_labels.contiguous()
200
+ has_label = True
201
+
202
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
203
+ ctx.save_for_backward(grads)
204
+ return loss
205
+
206
+ @staticmethod
207
+ @ensure_contiguous
208
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
209
+ """A backward pass for the Total Variation Distance Loss.
210
+
211
+ Args:
212
+ ctx: Torch autograd context
213
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
214
+
215
+ Returns:
216
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
217
+ """
218
+ (grads,) = ctx.saved_tensors
219
+ grads = tvd_backward_triton(grad_output, grads)
220
+
221
+ return grads, None, None, None, None
@@ -241,7 +241,7 @@ def compute_default_tiling_strategy(
241
241
  dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
242
242
  Must be provided. If None or <= 0, defaults to 4 (float32).
243
243
  memory_multiplier: Memory multiplier for estimating peak memory usage.
244
- - For GEGLU: typically 10.0 for backward, 7.0 for forward
244
+ - For GEGLU: typically 10.0 for backward, 4.0 for forward
245
245
  - For ROPE: typically 3.0
246
246
  If None, defaults to 10.0 (conservative estimate).
247
247
  shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
@@ -8,6 +8,7 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import set_large_grf_mode
11
12
  from liger_kernel.ops.utils import torch_to_triton_dtype
12
13
  from liger_kernel.utils import get_npu_multi_processor_count
13
14
  from liger_kernel.utils import is_npu_available
@@ -162,23 +163,21 @@ def _fused_add_rms_norm_backward_kernel(
162
163
 
163
164
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
164
165
 
165
- dY_ptr += row_start * dY_row_stride
166
- dX_ptr += row_start * dX_row_stride
167
- if has_dS_out:
168
- dS_out_ptr += row_start * dS_out_row_stride
169
-
170
- X_ptr += row_start * X_row_stride
171
- RSTD_ptr += row_start
172
-
173
166
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
174
167
  W_row = W_row + offset
175
168
 
176
- for _ in range(row_start, row_end):
177
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
178
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
169
+ for row_idx in range(row_start, row_end):
170
+ dy_base = dY_ptr + row_idx * dY_row_stride
171
+ dx_base = dX_ptr + row_idx * dX_row_stride
172
+
173
+ x_base = X_ptr + row_idx * X_row_stride
174
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
175
+
176
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
177
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
179
178
 
180
179
  # Get cached rms
181
- rstd_row = tl.load(RSTD_ptr)
180
+ rstd_row = tl.load(rstd_base)
182
181
 
183
182
  X_row = X_row.to(tl.float32)
184
183
 
@@ -195,11 +194,11 @@ def _fused_add_rms_norm_backward_kernel(
195
194
  dX_row = rstd_row * m
196
195
 
197
196
  if has_dS_out:
198
- dS_out_row = tl.load(dS_out_ptr + col_offsets, mask=mask, other=0.0)
197
+ ds_base = dS_out_ptr + row_idx * dS_out_row_stride
198
+ dS_out_row = tl.load(ds_base + col_offsets, mask=mask, other=0.0)
199
199
  dX_row += (rstd_row) * (
200
200
  -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
201
201
  ) + dS_out_row
202
- dS_out_ptr += dS_out_row_stride
203
202
  else:
204
203
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
205
204
 
@@ -210,12 +209,7 @@ def _fused_add_rms_norm_backward_kernel(
210
209
  # here X_row is already in fp32 (see previous if block)
211
210
  dW_row += dY_row * (X_row * rstd_row)
212
211
 
213
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
214
-
215
- dY_ptr += dY_row_stride
216
- dX_ptr += dX_row_stride
217
- X_ptr += X_row_stride
218
- RSTD_ptr += RSTD_row_stride
212
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
219
213
 
220
214
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
221
215
 
@@ -254,7 +248,7 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode):
254
248
  # XPU-specific optimization
255
249
  kernel_args = {}
256
250
  if X.device.type == "xpu":
257
- kernel_args["grf_mode"] = "large"
251
+ set_large_grf_mode(kernel_args)
258
252
 
259
253
  # TODO: add _block_fused_add_rms_norm_forward_kernel
260
254
  _fused_add_rms_norm_forward_kernel[(n_rows,)](
@@ -314,7 +308,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
314
308
  # XPU-specific optimization
315
309
  kernel_args = {}
316
310
  if S.device.type == "xpu":
317
- kernel_args["grf_mode"] = "large"
311
+ set_large_grf_mode(kernel_args)
318
312
 
319
313
  # TODO: add _block_fused_add_rms_norm_backward_kernel
320
314
  _fused_add_rms_norm_backward_kernel[grid](
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import infer_device
9
10
  from liger_kernel.utils import is_npu_available
10
11
 
11
12
  if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -18,7 +19,10 @@ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
18
19
  else:
19
20
  from triton.language.math import rsqrt
20
21
 
21
- MAX_FUSED_SIZE = 65536
22
+ if infer_device() == "npu":
23
+ MAX_FUSED_SIZE = 16384 # 8192
24
+ else:
25
+ MAX_FUSED_SIZE = 65536
22
26
 
23
27
 
24
28
  @triton.jit
@@ -78,15 +82,14 @@ def _group_norm_forward_kernel(
78
82
  for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
79
83
  W = tl.load(W_ptr + channel_idx)
80
84
  B = tl.load(B_ptr + channel_idx)
81
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
85
+ # Calculate channel offset within the group
86
+ channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
87
+ for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
82
88
  hidden_size_offsets = i + block_range
83
89
  mask = hidden_size_offsets < hidden_size_per_channel
84
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
90
+ X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
85
91
  Y = (X - m) * rstd * W + B
86
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
87
-
88
- X_ptr += hidden_size_per_channel
89
- Y_ptr += hidden_size_per_channel
92
+ tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
90
93
 
91
94
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
92
95
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
21
21
  return num_warps
22
22
 
23
23
 
24
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
30
 
26
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
32
 
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
116
121
 
117
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
123
  BT, V = y_pred.shape
119
- BLOCK_SIZE = (
120
- min(8192, triton.next_power_of_2(V))
121
- if infer_device() == "xpu"
122
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
- )
124
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
124
125
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
126
 
126
127
  grid = (BT,)
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
159
160
 
160
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
162
  BT, V = target.shape
162
- BLOCK_SIZE = (
163
- min(8192, triton.next_power_of_2(V))
164
- if infer_device() == "xpu"
165
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
- )
163
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
167
164
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
165
 
169
166
  grid = (BT,)
@@ -8,6 +8,8 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import set_large_grf_mode
12
+ from liger_kernel.utils import get_npu_multi_processor_count
11
13
  from liger_kernel.utils import is_npu_available
12
14
 
13
15
  if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -124,14 +126,14 @@ def _layer_norm_backward_kernel(
124
126
  w = tl.load(W_ptr + cols, mask=mask, other=0.0)
125
127
  w_f32 = w.to(tl.float32)
126
128
 
127
- # Calculate pointers for this specific row
128
- row_X_ptr = X_ptr + row_start * stride_x
129
- row_DX_ptr = DX_ptr + row_start * stride_dx
130
- row_DY_ptr = DY_ptr + row_start * stride_dy
131
- row_Mean_ptr = Mean_ptr + row_start
132
- row_RSTD_ptr = RSTD_ptr + row_start
129
+ for row_idx in range(row_start, row_end):
130
+ # Calculate pointers for this specific row
131
+ row_X_ptr = X_ptr + row_idx * stride_x
132
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
133
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
134
+ row_Mean_ptr = Mean_ptr + row_idx * stride_mean
135
+ row_RSTD_ptr = RSTD_ptr + row_idx * stride_rstd
133
136
 
134
- for _ in range(row_start, row_end):
135
137
  # Load data for this row
136
138
  x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
139
  dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
@@ -160,12 +162,6 @@ def _layer_norm_backward_kernel(
160
162
  dW_row += dw
161
163
  db_row += db
162
164
 
163
- row_X_ptr += stride_x
164
- row_DX_ptr += stride_dx
165
- row_DY_ptr += stride_dy
166
- row_Mean_ptr += stride_mean
167
- row_RSTD_ptr += stride_rstd
168
-
169
165
  tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
166
  tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
171
167
 
@@ -204,7 +200,7 @@ def layer_norm_forward(X, W, B, eps):
204
200
  # XPU-specific optimization
205
201
  kernel_args = {}
206
202
  if X.device.type == "xpu":
207
- kernel_args["grf_mode"] = "large"
203
+ set_large_grf_mode(kernel_args)
208
204
 
209
205
  # Launch kernel with one thread block per row for optimal performance
210
206
  grid = (n_rows,)
@@ -254,6 +250,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
254
250
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
255
251
  elif X.device.type == "xpu":
256
252
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
253
+ elif X.device.type == "npu":
254
+ sm_count = get_npu_multi_processor_count()
257
255
 
258
256
  # fp32 for numerical stability especially.
259
257
  _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -272,7 +270,8 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
272
270
  kernel_args = {"num_warps": num_warps}
273
271
  # XPU-specific optimization
274
272
  if X.device.type == "xpu":
275
- kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
273
+ kernel_args.update({"num_warps": 32, "num_stages": 4})
274
+ set_large_grf_mode(kernel_args)
276
275
 
277
276
  # Launch kernel with one thread block per row for optimal performance
278
277
  _layer_norm_backward_kernel[grid](
@@ -301,6 +300,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
301
300
  DX = DX.view(*shape)
302
301
  DW = _DW.sum(dim=0).to(W.dtype)
303
302
  DB = _DB.sum(dim=0).to(B.dtype)
303
+
304
304
  return DX, DW, DB
305
305
 
306
306
 
@@ -7,6 +7,7 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import set_large_grf_mode
10
11
  from liger_kernel.utils import get_npu_multi_processor_count
11
12
  from liger_kernel.utils import is_npu_available
12
13
 
@@ -140,20 +141,19 @@ def _poly_norm_backward_kernel(
140
141
  w1 = tl.load(W_ptr + 1).to(tl.float32)
141
142
  w2 = tl.load(W_ptr + 2).to(tl.float32)
142
143
 
143
- dY_ptr += row_start * dY_row_stride
144
- dX_ptr += row_start * dX_row_stride
145
- X_ptr += row_start * X_row_stride
146
- RSTD_ptr += row_start * RSTD_row_stride
144
+ for row_idx in range(row_start, row_end):
145
+ dy_base = dY_ptr + row_idx * dY_row_stride
146
+ x_base = X_ptr + row_idx * X_row_stride
147
+ dx_base = dX_ptr + row_idx * dX_row_stride
148
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
147
149
 
148
- for _ in range(row_start, row_end):
149
- # Load input and gradient
150
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
152
152
 
153
153
  # Load cached rstd values
154
- rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
155
- rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
156
- rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
154
+ rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
155
+ rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
156
+ rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
157
157
 
158
158
  # Compute powers
159
159
  X_pow3 = X_row * X_row * X_row
@@ -190,13 +190,7 @@ def _poly_norm_backward_kernel(
190
190
  dX_row = grad_x_3 + grad_x_2 + grad_x_1
191
191
 
192
192
  # Store gradient
193
- tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
194
-
195
- # Update pointers
196
- dY_ptr += dY_row_stride
197
- dX_ptr += dX_row_stride
198
- X_ptr += X_row_stride
199
- RSTD_ptr += RSTD_row_stride
193
+ tl.store(dx_base + col_offsets, dX_row, mask=mask)
200
194
 
201
195
  # Store accumulated gradients (scalars)
202
196
  tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
@@ -239,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
239
233
  # XPU-specific optimization
240
234
  kernel_args = {}
241
235
  if X.device.type == "xpu":
242
- kernel_args["grf_mode"] = "large"
236
+ set_large_grf_mode(kernel_args)
243
237
 
244
238
  # Launch kernel
245
239
  _poly_norm_forward_kernel[(n_rows,)](
@@ -310,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
310
304
  # XPU-specific optimization
311
305
  kernel_args = {}
312
306
  if X.device.type == "xpu":
313
- kernel_args["grf_mode"] = "large"
307
+ set_large_grf_mode(kernel_args)
314
308
 
315
309
  # Launch backward kernel
316
310
  _poly_norm_backward_kernel[grid](
@@ -20,6 +20,7 @@ import triton.language as tl
20
20
  from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import set_large_grf_mode
23
24
  from liger_kernel.ops.utils import torch_to_triton_dtype
24
25
  from liger_kernel.utils import get_npu_multi_processor_count
25
26
  from liger_kernel.utils import is_npu_available
@@ -70,11 +71,11 @@ def _rms_norm_forward_kernel(
70
71
  col_offsets = tl.arange(0, BLOCK_SIZE)
71
72
  mask = col_offsets < n_cols
72
73
 
73
- Y_ptr += row_idx * Y_row_stride
74
- X_ptr += row_idx * X_row_stride
75
- RSTD_ptr += row_idx * RSTD_row_stride
74
+ y_base = Y_ptr + row_idx * Y_row_stride
75
+ x_base = X_ptr + row_idx * X_row_stride
76
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
76
77
 
77
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
78
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
78
79
  X_row_dtype = X_row.dtype
79
80
  if elementwise_affine:
80
81
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
@@ -99,7 +100,7 @@ def _rms_norm_forward_kernel(
99
100
  # We can save time by caching rms with minimal memory overhead
100
101
  # because rms is much smaller compared to X_row, as rms is for each row.
101
102
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
102
- tl.store(RSTD_ptr, rstd)
103
+ tl.store(rstd_base, rstd)
103
104
 
104
105
  X_row = X_row * rstd
105
106
 
@@ -115,7 +116,7 @@ def _rms_norm_forward_kernel(
115
116
  if casting_mode == _CASTING_MODE_GEMMA:
116
117
  Y_row = Y_row.to(X_row_dtype)
117
118
 
118
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
119
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
119
120
 
120
121
 
121
122
  @triton.jit
@@ -155,22 +156,22 @@ def _rms_norm_backward_kernel(
155
156
  if elementwise_affine:
156
157
  dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
157
158
 
158
- dY_ptr += row_start * dY_row_stride
159
- dX_ptr += row_start * dX_row_stride
160
-
161
- X_ptr += row_start * X_row_stride
162
- RSTD_ptr += row_start
163
-
164
159
  if elementwise_affine:
165
160
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
166
161
  W_row = W_row + offset
167
162
 
168
- for _ in range(row_start, row_end):
169
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
170
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
163
+ for row_idx in range(row_start, row_end):
164
+ dy_base = dY_ptr + row_idx * dY_row_stride
165
+ dx_base = dX_ptr + row_idx * dX_row_stride
166
+
167
+ x_base = X_ptr + row_idx * X_row_stride
168
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
169
+
170
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
171
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
171
172
 
172
173
  # Get cached rms
173
- rstd_row = tl.load(RSTD_ptr)
174
+ rstd_row = tl.load(rstd_base)
174
175
 
175
176
  X_row = X_row.to(tl.float32)
176
177
 
@@ -205,12 +206,7 @@ def _rms_norm_backward_kernel(
205
206
  # here X_row is already in fp32 (see previous if block)
206
207
  dW_row += dY_row * (X_row * rstd_row)
207
208
 
208
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
209
-
210
- dY_ptr += dY_row_stride
211
- dX_ptr += dX_row_stride
212
- X_ptr += X_row_stride
213
- RSTD_ptr += RSTD_row_stride
209
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
214
210
 
215
211
  if elementwise_affine:
216
212
  tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
@@ -441,7 +437,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
441
437
  # XPU-specific optimization
442
438
  kernel_args = {}
443
439
  if X.device.type == "xpu":
444
- kernel_args["grf_mode"] = "large"
440
+ set_large_grf_mode(kernel_args)
445
441
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
446
442
  _rms_norm_forward_kernel[(n_rows,)](
447
443
  Y,
@@ -521,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
521
517
  # XPU-specific optimization
522
518
  kernel_args = {}
523
519
  if X.device.type == "xpu":
524
- kernel_args["grf_mode"] = "large"
520
+ set_large_grf_mode(kernel_args)
525
521
 
526
522
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
527
523
  _rms_norm_backward_kernel[grid](
liger_kernel/ops/utils.py CHANGED
@@ -139,3 +139,14 @@ def get_npu_core_count(default: int = 20) -> int:
139
139
  return int(props.get("num_vectorcore", default))
140
140
  except Exception:
141
141
  return default
142
+
143
+
144
+ def set_large_grf_mode(kernel_args: dict):
145
+ """Set large GRF mode for XPU devices."""
146
+ # On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
147
+ # triton XPU installed from source will be called `triton`.
148
+ if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
149
+ kernel_args["grf_mode"] = "256"
150
+ else:
151
+ # API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
152
+ kernel_args["grf_mode"] = "large"
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
33
33
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
34
34
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
35
35
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401
36
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
37
38
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
38
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
@@ -136,6 +137,7 @@ def __getattr__(name: str):
136
137
  "apply_liger_kernel_to_smolvlm",
137
138
  "apply_liger_kernel_to_hunyuan_v1_dense",
138
139
  "apply_liger_kernel_to_hunyuan_v1_moe",
140
+ "apply_liger_kernel_to_exaone4",
139
141
  }
140
142
 
141
143
  if name in monkey_patch_symbols:
@@ -214,5 +216,6 @@ if _TRANSFORMERS_AVAILABLE:
214
216
  "apply_liger_kernel_to_smolvlm",
215
217
  "apply_liger_kernel_to_hunyuan_v1_dense",
216
218
  "apply_liger_kernel_to_hunyuan_v1_moe",
219
+ "apply_liger_kernel_to_exaone4",
217
220
  ]
218
221
  )
@@ -0,0 +1,136 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ````python
47
+ >>> from transformers import AutoTokenizer, Exaone4ForCausalLM
48
+
49
+ >>> model = Exaone4ForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-1.2B")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ # Remove output-control parameters that shouldn't be passed to loss functions
87
+ kwargs.pop("return_dict", None)
88
+ logits = None
89
+ loss = None
90
+ token_accuracy = None
91
+
92
+ if skip_logits and labels is None and shift_labels is None:
93
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
94
+
95
+ if skip_logits is None:
96
+ # By default, if in training mode, don't materialize logits
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ # Compute loss
100
+ if skip_logits:
101
+ result = LigerForCausalLMLoss(
102
+ hidden_states=kept_hidden_states,
103
+ lm_head_weight=self.lm_head.weight,
104
+ labels=labels,
105
+ shift_labels=shift_labels,
106
+ hidden_size=self.config.hidden_size,
107
+ **kwargs,
108
+ )
109
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
110
+
111
+ else:
112
+ logits = self.lm_head(kept_hidden_states)
113
+ if labels is not None or shift_labels is not None:
114
+ loss = self.loss_function(
115
+ logits=logits,
116
+ labels=labels,
117
+ shift_labels=shift_labels,
118
+ vocab_size=self.config.vocab_size,
119
+ **kwargs,
120
+ )
121
+
122
+ if not return_dict:
123
+ output = (logits,) + outputs[1:]
124
+ output = ((loss,) + output) if loss is not None else output
125
+ output = output + (token_accuracy,) if token_accuracy is not None else output
126
+ return output
127
+
128
+ # Return custom output class with accuracy field
129
+ return LigerCausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ token_accuracy=token_accuracy,
136
+ )
@@ -7,7 +7,7 @@ from typing import Union
7
7
  import torch
8
8
 
9
9
  from torch.nn import CrossEntropyLoss
10
- from transformers.cache_utils import HybridCache
10
+ from transformers.cache_utils import Cache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
12
  from transformers.utils.deprecation import deprecate_kwarg
13
13
 
@@ -24,7 +24,7 @@ def lce_forward_deprecated(
24
24
  input_ids: torch.LongTensor = None,
25
25
  attention_mask: Optional[torch.Tensor] = None,
26
26
  position_ids: Optional[torch.LongTensor] = None,
27
- past_key_values: Optional[HybridCache] = None,
27
+ past_key_values: Optional[Cache] = None,
28
28
  inputs_embeds: Optional[torch.FloatTensor] = None,
29
29
  labels: Optional[torch.LongTensor] = None,
30
30
  use_cache: Optional[bool] = None,
@@ -149,7 +149,7 @@ def lce_forward(
149
149
  input_ids: torch.LongTensor = None,
150
150
  attention_mask: Optional[torch.Tensor] = None,
151
151
  position_ids: Optional[torch.LongTensor] = None,
152
- past_key_values: Optional[HybridCache] = None,
152
+ past_key_values: Optional[Cache] = None,
153
153
  inputs_embeds: Optional[torch.FloatTensor] = None,
154
154
  labels: Optional[torch.LongTensor] = None,
155
155
  use_cache: Optional[bool] = None,
@@ -6,10 +6,8 @@ import torch
6
6
  import torch.nn as nn
7
7
 
8
8
  from transformers.cache_utils import Cache
9
- from transformers.cache_utils import HybridCache
10
9
  from transformers.utils import logging
11
10
 
12
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
12
  from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
13
  from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
@@ -23,7 +21,7 @@ def causal_forward(
23
21
  input_ids: torch.LongTensor = None,
24
22
  attention_mask: Optional[torch.Tensor] = None,
25
23
  position_ids: Optional[torch.LongTensor] = None,
26
- past_key_values: Optional[HybridCache] = None,
24
+ past_key_values: Optional[Cache] = None,
27
25
  inputs_embeds: Optional[torch.FloatTensor] = None,
28
26
  labels: Optional[torch.LongTensor] = None,
29
27
  use_cache: Optional[bool] = None,
@@ -269,8 +267,15 @@ def multimodal_forward(
269
267
  shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
270
268
  shift_labels = shift_labels.view(-1).to(hidden_device)
271
269
 
272
- lce = LigerFusedLinearCrossEntropyLoss()
273
- result = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
270
+ result = LigerForCausalLMLoss(
271
+ hidden_states=shift_hidden_states,
272
+ lm_head_weight=self.lm_head.weight,
273
+ labels=shift_labels,
274
+ hidden_size=self.config.text_config.hidden_size,
275
+ shift_labels=shift_labels,
276
+ final_logit_softcapping=getattr(self.config.text_config, "final_logit_softcapping", None),
277
+ **lm_kwargs,
278
+ )
274
279
  loss, _, token_accuracy = unpack_cross_entropy_result(result)
275
280
 
276
281
  else:
@@ -1,3 +1,5 @@
1
+ import inspect
2
+
1
3
  from typing import Optional
2
4
  from typing import Tuple
3
5
 
@@ -71,6 +73,10 @@ def LigerForCausalLMLoss(
71
73
  return_token_accuracy: bool = False,
72
74
  **kwargs,
73
75
  ):
76
+ # Filter out inapplicable kwargs to liger_fused_linear_cross_entropy
77
+ applicable_params = inspect.signature(F.liger_fused_linear_cross_entropy).parameters
78
+ kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
79
+
74
80
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
75
81
  if shift_labels is None:
76
82
  # Shift so that token < n predict n
@@ -2821,6 +2821,83 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
2821
2821
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2822
2822
 
2823
2823
 
2824
+ def apply_liger_kernel_to_exaone4(
2825
+ rope: bool = True,
2826
+ cross_entropy: bool = False,
2827
+ fused_linear_cross_entropy: bool = True,
2828
+ rms_norm: bool = True,
2829
+ swiglu: bool = True,
2830
+ model: PreTrainedModel = None,
2831
+ ) -> None:
2832
+ """
2833
+ Apply Liger kernels to replace original implementation in HuggingFace EXAONE4 models.
2834
+
2835
+ Args:
2836
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2837
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2838
+ fused_linear_cross_entropy (bool):
2839
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2840
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2841
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2842
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2843
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2844
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2845
+ loaded. Default is None.
2846
+ """
2847
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2848
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2849
+ )
2850
+
2851
+ from transformers.models.exaone4 import modeling_exaone4
2852
+ from transformers.models.exaone4.modeling_exaone4 import Exaone4Model
2853
+
2854
+ from liger_kernel.transformers.model.exaone4 import lce_forward as exaone4_lce_forward
2855
+
2856
+ if rope:
2857
+ modeling_exaone4.apply_rotary_pos_emb = liger_rotary_pos_emb
2858
+
2859
+ if rms_norm:
2860
+ # EXAONE4 requires in_place=False to avoid gradient issues
2861
+ class Exaone4LigerRMSNorm(LigerRMSNorm):
2862
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
2863
+ super().__init__(hidden_size, eps, **kwargs)
2864
+ self.in_place = False
2865
+
2866
+ modeling_exaone4.Exaone4RMSNorm = Exaone4LigerRMSNorm
2867
+
2868
+ if cross_entropy:
2869
+ from transformers.loss.loss_utils import nn
2870
+
2871
+ nn.functional.cross_entropy = liger_cross_entropy
2872
+
2873
+ if fused_linear_cross_entropy:
2874
+ if model is not None:
2875
+ model.forward = MethodType(exaone4_lce_forward, model)
2876
+ else:
2877
+ modeling_exaone4.Exaone4ForCausalLM.forward = exaone4_lce_forward
2878
+
2879
+ if swiglu:
2880
+ modeling_exaone4.Exaone4MLP = LigerSwiGLUMLP
2881
+
2882
+ if model is not None:
2883
+ # The model instance already exists, so we need to additionally patch the
2884
+ # instance variables that reference already-instantiated modules
2885
+
2886
+ # get the base model from the model instance
2887
+ base_model: Exaone4Model = getattr(model, model.base_model_prefix, model)
2888
+
2889
+ if rms_norm:
2890
+ _patch_rms_norm_module(base_model.norm, in_place=False)
2891
+ for decoder_layer in base_model.layers:
2892
+ if swiglu:
2893
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
2894
+ if rms_norm:
2895
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2896
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2897
+ _patch_rms_norm_module(decoder_layer.self_attn.q_norm, in_place=False)
2898
+ _patch_rms_norm_module(decoder_layer.self_attn.k_norm, in_place=False)
2899
+
2900
+
2824
2901
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2825
2902
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2826
2903
  "gemma": apply_liger_kernel_to_gemma,
@@ -2862,6 +2939,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2862
2939
  "smolvlm": apply_liger_kernel_to_smolvlm,
2863
2940
  "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2864
2941
  "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2942
+ "exaone4": apply_liger_kernel_to_exaone4,
2865
2943
  }
2866
2944
 
2867
2945
 
@@ -57,11 +57,7 @@ class LigerTiledGEGLUMLP(nn.Module):
57
57
  Returns:
58
58
  Output tensor of the same shape as input
59
59
  """
60
- compute_params = [
61
- self.gate_proj.weight,
62
- self.up_proj.weight,
63
- self.down_proj.weight,
64
- ]
60
+ compute_params = [p for p in self.parameters() if p.requires_grad]
65
61
 
66
62
  return apply_tiled_mlp(
67
63
  fn=self._mlp_forward,
@@ -118,11 +114,7 @@ class LigerTiledSwiGLUMLP(nn.Module):
118
114
  Returns:
119
115
  Output tensor of the same shape as input
120
116
  """
121
- compute_params = [
122
- self.gate_proj.weight,
123
- self.up_proj.weight,
124
- self.down_proj.weight,
125
- ]
117
+ compute_params = [p for p in self.parameters() if p.requires_grad]
126
118
 
127
119
  return apply_tiled_mlp(
128
120
  fn=self._mlp_forward,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.4.dev20260107111351
3
+ Version: 0.6.4.dev20260116023519
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -19,42 +19,43 @@ liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZ
19
19
  liger_kernel/ops/__init__.py,sha256=F3m9qlXbgttykKEBsrMFf1WyK_0H8CKqLuDnFRR-cvc,7237
20
20
  liger_kernel/ops/cross_entropy.py,sha256=DnXFRZ9TGN1SnEo8xGBFFPLNQaen8aLVNPJ1em-LbK4,22910
21
21
  liger_kernel/ops/dyt.py,sha256=4XmkCCZaPPM8Tl4QHo6vSF2m68jrwsnjucrbyOJvZpM,5628
22
- liger_kernel/ops/fused_add_rms_norm.py,sha256=lvwrLsKvoAQqS9KatgBkAyy0Xdecado-g0rvXYXaBak,14237
22
+ liger_kernel/ops/fused_add_rms_norm.py,sha256=E4SqFDw13ixd6S3DMhB1HlvtxAfuPL_DiHkgpk3exCI,14174
23
23
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1gx2qljre9PVc861iknFnNCGC-P35D2w1cc_yMDO9ow,16239
24
24
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
25
25
  liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
26
26
  liger_kernel/ops/geglu.py,sha256=-ruMACDsFH1YsAak6BGvZ0ktLGIrBE6yGF0dAyR82UU,4307
27
- liger_kernel/ops/group_norm.py,sha256=zoy-TcNkYtKGmGhTFJmnyiG_4Es4ZphpqP8jtUSI6-I,10912
27
+ liger_kernel/ops/group_norm.py,sha256=7BqYIP5-HQCdvHKMJlA6jCQoYKZjbtsoD9-eXld5qzk,11133
28
28
  liger_kernel/ops/grpo_loss.py,sha256=2SyOujtF9I3xiNo4wFf4s6MeiDotE_qeYfRWgj_bOBE,9573
29
29
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
30
- liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
31
- liger_kernel/ops/layer_norm.py,sha256=-4UEyko9eKgBi5LNmfdEU2hTpJOWVnEy5iYjJkMvHmk,10598
30
+ liger_kernel/ops/kl_div.py,sha256=MZZb7eAPMXlydYVV4uL9aTytXFkdQdp-jmiDw9tC0pg,8652
31
+ liger_kernel/ops/layer_norm.py,sha256=D1qPDn0HVHfyOmNHQyMDKv7f_JEnFsFxzHgfq9B4rI8,10696
32
32
  liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd8s,8223
33
33
  liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
34
- liger_kernel/ops/poly_norm.py,sha256=5IdJEZnbbhblkL_X8UhSD4A2CooQbOAZJw8nAekWNs4,11372
34
+ liger_kernel/ops/poly_norm.py,sha256=BBwdOtSzW02W-c-UAN8pzn2vAU-AM3gCsWqZnSE5zf4,11288
35
35
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
36
- liger_kernel/ops/rms_norm.py,sha256=r97gpPmhbKz9qrBjxUEX0XP04aYu4psJeLe3KnhPZyo,21852
36
+ liger_kernel/ops/rms_norm.py,sha256=bd5ZAdiqh2iO7a7FdwWH7woslJEVyPlDKXSoUqDZ3GQ,21874
37
37
  liger_kernel/ops/rope.py,sha256=v-7JHRrv-5ImoROkpKfl30WwWI4qTa2tAl7zQeB4ml4,8956
38
38
  liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
39
39
  liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
40
40
  liger_kernel/ops/swiglu.py,sha256=D7nd4u_LInwsIRNCDdY77lqnTz8-W5dJrpEAt8zEO_A,3033
41
41
  liger_kernel/ops/tiled_mlp.py,sha256=eyMFsFFgHch8a_6R6IYRG24_jqKg5GF_BQUoQuAG8SY,4529
42
42
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
43
- liger_kernel/ops/utils.py,sha256=Xu6MJ2-lbp4hSmI0JGImKguKU0KqWnFQDgQwOxSieyc,4360
43
+ liger_kernel/ops/utils.py,sha256=90V8P0ElZeBathDhmIKm_506Nhrsr1ojO0qRl53_Tn0,4909
44
44
  liger_kernel/ops/backends/README.md,sha256=ZP59UUqD1WW8LwM5Y-cTpSM-Dtgdp8Wku2mE9kqAc2E,4185
45
45
  liger_kernel/ops/backends/__init__.py,sha256=-mgef3cHfDFeL5NbXbq1TI7ngCahE9qqL3aMaHnXvis,629
46
46
  liger_kernel/ops/backends/registry.py,sha256=yJa_Sh2FZ__iPCIU8h2nOQbnsFQh1I-_czROLtb1uQM,1637
47
47
  liger_kernel/ops/backends/_ascend/__init__.py,sha256=6n0keOX9H-kLadBdVZlx-Ce0ZLVJvLiEfR-9-uxmYUk,221
48
48
  liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md,sha256=FVXHSO1KY4ZFxCAE5r4hOYB2Q8ANyrJZ7WnFJ_GeQOA,19605
49
- liger_kernel/ops/backends/_ascend/ub_manager.py,sha256=3h7sncZk00veBJS37a01YPt1SLeAxJj5N3lPdv1wXAk,13174
50
- liger_kernel/ops/backends/_ascend/ops/__init__.py,sha256=R1iS9R0EtmGbrN0cSkIiRtZouVl7ndiPVZJIoEALb7s,1748
51
- liger_kernel/ops/backends/_ascend/ops/geglu.py,sha256=hs1Cdhw0pbgZFiK1srLuo8DCe8jtnmhjm5SS2vw8-0M,8421
49
+ liger_kernel/ops/backends/_ascend/ub_manager.py,sha256=3Utke2Dwx9huB0Qoch1KU2CXKN3JS5DbP9_JusIbfQU,13174
50
+ liger_kernel/ops/backends/_ascend/ops/__init__.py,sha256=N41VgPn8D_YJpHez1-UEYTtA-JZxpERmAzN7WcDfE2U,2067
51
+ liger_kernel/ops/backends/_ascend/ops/geglu.py,sha256=M3YFE44UREf91PtOvY0X_GZouUxeeDCy3GmXDrvRLQk,10131
52
52
  liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py,sha256=pUYcstJ4FuzDTkuhmQaO3U9gcVQoNCpzuwwUdtES5hM,11015
53
53
  liger_kernel/ops/backends/_ascend/ops/rope.py,sha256=nOwtm6_eSnzDjl2S-jvGpwHrumAOgWfr5pNg6SL3R2k,10842
54
54
  liger_kernel/ops/backends/_ascend/ops/swiglu.py,sha256=yrbEgIgeCZyayMYHCRNq7LntZE9cEemht39_TFPro0k,4682
55
+ liger_kernel/ops/backends/_ascend/ops/tvd.py,sha256=4Q_DXSuVRqummX5dwFT5zOgQpdaWViLbMPjJ3kWy2IE,7745
55
56
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
56
57
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
57
- liger_kernel/transformers/__init__.py,sha256=4sqcDbOZ_JtS9Ag-7oyuhq5jN298GyzjJFu9J-DyyZQ,10872
58
+ liger_kernel/transformers/__init__.py,sha256=h7U1Vxrg5OoqOstBmZMd-0G0LROYleYt_fS-RpvEq84,11057
58
59
  liger_kernel/transformers/auto_model.py,sha256=RnJhK8xHamRnnswgRLG_muJE1i6T6LszjK8lC6vonhE,2410
59
60
  liger_kernel/transformers/cross_entropy.py,sha256=08H8RxSxGX_52UzrHNnSZ_wWH-uvU8KrRiDmVrkOw14,1996
60
61
  liger_kernel/transformers/dyt.py,sha256=Rng-MZQSprnGGWFtpmYKt7MIX26vFUYbq5ruM4MjH-U,719
@@ -71,7 +72,7 @@ liger_kernel/transformers/jsd.py,sha256=_KlOX8YcdONU0tq0bIRDQ5VDBwtywm3Ro-FmlmI0
71
72
  liger_kernel/transformers/kl_div.py,sha256=94VR4uuj-2dZCTEnwFksvDi-LporrpB5HgmYtQCZnw0,402
72
73
  liger_kernel/transformers/layer_norm.py,sha256=l4nsT_Zj4CdVZOM7F0I0Ox-lmLHyIJzqQvVaF0o0HbI,895
73
74
  liger_kernel/transformers/llama4_rope.py,sha256=A_nxcS_KiUCyNeL2FAZX7yUhDsX7krrI9BG49OaN_nM,3627
74
- liger_kernel/transformers/monkey_patch.py,sha256=ESFIi_7hQMcnUtRLjAMJ9kbzSbwToDhpOfFa6aQ-SrY,135534
75
+ liger_kernel/transformers/monkey_patch.py,sha256=hCFLKviPteLyDTUxjehiUS6k4hEx2GHDEualDhKpEYs,138949
75
76
  liger_kernel/transformers/multi_token_attention.py,sha256=LtEjG7qy1-JK-HIPaz8zZ4P08aSZTnj5D635Pa04Onc,1730
76
77
  liger_kernel/transformers/poly_norm.py,sha256=T3VdLQHLcCY7KzNzrc6IJRs8SzO8Yc7a0BS_2p6d7Wo,1367
77
78
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=0hOBR3j2Yd6xbT4z9BNRKEy1D0eyOUsIW6EmI_3PPNI,1033
@@ -80,16 +81,17 @@ liger_kernel/transformers/rope.py,sha256=-W9aYLa2hMOmmG5yeHcvPsOI5UTc95ylYxUddxk
80
81
  liger_kernel/transformers/softmax.py,sha256=VI5QGHYpXSiXckgovEnDGcXwitimsxKB0GX-AT4dAC4,256
81
82
  liger_kernel/transformers/sparsemax.py,sha256=Os49bSpPX4pWymsasv_3j20m8GFaI54e03XFPkHiPE0,393
82
83
  liger_kernel/transformers/swiglu.py,sha256=LpgikAs9hibAL7G6itygBbOlW9tZe5s4D2IGAKGpbPw,4284
83
- liger_kernel/transformers/tiled_mlp.py,sha256=gPsz7b0kxpk3mre7o1uGBt-XdNvMUN7IIqnUYIur-T0,4628
84
+ liger_kernel/transformers/tiled_mlp.py,sha256=_Go2bN8huL4I0EHBPXNfpIRaEukl8hiQEEJIwpJST20,4498
84
85
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
85
86
  liger_kernel/transformers/tvd.py,sha256=GYjhtXgS3RTPveOTN2gyK4uBnjs6ii2vkSZRX21QpqA,446
86
87
  liger_kernel/transformers/experimental/__init__.py,sha256=oQqk-f32JYgWEP9DJCj6ty6bbJSGrdXsFDQFwGeX6vI,127
87
88
  liger_kernel/transformers/experimental/embedding.py,sha256=bjy9hHj--ivy6xEWdiE6qLy9uLyeS4PsBEgl_MdDrng,858
88
89
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
90
+ liger_kernel/transformers/model/exaone4.py,sha256=T5Ef2FnkJ-i8ktRWvBB5GXFOIyJmvMPyGsDFt5awpmE,5802
89
91
  liger_kernel/transformers/model/falcon_h1.py,sha256=heUZ4wUt2ATmtBtmv8Rcro3pQl6fV9T0pburjTTW7os,5004
90
92
  liger_kernel/transformers/model/gemma.py,sha256=pAri4PYpknsFfkvyo8Ez2NNlqrUDW-KkExUXTGZAcH4,10621
91
- liger_kernel/transformers/model/gemma2.py,sha256=qa9Ok42vFojVGNmASTH3Ek566Vu507kjd--ZpZDKX9M,12024
92
- liger_kernel/transformers/model/gemma3.py,sha256=ZUrFCc-pfF8jYHV0HsptBr98hx6p2q9ea0kSzVAoFPo,14966
93
+ liger_kernel/transformers/model/gemma2.py,sha256=KgSpXVi04c8hVFa7dqJtjzVobz6z7BNTvGc1WjoV4nk,12006
94
+ liger_kernel/transformers/model/gemma3.py,sha256=2XPmtpZxR55wccKflIDqf2AwHJdxypUbd62fLuZ8two,15092
93
95
  liger_kernel/transformers/model/glm4.py,sha256=bSp22iPIjsli4-c_usUOsyh1Bs2gIK8X6ynS0azseUs,5900
94
96
  liger_kernel/transformers/model/glm4v.py,sha256=dd-BQpccDCp1SbIxcJ5rG8xcwYQK3KOv1Tgm9TGnZc4,6594
95
97
  liger_kernel/transformers/model/glm4v_moe.py,sha256=zKhMdOOrRhlrvCSFaeVYfddL1ubpY8edEO91TN81n98,7135
@@ -99,7 +101,7 @@ liger_kernel/transformers/model/internvl.py,sha256=OOutracs9qrPHSU7FVYar08yinvGr
99
101
  liger_kernel/transformers/model/llama.py,sha256=kqZeONzwTBzudoChlKMzq1w23BtYGbxWZC1l1V__JTw,13410
100
102
  liger_kernel/transformers/model/llama4.py,sha256=PfkynGVI0xxMs3EtyYpCgaALI6stu25OIrTIymE-pvg,4853
101
103
  liger_kernel/transformers/model/llava.py,sha256=yoADM_BuIEummtTDiwWqjfUjXUMZD78VJzS0TRj5GJ4,15687
102
- liger_kernel/transformers/model/loss_utils.py,sha256=mAV6NsE1xR2smQMlr_n9afh4ek3BhIfieZdTn1Z-9Fw,2836
104
+ liger_kernel/transformers/model/loss_utils.py,sha256=tNbC94Z4Ca2mlv3MRhnqfpJ7sBc5MZJtt1-mzMMJT1M,3088
103
105
  liger_kernel/transformers/model/mistral.py,sha256=OcwOzVDMwwDbVccVPv-AaocznzWwzLT3aRaKK5SMaAg,6030
104
106
  liger_kernel/transformers/model/mixtral.py,sha256=YcBDoTEJDgLFJ_RTo180DYGxR8D5Ad9-idumif7kCPE,12130
105
107
  liger_kernel/transformers/model/mllama.py,sha256=vAHwCm63sn4kpAY0rDGf_N0HR7KRTBVpBYDVTPOaZTg,12079
@@ -122,9 +124,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
122
124
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
123
125
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
124
126
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
125
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
126
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/METADATA,sha256=Mzy4eM7hocfOx4KYOI_qKR056hH-RyAOcd99Ju-qY5k,25660
127
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
128
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
129
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
130
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD,,
127
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
128
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/METADATA,sha256=Ja1hknX3Qd5-8K5-BO7pX4Ln11BgPKgBrYBjf291kzU,25660
129
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
130
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/WHEEL,sha256=WnJ8fYhv8N4SYVK2lLYNI6N0kVATA7b0piVUNvqIIJE,91
131
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
132
+ liger_kernel_nightly-0.6.4.dev20260116023519.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.3.2)
2
+ Generator: setuptools (75.3.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5