liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,10 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import (
6
- amp_custom_bwd,
7
- amp_custom_fwd,
8
- element_mul_kernel,
9
- is_hip,
10
- )
5
+ from liger_kernel.ops.utils import amp_custom_bwd
6
+ from liger_kernel.ops.utils import amp_custom_fwd
7
+ from liger_kernel.ops.utils import element_mul_kernel
8
+ from liger_kernel.ops.utils import is_hip
11
9
 
12
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -19,15 +17,26 @@ def fused_linear_cross_entropy_forward(
19
17
  _input,
20
18
  weight,
21
19
  target,
20
+ ce_weight=None,
22
21
  bias=None,
23
22
  ignore_index=-100,
24
23
  lse_square_scale=0.0,
25
24
  label_smoothing=0.0,
26
25
  reduction="mean",
26
+ softcap=None,
27
+ return_z_loss=False,
28
+ accum_dtype=None,
29
+ use_token_scaling=False,
30
+ return_token_accuracy=False,
27
31
  ):
28
- dtype = _input.dtype
32
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
29
36
  device = _input.device
30
37
 
38
+ input_requires_grad = _input.requires_grad
39
+
31
40
  # inputs have shape: BT x H
32
41
  # materialized activations will have shape: BT x V
33
42
  # the increase in memory = BT x V
