liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,11 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kernel
4
+ from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
+ from liger_kernel.ops.utils import amp_custom_bwd
6
+ from liger_kernel.ops.utils import amp_custom_fwd
7
+ from liger_kernel.ops.utils import element_mul_kernel
8
+ from liger_kernel.ops.utils import is_hip
5
9
 
6
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
7
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -9,153 +13,363 @@ from liger_kernel.ops.cross_entropy import element_mul, liger_cross_entropy_kern
9
13
  MAX_FUSED_SIZE = 65536 // 2
10
14
 
11
15
 
12
- class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
13
- @staticmethod
14
- def forward(ctx, _input, linear, target, ignore_index):
15
- """
16
- Fusing the last linear layer with cross-entropy loss
17
- Reference: https://github.com/mgmalek/efficient_cross_entropy
16
+ def fused_linear_cross_entropy_forward(
17
+ _input,
18
+ weight,
19
+ target,
20
+ ce_weight=None,
21
+ bias=None,
22
+ ignore_index=-100,
23
+ lse_square_scale=0.0,
24
+ label_smoothing=0.0,
25
+ reduction="mean",
26
+ softcap=None,
27
+ return_z_loss=False,
28
+ accum_dtype=None,
29
+ use_token_scaling=False,
30
+ return_token_accuracy=False,
31
+ ):
32
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
33
+ assert isinstance(return_token_accuracy, bool), (
34
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
35
+ )
36
+ device = _input.device
18
37
 
19
- Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
20
- the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
21
- compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
22
- for the backward pass.
38
+ input_requires_grad = _input.requires_grad
23
39
 
24
- _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
25
- target: (B*T) where each value is in [0, V-1]
26
- linear: linear projection matrix of shape V x H.
27
- ignore_index: the index to ignore in the target
28
- """
29
- dtype = (
30
- torch.get_autocast_gpu_dtype()
31
- if torch.is_autocast_enabled()
32
- else _input.dtype
40
+ # inputs have shape: BT x H
41
+ # materialized activations will have shape: BT x V
42
+ # the increase in memory = BT x V
43
+ # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
44
+ # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
45
+ # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
46
+ # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
47
+ BT, H = _input.shape
48
+ V = weight.shape[0]
49
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
50
+
51
+ inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
52
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
53
+ num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
54
+
55
+ grad_input = torch.zeros_like(_input, device=device)
56
+
57
+ # we use fp32 for loss and gradients accumulator
58
+ if input_requires_grad:
59
+ if accum_dtype is None:
60
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
61
+ grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
62
+ else:
63
+ grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
64
+ grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
65
+ else:
66
+ grad_weight = None
67
+ grad_bias = None
68
+
69
+ loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
70
+ z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
71
+ token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
72
+
73
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
74
+ target_mask = target != ignore_index
75
+ total_n_non_ignore = target_mask.sum().item()
76
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
77
+ ce_weight_sum = 0.0
78
+ if ce_weight is not None:
79
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
80
+ assert torch.is_floating_point(ce_weight), (
81
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
33
82
  )
34
- device = _input.device
35
-
36
- # inputs have shape: BT x H
37
- # materialized activations will have shape: BT x V
38
- # the increase in memory = BT x V
39
- # reduction can be achieved by paritioning the number of tokens BT into smaller chunks.
40
- # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
41
- # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
42
- # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
43
- BT, H = _input.shape
44
- V = linear.shape[0]
45
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
46
-
47
- inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
48
- chunk_size = triton.next_power_of_2(
49
- triton.cdiv(BT, inc_factor)
50
- ) # (BT + inc_factor - 1) // inc_factor
51
- num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
52
-
53
- grad_linear = torch.zeros_like(linear, device=device)
54
- grad_input = torch.zeros_like(_input, device=device)
55
-
56
- # we use fp32 for loss accumulator
57
- loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
58
-
59
- total_n_non_ignore = (target != ignore_index).sum().item()
60
-
61
- for chunk_id in range(num_chunks):
62
- start_idx = chunk_id * chunk_size
63
- end_idx = min((chunk_id + 1) * chunk_size, BT)
64
- _input_chunk = _input[start_idx:end_idx] # chunk_size x H
65
-
66
- # when doing matmul, use the original precision
67
- logits_chunk = _input_chunk @ linear.t() # chunk_size x V
68
- target_chunk = target[start_idx:end_idx] # chunk_size,
69
-
70
- n_rows = logits_chunk.shape[0]
71
-
72
- # unreduced loss
73
- loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
74
- n_non_ignore = (target_chunk != ignore_index).sum().item()
75
-
76
- # when doing CE, use the upcasted precision
77
- logits_chunk = logits_chunk.float()
78
-
79
- # ensure _input and target are contiguous
80
- logits_chunk = logits_chunk.contiguous()
81
- target_chunk = target_chunk.contiguous()
82
-
83
- # Here we calculate the gradient of logits_chunk in place so we can save memory.
84
- liger_cross_entropy_kernel[(n_rows,)](
85
- X_ptr=logits_chunk,
86
- X_stride=logits_chunk.stride(-2),
87
- Y_ptr=target_chunk,
88
- Y_stride=target_chunk.stride(-1), # always 1
89
- loss_ptr=loss_1d_slice,
90
- loss_stride=loss_1d_slice.stride(-1), # always 1
91
- n_cols=V,
92
- n_non_ignore=n_non_ignore,
93
- ignore_index=ignore_index,
94
- BLOCK_SIZE=BLOCK_SIZE,
95
- num_warps=32,
96
- )
83
+ total_sum_non_ignore_ce_weight = (
84
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
85
+ )
86
+ ce_weight_sum = ce_weight.sum().item()
87
+ if ce_weight.stride(-1) != 1:
88
+ ce_weight = ce_weight.contiguous()
89
+
90
+ for chunk_id in range(num_chunks):
91
+ start_idx = chunk_id * chunk_size
92
+ end_idx = min((chunk_id + 1) * chunk_size, BT)
93
+ _input_chunk = _input[start_idx:end_idx] # chunk_size x H
94
+
95
+ # when doing matmul, use the original precision
96
+ logits_chunk = _input_chunk @ weight.t() # chunk_size x V
97
+ if bias is not None:
98
+ logits_chunk = logits_chunk + bias
99
+
100
+ target_chunk = target[start_idx:end_idx] # chunk_size,
101
+
102
+ n_rows = logits_chunk.shape[0]
103
+
104
+ # Compute predicted probabilities for token scaling if needed
105
+ if use_token_scaling:
106
+ # Compute softmax probabilities for scaling
107
+ # We need to compute this before the cross entropy kernel modifies logits_chunk
108
+ logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
109
+ if softcap is not None:
110
+ logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
111
+
112
+ # Compute softmax to get predicted probabilities
113
+ probs = torch.softmax(logits_for_softmax, dim=-1)
114
+
115
+ # Get predicted probabilities for token scaling, handling ignored targets
116
+ valid_target_mask = target_chunk != ignore_index
117
+ valid_targets = target_chunk[valid_target_mask]
118
+
119
+ if len(valid_targets) > 0:
120
+ # Gather probabilities only for valid targets
121
+ valid_probs = probs[valid_target_mask]
122
+ pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
123
+
124
+ # Create full tensor with zeros for ignored targets
125
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
126
+ pred_probs[valid_target_mask] = pred_probs_valid
127
+ else:
128
+ # All targets are ignored
129
+ pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
130
+
131
+ # Store the scaling factors
132
+ scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
133
+
134
+ # unreduced loss
135
+ loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
136
+ z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
137
+ token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
138
+
139
+ # ensure _input and target are contiguous
140
+ logits_chunk = logits_chunk.contiguous()
141
+ target_chunk = target_chunk.contiguous()
142
+
143
+ # Here we calculate the gradient of logits_chunk in place so we can save memory.
144
+ liger_cross_entropy_kernel[(n_rows,)](
145
+ X_ptr=logits_chunk,
146
+ X_stride=logits_chunk.stride(-2),
147
+ Y_ptr=target_chunk,
148
+ Y_stride=target_chunk.stride(-1), # always 1
149
+ weight_ptr=ce_weight,
150
+ loss_ptr=loss_1d_slice,
151
+ z_loss_ptr=z_loss_1d_slice,
152
+ loss_stride=loss_1d_slice.stride(-1), # always 1
153
+ token_accuracy_ptr=token_accuracy_1d_slice,
154
+ token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
155
+ if return_token_accuracy
156
+ else 0, # always 1 if accuracy is enabled
157
+ n_cols=V,
158
+ n_non_ignore=total_n_non_ignore,
159
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
160
+ weight_sum=ce_weight_sum,
161
+ ignore_index=ignore_index,
162
+ lse_square_scale=lse_square_scale,
163
+ label_smoothing=label_smoothing,
164
+ reduction=reduction,
165
+ softcap=softcap,
166
+ RETURN_Z_LOSS=return_z_loss,
167
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
168
+ HAS_WEIGHT=True if ce_weight is not None else False,
169
+ HAS_SOFTCAPPING=True if softcap is not None else False,
170
+ HAS_GRADIENTS=input_requires_grad,
171
+ BLOCK_SIZE=BLOCK_SIZE,
172
+ num_warps=32 if not is_hip() else 16,
173
+ )
174
+
175
+ # Apply token scaling if requested
176
+ if use_token_scaling:
177
+ loss_1d_slice = loss_1d_slice * scaling_factors
178
+ if return_z_loss:
179
+ z_loss_1d_slice = z_loss_1d_slice * scaling_factors
180
+
181
+ loss_1d[start_idx:end_idx] = loss_1d_slice
182
+ if return_z_loss:
183
+ z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
184
+ if return_token_accuracy:
185
+ token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
186
+ grad_logits_chunk = logits_chunk # chunk_size x V
187
+
188
+ # Apply token scaling to gradients if requested
189
+ if use_token_scaling:
190
+ # Expand scaling factors to match gradient dimensions
191
+ scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
192
+ grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
97
193
 
98
- # gradient of logits_chunk is computed inplace by the above triton kernel.
99
- # Following HuggingFace model source code, we do the forward and backward
100
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
101
- # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
102
- # Propagating to lm_head's backward, we'll switch back to the original dtype.
103
- logits_chunk = logits_chunk.to(dtype)
104
-
105
- # gradient of logits_chunk is computed inplace by the above triton kernel and is of shape: chunk_size x V
106
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
107
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
108
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
109
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
110
- grad_logits_chunk = logits_chunk * (n_non_ignore / total_n_non_ignore)
111
- grad_input[start_idx:end_idx] = grad_logits_chunk @ linear
112
-
113
- torch.addmm(
114
- input=grad_linear,
115
- mat1=logits_chunk.t(),
116
- mat2=_input_chunk,
117
- out=grad_linear,
118
- alpha=n_non_ignore / total_n_non_ignore,
119
- beta=1.0,
194
+ if input_requires_grad:
195
+ grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
196
+
197
+ if grad_weight is not None and input_requires_grad:
198
+ grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
199
+
200
+ if bias is not None and input_requires_grad:
201
+ torch.add(
202
+ input=grad_bias,
203
+ other=grad_logits_chunk.sum(dim=0),
204
+ out=grad_bias,
205
+ alpha=1.0,
120
206
  )
121
207
 
122
- loss = torch.sum(loss_1d) / total_n_non_ignore
208
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
209
+ # if reduction == "none":
210
+ # loss = loss_1d
211
+ # z_loss = z_loss_1d if return_z_loss else None
123
212
 
124
- # downcast to dtype and store for backward
125
- ctx.save_for_backward(grad_input.detach(), grad_linear.detach())
126
- return loss
213
+ if reduction == "none":
214
+ # Return per-token losses
215
+ loss = loss_1d
216
+ z_loss = z_loss_1d if return_z_loss else None
217
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
218
+ else:
219
+ loss = torch.sum(loss_1d)
220
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
221
+ # For accuracy, we compute the mean across all non-ignored tokens
222
+ token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
127
223
 
128
- @staticmethod
129
- def backward(ctx, grad_output):
130
- (grad_input, grad_linear) = ctx.saved_tensors
131
- # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
132
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
133
- # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
134
- # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
135
- BT, H = grad_input.shape
136
- n_rows = BT
137
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
138
-
139
- element_mul[(n_rows,)](
140
- grad_input,
141
- grad_input.stride(-2),
224
+ # Cast back to original dtype
225
+ grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
226
+ grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
227
+
228
+ return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
229
+
230
+
231
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
232
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
233
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
234
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
235
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
236
+ BT, H = grad_input.shape
237
+ n_rows = BT
238
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
239
+
240
+ element_mul_kernel[(n_rows,)](
241
+ grad_input,
242
+ grad_input.stride(-2),
243
+ grad_output,
244
+ H,
245
+ BLOCK_SIZE=BLOCK_SIZE,
246
+ num_warps=32 if not is_hip() else 16,
247
+ )
248
+
249
+ # handle grad_weight
250
+ if grad_weight is not None:
251
+ V, H = grad_weight.shape
252
+ n_rows = V
253
+
254
+ element_mul_kernel[(n_rows,)](
255
+ grad_weight,
256
+ grad_weight.stride(-2),
142
257
  grad_output,
143
258
  H,
144
259
  BLOCK_SIZE=BLOCK_SIZE,
145
- num_warps=32,
260
+ num_warps=32 if not is_hip() else 16,
146
261
  )
147
262
 
148
- # handle grad_linear
149
- V, H = grad_linear.shape
263
+ if grad_bias is not None:
264
+ V = grad_bias.shape[0]
150
265
  n_rows = V
151
266
 
152
- element_mul[(n_rows,)](
153
- grad_linear,
154
- grad_linear.stride(-2),
267
+ element_mul_kernel[(n_rows,)](
268
+ grad_bias,
269
+ grad_bias.stride(-1),
155
270
  grad_output,
156
- H,
271
+ 1,
157
272
  BLOCK_SIZE=BLOCK_SIZE,
158
- num_warps=32,
273
+ num_warps=32 if not is_hip() else 16,
159
274
  )
275
+ return grad_input, grad_weight, grad_bias
160
276
 
161
- return (grad_input, grad_linear, None, None)
277
+
278
+ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
279
+ @staticmethod
280
+ @amp_custom_fwd
281
+ def forward(
282
+ ctx,
283
+ _input,
284
+ weight,
285
+ target,
286
+ bias=None,
287
+ ce_weight=None,
288
+ ignore_index=-100,
289
+ lse_square_scale=0.0,
290
+ label_smoothing=0.0,
291
+ reduction="mean",
292
+ softcap=None,
293
+ return_z_loss: bool = False,
294
+ accum_dtype=None,
295
+ use_token_scaling: bool = False,
296
+ return_token_accuracy: bool = False,
297
+ ):
298
+ """
299
+ Fusing the last linear layer with cross-entropy loss
300
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
301
+
302
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
303
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
304
+ compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
305
+ for the backward pass.
306
+
307
+ _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
308
+ target: (B*T) where each value is in [0, V-1]
309
+ weight: (V, H) where V is the number of classes
310
+ bias: (V) where V is the number of classes
311
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
312
+ ignore_index: the index to ignore in the target
313
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
314
+ reduction: reduction to apply
315
+ accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
316
+ Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
317
+ use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
318
+ When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
319
+ Default: False.
320
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
321
+ """
322
+
323
+ loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
324
+ _input=_input,
325
+ weight=weight,
326
+ target=target,
327
+ bias=bias,
328
+ ce_weight=ce_weight,
329
+ ignore_index=ignore_index,
330
+ lse_square_scale=lse_square_scale,
331
+ label_smoothing=label_smoothing,
332
+ reduction=reduction,
333
+ softcap=softcap,
334
+ return_z_loss=return_z_loss,
335
+ accum_dtype=accum_dtype,
336
+ use_token_scaling=use_token_scaling,
337
+ return_token_accuracy=return_token_accuracy,
338
+ )
339
+ # downcast to dtype and store for backward
340
+ ctx.save_for_backward(
341
+ grad_input.detach(),
342
+ grad_weight.detach() if grad_weight is not None else None,
343
+ grad_bias.detach() if bias is not None else None,
344
+ )
345
+ ctx.return_z_loss = return_z_loss
346
+ ctx.return_token_accuracy = return_token_accuracy
347
+ return loss, z_loss, token_accuracy
348
+
349
+ @staticmethod
350
+ @amp_custom_bwd
351
+ def backward(ctx, grad_output, grad_output2, grad_output3):
352
+ if ctx.return_z_loss:
353
+ del grad_output2 # z_loss is only for logging
354
+ if ctx.return_token_accuracy:
355
+ del grad_output3 # token_accuracy is only for metrics
356
+ (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
357
+ grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
358
+ grad_output, grad_input, grad_weight, grad_bias
359
+ )
360
+ return (
361
+ grad_input,
362
+ grad_weight,
363
+ None,
364
+ grad_bias,
365
+ None,
366
+ None,
367
+ None,
368
+ None,
369
+ None,
370
+ None,
371
+ None,
372
+ None,
373
+ None, # use_token_scaling
374
+ None, # return_token_accuracy
375
+ )