liger-kernel 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +103 -61
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +30 -11
  10. liger_kernel/ops/kl_div.py +2 -2
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/dyt.py +20 -0
  13. liger_kernel/transformers/functional.py +5 -0
  14. liger_kernel/transformers/model/gemma.py +8 -16
  15. liger_kernel/transformers/model/gemma2.py +7 -16
  16. liger_kernel/transformers/model/llama.py +8 -15
  17. liger_kernel/transformers/model/llava.py +369 -0
  18. liger_kernel/transformers/model/loss_utils.py +57 -0
  19. liger_kernel/transformers/model/mistral.py +9 -10
  20. liger_kernel/transformers/model/mixtral.py +8 -15
  21. liger_kernel/transformers/model/mllama.py +8 -15
  22. liger_kernel/transformers/model/olmo2.py +8 -16
  23. liger_kernel/transformers/model/paligemma.py +397 -0
  24. liger_kernel/transformers/model/phi3.py +8 -15
  25. liger_kernel/transformers/model/qwen2.py +8 -15
  26. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  27. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  28. liger_kernel/transformers/monkey_patch.py +219 -13
  29. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +9 -6
  30. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/RECORD +34 -29
  31. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  32. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -19,15 +19,20 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
19
19
  student_log_probs = F.log_softmax(student_logits, dim=-1)
20
20
  teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
21
 
22
- # Compute probabilities (only required for mean calculation)
23
- mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
- log_mean_probs = mean_probs.log()
22
+ if beta == 0:
23
+ jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
24
+ elif beta == 1:
25
+ jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26
+ else:
27
+ # Compute probabilities (only required for mean calculation)
28
+ mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29
+ log_mean_probs = mean_probs.log()
25
30
 
26
- student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
- teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
31
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
33
 
29
- # JSD is the weighted average of the KL divergences
30
- jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
34
+ # JSD is the weighted average of the KL divergences
35
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
36
  return jsd_loss
32
37
 
33
38
  @classmethod
@@ -9,6 +9,7 @@ import triton.language as tl
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
12
13
 
13
14
  if compare_version("triton", operator.ge, "3.0.0"):
14
15
  try:
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
59
60
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
60
61
  loss_stride (int): The stride of the loss tensor.
61
62
  n_cols (int): The number of columns in the input tensor.
62
- n_non_ignore (flaot): The number of non-ignored elements in the batch.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
63
64
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
64
65
  weight_sum (float): The sum of weight tensor.
65
66
  ignore_index (int): The index to ignore in the target.
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
258
259
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259
260
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260
261
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
261
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
262
263
 
263
264
 