@@ -40,21 +49,43 @@ def fused_linear_cross_entropy_forward(
40
49
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
41
50
 
42
51
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
43
- chunk_size = triton.next_power_of_2(
44
- triton.cdiv(BT, inc_factor)
45
- ) # (BT + inc_factor - 1) // inc_factor
52
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
46
53
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
47
54
 
48
- grad_weight = (
49
- torch.zeros_like(weight, device=device) if weight.requires_grad else None
50
- )
51
55
  grad_input = torch.zeros_like(_input, device=device)
52
- grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
53
- # we use fp32 for loss accumulator
54
- loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
55
56
 
56
- # NOTE: skip .item() here to avoid CUDA synchronization
57
- total_n_non_ignore = (target != ignore_index).sum()
57
+ # we use fp32 for loss and gradients accumulator
58
+ if input_requires_grad:
59
+ if accum_dtype is None:
60
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
61
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
62
+ else:
63
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
64
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
65
+ else:
66
+ grad_weight = None
67
+ grad_bias = None
68
+
69
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
70
+ z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
72
+
73
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
74
+ target_mask = target != ignore_index
75
+ total_n_non_ignore = target_mask.sum().item()
76
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
77
+ ce_weight_sum = 0.0
78
+ if ce_weight is not None:
79
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
80
+ assert torch.is_floating_point(ce_weight), (
81
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
82
+ )
83
+ total_sum_non_ignore_ce_weight = (
84
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
85
+ )
86
+ ce_weight_sum = ce_weight.sum().item()
87
+ if ce_weight.stride(-1) != 1:
88
+ ce_weight = ce_weight.contiguous()
58
89
 
59
90
  for chunk_id in range(num_chunks):
60
91
  start_idx = chunk_id * chunk_size
@@ -65,16 +96,45 @@ def fused_linear_cross_entropy_forward(
65
96
  logits_chunk = _input_chunk @ weight.t() # chunk_size x V
66
97
  if bias is not None:
67
98
  logits_chunk = logits_chunk + bias
99
+
68
100
  target_chunk = target[start_idx:end_idx] # chunk_size,
69
101
 
70
102
  n_rows = logits_chunk.shape[0]
71
103
 
104
+ # Compute predicted probabilities for token scaling if needed
105
+ if use_token_scaling:
106
+ # Compute softmax probabilities for scaling
107
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
108
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
109
+ if softcap is not None:
110
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
111
+
112
+ # Compute softmax to get predicted probabilities
113
+ probs = torch.softmax(logits_for_softmax, dim=-1)
114
+
115
+ # Get predicted probabilities for token scaling, handling ignored targets
116
+ valid_target_mask = target_chunk != ignore_index
117
+ valid_targets = target_chunk[valid_target_mask]
118
+
119
+ if len(valid_targets) > 0:
120
+ # Gather probabilities only for valid targets
121
+ valid_probs = probs[valid_target_mask]
122
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
123
+
124
+ # Create full tensor with zeros for ignored targets
125
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
126
+ pred_probs[valid_target_mask] = pred_probs_valid
127
+ else:
128
+ # All targets are ignored
129
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
130
+
131
+ # Store the scaling factors
132
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
133
+
72
134
  # unreduced loss
73
135
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
74
- n_non_ignore = (target_chunk != ignore_index).sum().item()
75
-
76
- # when doing CE, use the upcasted precision
77
- logits_chunk = logits_chunk.float()
136
+ z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
78
138
 
79
139
  # ensure _input and target are contiguous
80
140
  logits_chunk = logits_chunk.contiguous()
@@ -86,70 +146,91 @@ def fused_linear_cross_entropy_forward(
86
146
  X_stride=logits_chunk.stride(-2),
87
147
  Y_ptr=target_chunk,
88
148
  Y_stride=target_chunk.stride(-1), # always 1
149
+ weight_ptr=ce_weight,
89
150
  loss_ptr=loss_1d_slice,
90
- z_loss_ptr=loss_1d_slice, # dummy ptr, not used
151
+ z_loss_ptr=z_loss_1d_slice,
91
152
  loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
92
157
  n_cols=V,
93
- n_non_ignore=n_non_ignore,
158
+ n_non_ignore=total_n_non_ignore,
159
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
160
+ weight_sum=ce_weight_sum,
94
161
  ignore_index=ignore_index,
95
162
  lse_square_scale=lse_square_scale,
96
163
  label_smoothing=label_smoothing,
97
164
  reduction=reduction,
98
- RETURN_Z_LOSS=0, # False
165
+ softcap=softcap,
166
+ RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
168
+ HAS_WEIGHT=True if ce_weight is not None else False,
169
+ HAS_SOFTCAPPING=True if softcap is not None else False,
170
+ HAS_GRADIENTS=input_requires_grad,
99
171
  BLOCK_SIZE=BLOCK_SIZE,
100
172
  num_warps=32 if not is_hip() else 16,
101
173
  )
102
174
 
103
- # gradient of logits_chunk is computed in-place by the above triton kernel.
104
- # Following HuggingFace model source code, we do the forward and backward
105
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
106
- # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
107
- # Propagating to lm_head's backward, we'll switch back to the original dtype.
108
- logits_chunk = logits_chunk.to(dtype)
109
-
110
- # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
111
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
112
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
113
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
114
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
115
-
116
- if reduction == "mean":
117
- alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
118
- else:
119
- alpha = 1.0
175
+ # Apply token scaling if requested
176
+ if use_token_scaling:
177
+ loss_1d_slice = loss_1d_slice * scaling_factors
178
+ if return_z_loss:
179
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
120
180
 
121
- loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
122
- grad_logits_chunk = logits_chunk * alpha # chunk_size x V
181
+ loss_1d[start_idx:end_idx] = loss_1d_slice
182
+ if return_z_loss:
183
+ z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
186
+ grad_logits_chunk = logits_chunk # chunk_size x V
123
187
 
124
- grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
188
+ # Apply token scaling to gradients if requested
189
+ if use_token_scaling:
190
+ # Expand scaling factors to match gradient dimensions
191
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
192
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
125
193
 
126
- if grad_weight is not None:
127
- torch.addmm(
128
- input=grad_weight,
129
- mat1=logits_chunk.t(),
130
- mat2=_input_chunk,
131
- out=grad_weight,
132
- alpha=alpha,
133
- beta=1.0,
134
- )
194
+ if input_requires_grad:
195
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
135
196
 
136
- if bias is not None:
197
+ if grad_weight is not None and input_requires_grad:
198
+ grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
199
+
200
+ if bias is not None and input_requires_grad:
137
201
  torch.add(
138
202
  input=grad_bias,
139
- other=logits_chunk.sum(dim=0),
203
+ other=grad_logits_chunk.sum(dim=0),
140
204
  out=grad_bias,
141
- alpha=alpha,
205
+ alpha=1.0,
142
206
  )
143
207
 
144
- loss = torch.sum(loss_1d)
145
- return loss, grad_input, grad_weight, grad_bias
208
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
209
+ # if reduction == "none":
210
+ # loss = loss_1d
211
+ # z_loss = z_loss_1d if return_z_loss else None
146
212
 
213
+ if reduction == "none":
214
+ # Return per-token losses
215
+ loss = loss_1d
216
+ z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
218
+ else:
219
+ loss = torch.sum(loss_1d)
220
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
147
223
 
148
- def fused_linear_cross_entropy_backward(
149
- grad_output, grad_input, grad_weight, grad_bias
150
- ):
224
+ # Cast back to original dtype
225
+ grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
226
+ grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
227
+
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
229
+
230
+
231
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
151
232
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
152
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
233
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
153
234
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
154
235
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
155
236
  BT, H = grad_input.shape
@@ -203,10 +284,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
203
284
  weight,
204
285
  target,
205
286
  bias=None,
287
+ ce_weight=None,
206
288
  ignore_index=-100,
207
289
  lse_square_scale=0.0,
208
290
  label_smoothing=0.0,
209
291
  reduction="mean",
292
+ softcap=None,
293
+ return_z_loss: bool = False,
294
+ accum_dtype=None,
295
+ use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
210
297
  ):
211
298
  """
212
299
  Fusing the last linear layer with cross-entropy loss
@@ -221,19 +308,33 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
221
308
  target: (B*T) where each value is in [0, V-1]
222
309
  weight: (V, H) where V is the number of classes
223
310
  bias: (V) where V is the number of classes
311
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
224
312
  ignore_index: the index to ignore in the target
225
313
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
226
314
  reduction: reduction to apply
315
+ accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
316
+ Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
317
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
318
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
319
+ Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
227
321
  """
228
- loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
229
- _input,
230
- weight,
231
- target,
232
- bias,
233
- ignore_index,
234
- lse_square_scale,
235
- label_smoothing,
236
- reduction,
322
+
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
324
+ _input=_input,
325
+ weight=weight,
326
+ target=target,
327
+ bias=bias,
328
+ ce_weight=ce_weight,
329
+ ignore_index=ignore_index,
330
+ lse_square_scale=lse_square_scale,
331
+ label_smoothing=label_smoothing,
332
+ reduction=reduction,
333
+ softcap=softcap,
334
+ return_z_loss=return_z_loss,
335
+ accum_dtype=accum_dtype,
336
+ use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
237
338
  )
