liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,27 @@
1
+ import operator
2
+
3
+ from typing import Optional
4
+
1
5
  import torch
2
6
  import triton
3
7
  import triton.language as tl
4
8
 
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import element_mul_kernel
11
+ from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
14
+
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
16
+ try:
17
+ # typical import path with dispatch available
18
+ from triton.language.extra.libdevice import tanh
19
+ except ModuleNotFoundError:
20
+ # for working with NGC containers
21
+ from triton.language.extra.cuda.libdevice import tanh
22
+ else:
23
+ from triton.language.math import tanh
24
+
5
25
 
6
26
  @triton.jit
7
27
  def liger_cross_entropy_kernel(
@@ -9,12 +29,27 @@ def liger_cross_entropy_kernel(
9
29
  X_stride,
10
30
  Y_ptr,
11
31
  Y_stride,
32
+ weight_ptr,
12
33
  loss_ptr,
34
+ z_loss_ptr,
13
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
14
38
  n_cols,
15
39
  n_non_ignore,
40
+ sum_non_ignore_weight,
41
+ weight_sum,
16
42
  ignore_index,
43
+ lse_square_scale: tl.constexpr,
44
+ label_smoothing: tl.constexpr,
45
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
46
+ softcap,
47
+ RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
17
49
  BLOCK_SIZE: tl.constexpr,
50
+ HAS_WEIGHT: tl.constexpr,
51
+ HAS_SOFTCAPPING: tl.constexpr,
52
+ HAS_GRADIENTS: tl.constexpr,
18
53
  ):
19
54
  """
20
55
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -25,12 +60,27 @@ def liger_cross_entropy_kernel(
25
60
  X_stride (int): The stride of the input tensor.
26
61
  Y_ptr: Pointer to target tensor.
27
62
  Y_stride (int): The stride of the target tensor.
63
+ weight_ptr: Pointer to weight tensor.
28
64
  loss_ptr: Pointer to tensor to store the loss.
65
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
29
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
30
69
  n_cols (int): The number of columns in the input tensor.
31
- n_non_ignore (int): The number of non-ignored elements in the batch.
70
+ n_non_ignore (float): The number of non-ignored elements in the batch.
71
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
72
+ weight_sum (float): The sum of weight tensor.
32
73
  ignore_index (int): The index to ignore in the target.
74
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
75
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
76
+ reduction (str): The string for the reduction to apply
77
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
33
80
  BLOCK_SIZE (int): The block size for Triton operations.
81
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
82
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
83
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
34
84
  """
35
85
 
36
86
  # https://github.com/triton-lang/triton/issues/1058
@@ -49,102 +99,325 @@ def liger_cross_entropy_kernel(
49
99
  for i in range(0, n_cols, BLOCK_SIZE):
50
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
51
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
52
106
  return
53
107
 
54
108
  loss_ptr += program_id * loss_stride
109
+ if RETURN_Z_LOSS:
110
+ z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
113
+
114
+ if HAS_WEIGHT:
115
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
55
116
 
56
117
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
57
118
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
58
119
 
59
- # 3. [Oneline softmax] first pass: find max + sum
120
+ # 3. [Online softmax] first pass: find max + sum
60
121
  m = float("-inf") # m is the max value. use the notation from the paper
61
122
  d = 0.0 # d is the sum. use the notation from the paper
62
- ori_X_y = tl.load(
63
- X_ptr + y
64
- ) # we need to store the original value of X_y for the loss calculation
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
124
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
125
+ if HAS_SOFTCAPPING:
126
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
127
+
128
+ # Label smoothing is a general case of normal cross entropy
129
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
130
+ scaled_x_sum = 0.0
131
+ eps = label_smoothing / n_cols
65
132
 
66
133
  for i in range(0, n_cols, BLOCK_SIZE):
67
134
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
68
135
  X_block = tl.load(
69
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
70
- )
136
+ X_ptr + X_offsets,
137
+ mask=X_offsets < n_cols,
138
+ other=float("-inf"),
139
+ # Ensure float32 precision for softmax calculation
140
+ ).cast(tl.float32)
141
+ if HAS_SOFTCAPPING:
142
+ X_block = softcap * tanh(X_block / softcap)
71
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY and block_max > m:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ argmax_idx = tl.min(masked_offsets)
153
+
154
+ if label_smoothing > 0:
155
+ # scale X beforehand to avoid overflow
156
+ if HAS_WEIGHT:
157
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
158
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
159
+ else:
160
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
72
161
  m_new = tl.maximum(m, block_max)
73
162
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
74
163
  m = m_new
75
164
 
76
- # 4. [Oneline softmax] second pass: calculate the gradients
165
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
166
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
167
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
168
+ lse = m + tl.log(d)
169
+
170
+ # 4. [Online Softmax] Second pass: compute gradients
171
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
77
172
  # dx_y = (softmax(x_y) - 1) / N
78
173
  # dx_i = softmax(x_i) / N, i != y
79
- # N is the number of non ingored elements in the batch
80
- for i in range(0, n_cols, BLOCK_SIZE):
81
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
82
- X_block = tl.load(
83
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
84
- )
85
- X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
86
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
174
+ # For label smoothing:
175
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
176
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
177
+ # = dx_i - (1 - label_smoothing) / N
178
+ # With Z loss:
179
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
180
+ # dx_y = dx_i - (1 - label_smoothing) / N
181
+ # For 'sum' reduction, no normalization is applied:
182
+ # dx_y = softmax(x_y) - 1
183
+ # dx_i = softmax(x_i), for i ≠ y
184
+ if HAS_GRADIENTS:
185
+ for i in range(0, n_cols, BLOCK_SIZE):
186
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
187
+ X_block = tl.load(
188
+ X_ptr + X_offsets,
189
+ mask=X_offsets < n_cols,
190
+ other=float("-inf"),
191
+ # Ensure float32 precision for softmax calculation
192
+ ).cast(tl.float32)
193
+ if HAS_SOFTCAPPING:
194
+ intermediate = tanh(X_block / softcap)
195
+ X_block = softcap * intermediate
196
+
197
+ if not HAS_WEIGHT:
198
+ # softmax(x_i)
199
+ X_block = tl.exp(X_block - m) / d
200
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
201
+ X_block += 2 * lse_square_scale * lse * X_block
202
+ # smoothing term
203
+ X_block += -eps
204
+ # special handle dx_y
205
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
206
+ # reduction scale
207
+ if reduction == "mean":
208
+ X_block = X_block / n_non_ignore
209
+ else:
210
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
211
+ softmax_X = tl.exp(X_block - m) / d
212
+ # derivative of original_loss
213
+ dloss_ori = (1 - label_smoothing) * softmax_X
214
+ # specially handle dx_y
215
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
216
+ dloss_ori = dloss_ori * weight_y
217
+ # derivative of smooth_loss
218
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
219
+ # derivative of z-loss
220
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
221
+ # reduction scale
222
+ if reduction == "mean":
223
+ dloss_ori = dloss_ori / sum_non_ignore_weight
224
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
225
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
226
+ dz_loss = dz_loss / n_non_ignore
227
+ # derivative of total_loss
228
+ X_block = dloss_ori + dloss_smooth + dz_loss
229
+
230
+ # chain rule softcapping
231
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
232
+ if HAS_SOFTCAPPING:
233
+ X_block = X_block * (1 - intermediate * intermediate)
234
+
235
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
87
236
 
88
237
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
89
- # ttps://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
238
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
90
239
  tl.debug_barrier()
91
240
 
92
241
  # 5. Calculate the loss
93
242
 
94
243
  # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
95
244
  # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
245
+ # = X_y - m - log d = X_y - lse
96
246
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
97
247
  # So we can safely calculate log (softmax(X_y)) without overflow
98
- loss = -(ori_X_y - m - tl.log(d))
99
-
100
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
101
- X_y = tl.load(X_ptr + y)
102
- X_y += -1 / (n_non_ignore)
248
+ loss = lse - ori_X_y
249
+ if HAS_WEIGHT:
250
+ loss = weight_y * loss
251
+
252
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
253
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
254
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
255
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
256
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
257
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
258
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
259
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
260
+ if label_smoothing > 0:
261
+ if HAS_WEIGHT:
262
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
263
+ else:
264
+ smooth_loss = scaled_x_sum + label_smoothing * lse
265
+ loss = loss * (1 - label_smoothing) + smooth_loss
266
+
267
+ # An auxiliary loss, z_loss
268
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
269
+ z_loss = lse_square_scale * lse * lse
270
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
271
+ if reduction == "mean":
272
+ if HAS_WEIGHT:
273
+ loss = loss / sum_non_ignore_weight
274
+ else:
275
+ loss = loss / n_non_ignore
276
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
277
+ z_loss = z_loss / n_non_ignore
278
+ loss += z_loss
103
279
 
104
280
  tl.store(loss_ptr, loss)
105
- tl.store(X_ptr + y, X_y)
281
+ if RETURN_Z_LOSS:
282
+ tl.store(z_loss_ptr, z_loss)
283
+ if RETURN_TOKEN_ACCURACY:
284
+ # Store 1.0 if prediction is correct, 0.0 otherwise
285
+ is_correct = 1.0 if argmax_idx == y else 0.0
286
+ tl.store(token_accuracy_ptr, is_correct)
106
287
 
107
288
 
108
289
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
109
290
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
110
291
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
111
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
292
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
112
293
 
113
294
 
114
- @triton.jit
115
- def element_mul(
116
- X_ptr,
117
- X_stride,
118
- grad_output_ptr,
119
- n_cols,
120
- BLOCK_SIZE: tl.constexpr,
295
+ def cross_entropy_forward(
296
+ _input,
297
+ target,
298
+ weight,
299
+ ignore_index,
300
+ lse_square_scale,
301
+ label_smoothing,
302
+ reduction,
303
+ softcap,
304
+ return_z_loss,
305
+ return_token_accuracy=False,
121
306
  ):
122
- """
123
- This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
124
- The multiplication is performed in-place on the tensor pointed by X_ptr.
125
-
126
- Parameters:
127
- X_ptr: Pointer to the input tensor.
128
- X_stride (int): The stride of the input tensor.
129
- grad_output_ptr: Pointer to the gradient output value.
130
- n_cols (int): The number of columns in the input tensor.
131
- BLOCK_SIZE (int): The block size for Triton operations.
132
- """
133
-
134
- # Get the program ID and convert it to int64 to avoid overflow
135
- program_id = tl.program_id(0).to(tl.int64)
136
-
137
- # Locate the start index
138
- X_ptr += program_id * X_stride
307
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
308
+ assert isinstance(return_token_accuracy, bool), (
309
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
310
+ )
311
+
312
+ BT, V = _input.shape
313
+ n_rows = BT
314
+
315
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
316
+
317
+ # unreduced loss
318
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
319
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
320
+ token_accuracy_1d = (
321
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
322
+ )
323
+
324
+ target_mask = target != ignore_index
325
+ n_non_ignore = target_mask.sum().item()
326
+ assert (target * target_mask).max() < _input.shape[-1], (
327
+ f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
328
+ )
329
+ assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
330
+ sum_non_ignore_weight = n_non_ignore
331
+ weight_sum = 0.0
332
+ if weight is not None:
333
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
334
+ assert torch.is_floating_point(weight), (
335
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
336
+ )
337
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
338
+ weight_sum = weight.sum().item()
339
+ # ensure weight is contiguous
340
+ if weight.stride(-1) != 1:
341
+ weight = weight.contiguous()
342
+
343
+ # ensure _input and target are contiguous in the last dimension
344
+ if _input.stride(-1) != 1:
345
+ _input = _input.contiguous()
346
+ if target.stride(-1) != 1:
347
+ target = target.contiguous()
348
+
349
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
350
+ liger_cross_entropy_kernel[(n_rows,)](
351
+ X_ptr=_input,
352
+ X_stride=_input.stride(-2),
353
+ Y_ptr=target,
354
+ Y_stride=target.stride(-1), # always 1
355
+ weight_ptr=weight, # dummy if None
356
+ loss_ptr=loss_1d,
357
+ z_loss_ptr=z_loss_1d,
358
+ loss_stride=loss_1d.stride(-1), # always 1
359
+ token_accuracy_ptr=token_accuracy_1d,
360
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
361
+ if return_token_accuracy
362
+ else 0, # always 1 if accuracy is enabled
363
+ n_cols=V,
364
+ n_non_ignore=n_non_ignore,
365
+ sum_non_ignore_weight=sum_non_ignore_weight,
366
+ ignore_index=ignore_index,
367
+ weight_sum=weight_sum,
368
+ lse_square_scale=lse_square_scale,
369
+ label_smoothing=label_smoothing,
370
+ reduction=reduction,
371
+ softcap=softcap,
372
+ RETURN_Z_LOSS=return_z_loss,
373
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
374
+ BLOCK_SIZE=BLOCK_SIZE,
375
+ HAS_WEIGHT=True if weight is not None else False,
376
+ HAS_SOFTCAPPING=True if softcap is not None else False,
377
+ HAS_GRADIENTS=_input.requires_grad,
378
+ # TODO: 32 seems to give the best performance
379
+ # Performance is quite sensitive to num_warps
380
+ num_warps=32 if not is_hip() else 16,
381
+ )
382
+
383
+ if reduction == "none":
384
+ loss = loss_1d
385
+ z_loss = z_loss_1d if return_z_loss else None
386
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
387
+ else:
388
+ loss = torch.sum(loss_1d)
389
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
390
+ # For accuracy, we compute the mean across all non-ignored tokens
391
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
392
+
393
+ return loss, z_loss, token_accuracy, _input
394
+
395
+
396
+ def cross_entropy_backward(_input, grad_output):
397
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
398
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
399
+ pass
400
+ # If reduction is 'none'
401
+ elif grad_output.ndim > 0:
402
+ _input = _input * grad_output.unsqueeze(dim=1)
403
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
404
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
405
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
406
+ else:
407
+ BT, V = _input.shape
408
+ n_rows = BT
409
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
139
410
 
140
- # Load the gradient output value
141
- grad_output = tl.load(grad_output_ptr)
411
+ element_mul_kernel[(n_rows,)](
412
+ _input,
413
+ _input.stride(-2),
414
+ grad_output,
415
+ V,
416
+ BLOCK_SIZE=BLOCK_SIZE,
417
+ num_warps=32 if not is_hip() else 16,
418
+ )
142
419
 
143
- # Perform the element-wise multiplication
144
- for i in range(0, n_cols, BLOCK_SIZE):
145
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
146
- X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
147
- tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
420
+ return _input
148
421
 
149
422
 
150
423
  class LigerCrossEntropyFunction(torch.autograd.Function):
@@ -154,7 +427,19 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
154
427
  """
155
428
 
156
429
  @staticmethod
157
- def forward(ctx, _input, target, ignore_index):
430
+ def forward(
431
+ ctx,
432
+ _input: torch.Tensor,
433
+ target: torch.Tensor,
434
+ weight: Optional[torch.FloatTensor],
435
+ ignore_index: int = -100,
436
+ lse_square_scale: float = 0.0,
437
+ label_smoothing: float = 0.0,
438
+ reduction: str = "mean",
439
+ softcap: Optional[float] = None,
440
+ return_z_loss: bool = False,
441
+ return_token_accuracy: bool = False,
442
+ ):
158
443
  """
159
444
  The forward pass of the Liger Cross Entropy loss.
160
445
 
@@ -162,87 +447,71 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
162
447
  ctx : The context object.
163
448
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
164
449
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
450
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
165
451
  ignore_index (int): The index to ignore in the target.
452
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
453
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
454
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
455
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
456
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
457
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
166
458
 
167
459
  Returns:
168
- tensor: The computed loss.
460
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
169
461
  """
170
- BT, V = _input.shape
171
- n_rows = BT
172
-
173
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
462
+ input_requires_grad = _input.requires_grad
174
463
 
175
- # unreduced loss
176
- loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
177
-
178
- n_non_ignore = (target != ignore_index).sum().item()
179
-
180
- # ensure _input and target are contiguous in the last dimension
181
- if _input.stride(-1) != 1:
182
- _input = _input.contiguous()
183
- if target.stride(-1) != 1:
184
- target = target.contiguous()
185
-
186
- # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
187
- liger_cross_entropy_kernel[(n_rows,)](
188
- X_ptr=_input,
189
- X_stride=_input.stride(-2),
190
- Y_ptr=target,
191
- Y_stride=target.stride(-1), # always 1
192
- loss_ptr=loss_1d,
193
- loss_stride=loss_1d.stride(-1), # always 1
194
- n_cols=V,
195
- n_non_ignore=n_non_ignore,
196
- ignore_index=ignore_index,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- # TODO: 32 seems to give the best performance
199
- # Performance is quite sentitive to num_warps
200
- num_warps=32,
464
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
465
+ _input,
466
+ target,
467
+ weight,
468
+ ignore_index,
469
+ lse_square_scale,
470
+ label_smoothing,
471
+ reduction,
472
+ softcap,
473
+ return_z_loss,
474
+ return_token_accuracy,
201
475
  )
202
-
203
- loss = torch.sum(loss_1d) / n_non_ignore
204
-
205
476
  # TODO: investigation
206
477
  # If we don't detach the _input tensor, the memory will double
207
478
  # Not sure why but seems that there will be a time both grad and value exist but in different location
208
- ctx.save_for_backward(_input.detach())
209
- return loss
479
+ if input_requires_grad:
480
+ ctx.save_for_backward(_input.detach())
481
+ ctx.return_z_loss = return_z_loss
482
+ ctx.return_token_accuracy = return_token_accuracy
483
+
484
+ return loss, z_loss, token_accuracy
210
485
 
211
486
  @staticmethod
212
- def backward(ctx, grad_output):
487
+ def backward(ctx, grad_output, grad_output2, grad_output3):
213
488
  """
214
489
  The backward pass of the Liger Cross Entropy loss.
215
490
 
216
491
  Parameters:
217
492
  ctx : The context object with saved tensors.
218
493
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
219
-
494
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
495
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
220
496
  Returns:
221
497
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
222
498
  """
223
- (_input,) = ctx.saved_tensors
224
- # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
225
- if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
226
- pass
227
-
228
- # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
229
- # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
230
- else:
231
- BT, V = _input.shape
232
- n_rows = BT
233
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
234
-
235
- element_mul[(n_rows,)](
236
- _input,
237
- _input.stride(-2),
238
- grad_output,
239
- V,
240
- BLOCK_SIZE=BLOCK_SIZE,
241
- num_warps=32,
242
- )
499
+ if ctx.return_z_loss:
500
+ del grad_output2 # z_loss is only for logging
501
+ if ctx.return_token_accuracy:
502
+ del grad_output3 # token_accuracy is only for metrics
243
503
 
504
+ (_input,) = ctx.saved_tensors
505
+ _input = cross_entropy_backward(_input, grad_output)
244
506
  return (
245
507
  _input,
246
508
  None,
247
509
  None,
510
+ None,
511
+ None,
512
+ None,
513
+ None,
514
+ None,
515
+ None,
516
+ None,
248
517
  )