264
265
  def cross_entropy_forward(
@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -8,11 +8,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
8
8
  from liger_kernel.ops.utils import amp_custom_fwd
9
9
  from liger_kernel.ops.utils import element_mul_kernel
10
10
  from liger_kernel.ops.utils import is_hip
11
+ from liger_kernel.utils import infer_device
11
12
 
12
13
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
14
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
14
15
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
15
- MAX_FUSED_SIZE = 65536 // 2
16
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
16
17
 
17
18
 
18
19
  def fused_linear_jsd_forward(
liger_kernel/ops/jsd.py CHANGED
@@ -51,24 +51,43 @@ def _jsd_kernel(
51
51
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
52
 
53
53
  if beta == 0.0: # forward KL
54
- Y_prob = tl.exp(Y)
54
+ Y_max = tl.max(Y, axis=0)
55
+ Y_shifted = Y - Y_max
56
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
55
57
  loss = Y_prob * (Y - X)
56
58
  dX = -Y_prob
57
- elif beta == 1.0:
58
- X_prob = tl.exp(X)
59
+ elif beta == 1.0: # reverse KL
60
+ X_max = tl.max(X, axis=0)
61
+ X_shifted = X - X_max
62
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
59
63
  loss = X_prob * (X - Y)
60
64
  dX = loss + X_prob
61
65
  else:
62
- Q = tl.exp(X)
63
- P = tl.exp(Y)
64
- M = beta * P + (1 - beta) * Q
65
- log_M = tl.log(M)
66
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
67
+ X_shifted = X - max_val
68
+ Y_shifted = Y - max_val
66
69
 
67
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
- dX = (1 - beta) * Q * (X - log_M)
70
+ # Pre-compute exp(max_val) since it's used twice
71
+ exp_max = tl.exp(max_val)
72
+
73
+ # Compute exp terms with compensation
74
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
75
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
76
+
77
+ # Pre-compute common terms
78
+ beta_P = beta * P
79
+ one_minus_beta_Q = (1 - beta) * Q
80
+ M = beta_P + one_minus_beta_Q
81
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
82
+
83
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
84
+ dX = one_minus_beta_Q * (X - log_M)
85
+
86
+ # Pre-compute scaling factor
87
+ scale = 1.0 / n_non_ignore
88
+ loss = loss * scale
89
+ dX = dX * scale
69
90
 
70
- loss = loss / n_non_ignore
71
- dX = dX / n_non_ignore
72
91
  tl.store(loss_ptr + offsets, loss, mask=mask)
73
92
  tl.store(dX_ptr + offsets, dX, mask=mask)
74
93
 
@@ -185,9 +185,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
185
185
  Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
186
186
  ```python
187
187
  if log_target:
188
- loss = target * (target.log() - input)
189
- else:
190
188
  loss = target.exp() * (target - input)
189
+ else:
190
+ loss = target * (target.log() - input)
191
191
  ```,
192
192
  then the loss is reduced according to the `reduction` parameter.
193
193
  as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
2
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
3
4
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
4
5
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
5
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
@@ -11,10 +12,12 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma
11
12
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
13
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
13
14
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
14
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
15
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
16
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
17
19
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
20
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
18
21
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
19
22
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
20
23
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
+
16
+ def forward(self, x):
17
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
+
19
+ def extra_repr(self):
20
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}"
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
4
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
7
  from liger_kernel.ops.geglu import LigerGELUMulFunction
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
192
193
 
193
194
  def liger_swiglu(a, b):
194
195
  return LigerSiLUMulFunction.apply(a, b)
196
+
197
+
198
+ def liger_dyt(x, alpha, gamma, beta):
199
+ return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -14,6 +14,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
15
 
16
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
18
 
18
19
 
19
20
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@@ -200,22 +201,13 @@ def lce_forward(
200
201
  loss = None
201
202
  # if in training mode, don't materialize logits
202
203
  if self.training and (labels is not None):
203
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
204
-
205
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
206
- shift_labels = labels[..., 1:].contiguous()
207
-
208
- # flatten tokens
209
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
210
- shift_labels = shift_labels.view(-1)
211
-
212
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
213
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
214
-
215
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
216
- if reduction == "sum":
217
- loss /= loss_kwargs["num_items_in_batch"]
218
-
204
+ loss = LigerForCausalLMLoss(
205
+ hidden_states=hidden_states,
206
+ lm_head_weight=self.lm_head.weight,
207
+ labels=labels,
208
+ hidden_size=self.config.hidden_size,
209
+ **loss_kwargs,
210
+ )
219
211
  else: # if in inference mode materialize logits
220
212
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
221
213
  if labels is not None:
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
16
 
17
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
19
 
19
20
  logger = logging.getLogger(__name__)
20
21
 
@@ -212,25 +213,15 @@ def lce_forward(
212
213
  loss = None
213
214
  # if in training mode, don't materialize logits
214
215
  if self.training and (labels is not None):
215
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
216
-
217
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
218
- shift_labels = labels[..., 1:].contiguous()
219
-
220
- # flatten tokens
221
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
222
- shift_labels = shift_labels.view(-1)
223
-
224
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
225
- lce = LigerFusedLinearCrossEntropyLoss(
216
+ loss = LigerForCausalLMLoss(
217
+ hidden_states=hidden_states,
218
+ lm_head_weight=self.lm_head.weight,
219
+ labels=labels,
220
+ hidden_size=self.config.hidden_size,
226
221
  softcap=self.config.final_logit_softcapping,
227
- reduction=reduction,
222
+ **loss_kwargs,
228
223
  )
229
224
 
230
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
231
- if reduction == "sum":
232
- loss /= loss_kwargs["num_items_in_batch"]
233
-
234
225
  else: # if in inference mode materialize logits
235
226
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
236
227
  if self.config.final_logit_softcapping is not None:
@@ -15,6 +15,7 @@ from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
16
 
17
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from transformers.cache_utils import Cache
@@ -212,21 +213,13 @@ def lce_forward(
212
213
  loss = None
213
214
  # if in training mode, don't materialize logits
214
215
  if self.training and (labels is not None):
215
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
216
-
217
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
218
- shift_labels = labels[..., 1:].contiguous()
219
-
220
- # flatten tokens
221
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
222
- shift_labels = shift_labels.view(-1)
223
-
224
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
225
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
226
-
227
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
228
- if reduction == "sum":
229
- loss /= loss_kwargs["num_items_in_batch"]
216
+ loss = LigerForCausalLMLoss(
217
+ hidden_states=hidden_states,
218
+ lm_head_weight=self.lm_head.weight,
219
+ labels=labels,
220
+ hidden_size=self.config.hidden_size,
221
+ **loss_kwargs,
222
+ )
230
223
 
231
224
  else: # if in inference mode materialize logits
232
225
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])