238
339
  # downcast to dtype and store for backward
239
340
  ctx.save_for_backward(
@@ -241,13 +342,34 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
241
342
  grad_weight.detach() if grad_weight is not None else None,
242
343
  grad_bias.detach() if bias is not None else None,
243
344
  )
244
- return loss
345
+ ctx.return_z_loss = return_z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
245
348
 
246
349
  @staticmethod
247
350
  @amp_custom_bwd
248
- def backward(ctx, grad_output):
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
352
+ if ctx.return_z_loss:
353
+ del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
249
356
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
250
357
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
251
358
  grad_output, grad_input, grad_weight, grad_bias
252
359
  )
253
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
360
+ return (
361
+ grad_input,
362
+ grad_weight,
363
+ None,
364
+ grad_bias,
365
+ None,
366
+ None,
367
+ None,
368
+ None,
369
+ None,
370
+ None,
371
+ None,
372
+ None,
373
+ None, # use_token_scaling
374
+ None, # return_token_accuracy
375
+ )
@@ -4,17 +4,16 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
11
+ from liger_kernel.utils import infer_device
13
12
 
14
13
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
14
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
16
15
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
17
- MAX_FUSED_SIZE = 65536 // 2
16
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
18
17
 
19
18
 
20
19
  def fused_linear_jsd_forward(
@@ -43,16 +42,10 @@ def fused_linear_jsd_forward(
43
42
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
43
 
45
44
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
45
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
46
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
47
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
48
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
49
  grad_input = torch.zeros_like(student_input)
57
50
  # we use fp32 for loss accumulator
58
51
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +66,8 @@ def fused_linear_jsd_forward(
73
66
  # shape: chunk_size x V
74
67
  # For anything starting from logits to the final JSD loss, we do computation
75
68
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
69
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
70
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
71
  chunk_n_rows = student_logits_chunk.shape[0]
83
72
 
84
73
  # unreduced loss
@@ -104,9 +93,7 @@ def fused_linear_jsd_forward(
104
93
  dX_ptr=student_prob_chunk,
105
94
  dX_stride=student_prob_chunk.stride(-2),
106
95
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
96
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
97
  ), # dummy ptr if no label
111
98
  beta=jsd_beta,
112
99
  n_non_ignore=n_non_ignore,
@@ -121,9 +108,7 @@ def fused_linear_jsd_forward(
121
108
  student_logits_chunk = (
122
109
  student_prob_chunk
123
110
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
111
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
112
  ) / temperature
128
113
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
114
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -202,7 +187,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
202
187
  teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
188
  teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
189
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
- jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
190
+ jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
206
191
  ignore_index (int): the index to ignore. Default: -100
207
192
  temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
193
 
@@ -211,9 +196,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
211
196
  """
212
197
  has_label = False
213
198
  if shift_labels is not None:
214
- assert shift_labels.shape == (
215
- teacher_input.shape[0],
216
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
199
+ assert shift_labels.shape == (teacher_input.shape[0],), (
200
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
201
+ )
217
202
  shift_labels = shift_labels.contiguous()
218
203
  has_label = True
219
204
 
@@ -239,7 +224,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
224
  @amp_custom_bwd
240
225
  def backward(ctx, grad_output):
241
226
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
227
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
228
  return (grad_input, grad_weight, None, None, None, None, None, None)