liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -2,12 +2,10 @@ import torch
|
|
|
2
2
|
import triton
|
|
3
3
|
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
|
5
|
-
from liger_kernel.ops.utils import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
is_hip,
|
|
10
|
-
)
|
|
5
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
|
6
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
|
11
9
|
|
|
12
10
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
13
11
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -19,15 +17,26 @@ def fused_linear_cross_entropy_forward(
|
|
|
19
17
|
_input,
|
|
20
18
|
weight,
|
|
21
19
|
target,
|
|
20
|
+
ce_weight=None,
|
|
22
21
|
bias=None,
|
|
23
22
|
ignore_index=-100,
|
|
24
23
|
lse_square_scale=0.0,
|
|
25
24
|
label_smoothing=0.0,
|
|
26
25
|
reduction="mean",
|
|
26
|
+
softcap=None,
|
|
27
|
+
return_z_loss=False,
|
|
28
|
+
accum_dtype=None,
|
|
29
|
+
use_token_scaling=False,
|
|
30
|
+
return_token_accuracy=False,
|
|
27
31
|
):
|
|
28
|
-
|
|
32
|
+
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
33
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
34
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
35
|
+
)
|
|
29
36
|
device = _input.device
|
|
30
37
|
|
|
38
|
+
input_requires_grad = _input.requires_grad
|
|
39
|
+
|
|
31
40
|
# inputs have shape: BT x H
|
|
32
41
|
# materialized activations will have shape: BT x V
|
|
33
42
|
# the increase in memory = BT x V
|
|
@@ -40,21 +49,43 @@ def fused_linear_cross_entropy_forward(
|
|
|
40
49
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
41
50
|
|
|
42
51
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
43
|
-
chunk_size = triton.next_power_of_2(
|
|
44
|
-
triton.cdiv(BT, inc_factor)
|
|
45
|
-
) # (BT + inc_factor - 1) // inc_factor
|
|
52
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
|
46
53
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
47
54
|
|
|
48
|
-
grad_weight = (
|
|
49
|
-
torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
50
|
-
)
|
|
51
55
|
grad_input = torch.zeros_like(_input, device=device)
|
|
52
|
-
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
53
|
-
# we use fp32 for loss accumulator
|
|
54
|
-
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
55
56
|
|
|
56
|
-
#
|
|
57
|
-
|
|
57
|
+
# we use fp32 for loss and gradients accumulator
|
|
58
|
+
if input_requires_grad:
|
|
59
|
+
if accum_dtype is None:
|
|
60
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
|
61
|
+
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
|
62
|
+
else:
|
|
63
|
+
grad_weight = torch.zeros_like(weight, dtype=accum_dtype, device=device) if weight.requires_grad else None
|
|
64
|
+
grad_bias = torch.zeros_like(bias, dtype=accum_dtype, device=device) if bias is not None else None
|
|
65
|
+
else:
|
|
66
|
+
grad_weight = None
|
|
67
|
+
grad_bias = None
|
|
68
|
+
|
|
69
|
+
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
|
|
70
|
+
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
71
|
+
token_accuracy_1d = torch.zeros(BT, dtype=torch.float32, device=device) if return_token_accuracy else None
|
|
72
|
+
|
|
73
|
+
# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
|
|
74
|
+
target_mask = target != ignore_index
|
|
75
|
+
total_n_non_ignore = target_mask.sum().item()
|
|
76
|
+
total_sum_non_ignore_ce_weight = total_n_non_ignore
|
|
77
|
+
ce_weight_sum = 0.0
|
|
78
|
+
if ce_weight is not None:
|
|
79
|
+
assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
|
|
80
|
+
assert torch.is_floating_point(ce_weight), (
|
|
81
|
+
f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
|
|
82
|
+
)
|
|
83
|
+
total_sum_non_ignore_ce_weight = (
|
|
84
|
+
torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
|
|
85
|
+
)
|
|
86
|
+
ce_weight_sum = ce_weight.sum().item()
|
|
87
|
+
if ce_weight.stride(-1) != 1:
|
|
88
|
+
ce_weight = ce_weight.contiguous()
|
|
58
89
|
|
|
59
90
|
for chunk_id in range(num_chunks):
|
|
60
91
|
start_idx = chunk_id * chunk_size
|
|
@@ -65,16 +96,45 @@ def fused_linear_cross_entropy_forward(
|
|
|
65
96
|
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
|
|
66
97
|
if bias is not None:
|
|
67
98
|
logits_chunk = logits_chunk + bias
|
|
99
|
+
|
|
68
100
|
target_chunk = target[start_idx:end_idx] # chunk_size,
|
|
69
101
|
|
|
70
102
|
n_rows = logits_chunk.shape[0]
|
|
71
103
|
|
|
104
|
+
# Compute predicted probabilities for token scaling if needed
|
|
105
|
+
if use_token_scaling:
|
|
106
|
+
# Compute softmax probabilities for scaling
|
|
107
|
+
# We need to compute this before the cross entropy kernel modifies logits_chunk
|
|
108
|
+
logits_for_softmax = logits_chunk.detach().clone() # Detach to avoid gradient flow
|
|
109
|
+
if softcap is not None:
|
|
110
|
+
logits_for_softmax = softcap * torch.tanh(logits_for_softmax / softcap)
|
|
111
|
+
|
|
112
|
+
# Compute softmax to get predicted probabilities
|
|
113
|
+
probs = torch.softmax(logits_for_softmax, dim=-1)
|
|
114
|
+
|
|
115
|
+
# Get predicted probabilities for token scaling, handling ignored targets
|
|
116
|
+
valid_target_mask = target_chunk != ignore_index
|
|
117
|
+
valid_targets = target_chunk[valid_target_mask]
|
|
118
|
+
|
|
119
|
+
if len(valid_targets) > 0:
|
|
120
|
+
# Gather probabilities only for valid targets
|
|
121
|
+
valid_probs = probs[valid_target_mask]
|
|
122
|
+
pred_probs_valid = torch.gather(valid_probs, -1, valid_targets.unsqueeze(-1)).squeeze(-1)
|
|
123
|
+
|
|
124
|
+
# Create full tensor with zeros for ignored targets
|
|
125
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
126
|
+
pred_probs[valid_target_mask] = pred_probs_valid
|
|
127
|
+
else:
|
|
128
|
+
# All targets are ignored
|
|
129
|
+
pred_probs = torch.zeros_like(target_chunk, dtype=probs.dtype, device=probs.device)
|
|
130
|
+
|
|
131
|
+
# Store the scaling factors
|
|
132
|
+
scaling_factors = pred_probs.detach() # Detach to ensure no gradient flow
|
|
133
|
+
|
|
72
134
|
# unreduced loss
|
|
73
135
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
# when doing CE, use the upcasted precision
|
|
77
|
-
logits_chunk = logits_chunk.float()
|
|
136
|
+
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
|
|
137
|
+
token_accuracy_1d_slice = token_accuracy_1d[start_idx:end_idx] if return_token_accuracy else None
|
|
78
138
|
|
|
79
139
|
# ensure _input and target are contiguous
|
|
80
140
|
logits_chunk = logits_chunk.contiguous()
|
|
@@ -86,70 +146,91 @@ def fused_linear_cross_entropy_forward(
|
|
|
86
146
|
X_stride=logits_chunk.stride(-2),
|
|
87
147
|
Y_ptr=target_chunk,
|
|
88
148
|
Y_stride=target_chunk.stride(-1), # always 1
|
|
149
|
+
weight_ptr=ce_weight,
|
|
89
150
|
loss_ptr=loss_1d_slice,
|
|
90
|
-
z_loss_ptr=
|
|
151
|
+
z_loss_ptr=z_loss_1d_slice,
|
|
91
152
|
loss_stride=loss_1d_slice.stride(-1), # always 1
|
|
153
|
+
token_accuracy_ptr=token_accuracy_1d_slice,
|
|
154
|
+
token_accuracy_stride=token_accuracy_1d_slice.stride(-1)
|
|
155
|
+
if return_token_accuracy
|
|
156
|
+
else 0, # always 1 if accuracy is enabled
|
|
92
157
|
n_cols=V,
|
|
93
|
-
n_non_ignore=
|
|
158
|
+
n_non_ignore=total_n_non_ignore,
|
|
159
|
+
sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
|
|
160
|
+
weight_sum=ce_weight_sum,
|
|
94
161
|
ignore_index=ignore_index,
|
|
95
162
|
lse_square_scale=lse_square_scale,
|
|
96
163
|
label_smoothing=label_smoothing,
|
|
97
164
|
reduction=reduction,
|
|
98
|
-
|
|
165
|
+
softcap=softcap,
|
|
166
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
167
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
168
|
+
HAS_WEIGHT=True if ce_weight is not None else False,
|
|
169
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
170
|
+
HAS_GRADIENTS=input_requires_grad,
|
|
99
171
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
100
172
|
num_warps=32 if not is_hip() else 16,
|
|
101
173
|
)
|
|
102
174
|
|
|
103
|
-
#
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
logits_chunk = logits_chunk.to(dtype)
|
|
109
|
-
|
|
110
|
-
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
|
|
111
|
-
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
|
112
|
-
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
|
113
|
-
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
|
|
114
|
-
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
|
|
115
|
-
|
|
116
|
-
if reduction == "mean":
|
|
117
|
-
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
|
|
118
|
-
else:
|
|
119
|
-
alpha = 1.0
|
|
175
|
+
# Apply token scaling if requested
|
|
176
|
+
if use_token_scaling:
|
|
177
|
+
loss_1d_slice = loss_1d_slice * scaling_factors
|
|
178
|
+
if return_z_loss:
|
|
179
|
+
z_loss_1d_slice = z_loss_1d_slice * scaling_factors
|
|
120
180
|
|
|
121
|
-
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
122
|
-
|
|
181
|
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
182
|
+
if return_z_loss:
|
|
183
|
+
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
|
|
184
|
+
if return_token_accuracy:
|
|
185
|
+
token_accuracy_1d[start_idx:end_idx] = token_accuracy_1d_slice
|
|
186
|
+
grad_logits_chunk = logits_chunk # chunk_size x V
|
|
123
187
|
|
|
124
|
-
|
|
188
|
+
# Apply token scaling to gradients if requested
|
|
189
|
+
if use_token_scaling:
|
|
190
|
+
# Expand scaling factors to match gradient dimensions
|
|
191
|
+
scaling_factors_expanded = scaling_factors.unsqueeze(-1) # chunk_size x 1
|
|
192
|
+
grad_logits_chunk = grad_logits_chunk * scaling_factors_expanded
|
|
125
193
|
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
input=grad_weight,
|
|
129
|
-
mat1=logits_chunk.t(),
|
|
130
|
-
mat2=_input_chunk,
|
|
131
|
-
out=grad_weight,
|
|
132
|
-
alpha=alpha,
|
|
133
|
-
beta=1.0,
|
|
134
|
-
)
|
|
194
|
+
if input_requires_grad:
|
|
195
|
+
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
|
|
135
196
|
|
|
136
|
-
if
|
|
197
|
+
if grad_weight is not None and input_requires_grad:
|
|
198
|
+
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
|
|
199
|
+
|
|
200
|
+
if bias is not None and input_requires_grad:
|
|
137
201
|
torch.add(
|
|
138
202
|
input=grad_bias,
|
|
139
|
-
other=
|
|
203
|
+
other=grad_logits_chunk.sum(dim=0),
|
|
140
204
|
out=grad_bias,
|
|
141
|
-
alpha=
|
|
205
|
+
alpha=1.0,
|
|
142
206
|
)
|
|
143
207
|
|
|
144
|
-
|
|
145
|
-
|
|
208
|
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
|
209
|
+
# if reduction == "none":
|
|
210
|
+
# loss = loss_1d
|
|
211
|
+
# z_loss = z_loss_1d if return_z_loss else None
|
|
146
212
|
|
|
213
|
+
if reduction == "none":
|
|
214
|
+
# Return per-token losses
|
|
215
|
+
loss = loss_1d
|
|
216
|
+
z_loss = z_loss_1d if return_z_loss else None
|
|
217
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
218
|
+
else:
|
|
219
|
+
loss = torch.sum(loss_1d)
|
|
220
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
221
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
222
|
+
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
|
|
147
223
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
224
|
+
# Cast back to original dtype
|
|
225
|
+
grad_weight = grad_weight.to(weight.dtype) if grad_weight is not None else None
|
|
226
|
+
grad_bias = grad_bias.to(bias.dtype) if grad_bias is not None else None
|
|
227
|
+
|
|
228
|
+
return loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
|
151
232
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
152
|
-
if torch.
|
|
233
|
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
153
234
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
154
235
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
155
236
|
BT, H = grad_input.shape
|
|
@@ -203,10 +284,16 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
203
284
|
weight,
|
|
204
285
|
target,
|
|
205
286
|
bias=None,
|
|
287
|
+
ce_weight=None,
|
|
206
288
|
ignore_index=-100,
|
|
207
289
|
lse_square_scale=0.0,
|
|
208
290
|
label_smoothing=0.0,
|
|
209
291
|
reduction="mean",
|
|
292
|
+
softcap=None,
|
|
293
|
+
return_z_loss: bool = False,
|
|
294
|
+
accum_dtype=None,
|
|
295
|
+
use_token_scaling: bool = False,
|
|
296
|
+
return_token_accuracy: bool = False,
|
|
210
297
|
):
|
|
211
298
|
"""
|
|
212
299
|
Fusing the last linear layer with cross-entropy loss
|
|
@@ -221,19 +308,33 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
221
308
|
target: (B*T) where each value is in [0, V-1]
|
|
222
309
|
weight: (V, H) where V is the number of classes
|
|
223
310
|
bias: (V) where V is the number of classes
|
|
311
|
+
ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
|
|
224
312
|
ignore_index: the index to ignore in the target
|
|
225
313
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
226
314
|
reduction: reduction to apply
|
|
315
|
+
accum_dtype (torch.dtype): the dtype of intermediate result buffers for weight and bias gradient accumulations.
|
|
316
|
+
Recommended to set `accum_dtype` to higher precision, e.g. `torch.float32`, if the training is unstable with original dtype. Default: `None`, performing accumulations in original dtype
|
|
317
|
+
use_token_scaling (bool): whether to scale each token's loss by its predicted probability (detached).
|
|
318
|
+
When True, each token's loss is multiplied by the model's predicted probability for that token's true class.
|
|
319
|
+
Default: False.
|
|
320
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
227
321
|
"""
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
322
|
+
|
|
323
|
+
loss, z_loss, token_accuracy, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
324
|
+
_input=_input,
|
|
325
|
+
weight=weight,
|
|
326
|
+
target=target,
|
|
327
|
+
bias=bias,
|
|
328
|
+
ce_weight=ce_weight,
|
|
329
|
+
ignore_index=ignore_index,
|
|
330
|
+
lse_square_scale=lse_square_scale,
|
|
331
|
+
label_smoothing=label_smoothing,
|
|
332
|
+
reduction=reduction,
|
|
333
|
+
softcap=softcap,
|
|
334
|
+
return_z_loss=return_z_loss,
|
|
335
|
+
accum_dtype=accum_dtype,
|
|
336
|
+
use_token_scaling=use_token_scaling,
|
|
337
|
+
return_token_accuracy=return_token_accuracy,
|
|
237
338
|
)
|
|
238
339
|
# downcast to dtype and store for backward
|
|
239
340
|
ctx.save_for_backward(
|
|
@@ -241,13 +342,34 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
241
342
|
grad_weight.detach() if grad_weight is not None else None,
|
|
242
343
|
grad_bias.detach() if bias is not None else None,
|
|
243
344
|
)
|
|
244
|
-
|
|
345
|
+
ctx.return_z_loss = return_z_loss
|
|
346
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
347
|
+
return loss, z_loss, token_accuracy
|
|
245
348
|
|
|
246
349
|
@staticmethod
|
|
247
350
|
@amp_custom_bwd
|
|
248
|
-
def backward(ctx, grad_output):
|
|
351
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
352
|
+
if ctx.return_z_loss:
|
|
353
|
+
del grad_output2 # z_loss is only for logging
|
|
354
|
+
if ctx.return_token_accuracy:
|
|
355
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
249
356
|
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
|
|
250
357
|
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
|
|
251
358
|
grad_output, grad_input, grad_weight, grad_bias
|
|
252
359
|
)
|
|
253
|
-
return (
|
|
360
|
+
return (
|
|
361
|
+
grad_input,
|
|
362
|
+
grad_weight,
|
|
363
|
+
None,
|
|
364
|
+
grad_bias,
|
|
365
|
+
None,
|
|
366
|
+
None,
|
|
367
|
+
None,
|
|
368
|
+
None,
|
|
369
|
+
None,
|
|
370
|
+
None,
|
|
371
|
+
None,
|
|
372
|
+
None,
|
|
373
|
+
None, # use_token_scaling
|
|
374
|
+
None, # return_token_accuracy
|
|
375
|
+
)
|
|
@@ -4,17 +4,16 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
)
|
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
|
8
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
|
9
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
|
10
|
+
from liger_kernel.ops.utils import is_hip
|
|
11
|
+
from liger_kernel.utils import infer_device
|
|
13
12
|
|
|
14
13
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
15
14
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
16
15
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
17
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
16
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
def fused_linear_jsd_forward(
|
|
@@ -43,16 +42,10 @@ def fused_linear_jsd_forward(
|
|
|
43
42
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
44
43
|
|
|
45
44
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
46
|
-
chunk_size = triton.next_power_of_2(
|
|
47
|
-
triton.cdiv(BT, inc_factor)
|
|
48
|
-
) # (BT + inc_factor - 1) // inc_factor
|
|
45
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
|
49
46
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
50
47
|
|
|
51
|
-
grad_weight = (
|
|
52
|
-
torch.zeros_like(student_weight, device=device)
|
|
53
|
-
if student_weight.requires_grad
|
|
54
|
-
else None
|
|
55
|
-
)
|
|
48
|
+
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
|
|
56
49
|
grad_input = torch.zeros_like(student_input)
|
|
57
50
|
# we use fp32 for loss accumulator
|
|
58
51
|
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
|
@@ -73,12 +66,8 @@ def fused_linear_jsd_forward(
|
|
|
73
66
|
# shape: chunk_size x V
|
|
74
67
|
# For anything starting from logits to the final JSD loss, we do computation
|
|
75
68
|
# in FP32 to avoid losing numerical stability.
|
|
76
|
-
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
|
77
|
-
|
|
78
|
-
)
|
|
79
|
-
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
|
80
|
-
torch.float32
|
|
81
|
-
)
|
|
69
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
|
|
70
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
|
|
82
71
|
chunk_n_rows = student_logits_chunk.shape[0]
|
|
83
72
|
|
|
84
73
|
# unreduced loss
|
|
@@ -104,9 +93,7 @@ def fused_linear_jsd_forward(
|
|
|
104
93
|
dX_ptr=student_prob_chunk,
|
|
105
94
|
dX_stride=student_prob_chunk.stride(-2),
|
|
106
95
|
label_ptr=(
|
|
107
|
-
shift_labels[start_idx:end_idx]
|
|
108
|
-
if has_label
|
|
109
|
-
else torch.empty(1, device=device)
|
|
96
|
+
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
|
|
110
97
|
), # dummy ptr if no label
|
|
111
98
|
beta=jsd_beta,
|
|
112
99
|
n_non_ignore=n_non_ignore,
|
|
@@ -121,9 +108,7 @@ def fused_linear_jsd_forward(
|
|
|
121
108
|
student_logits_chunk = (
|
|
122
109
|
student_prob_chunk
|
|
123
110
|
- torch.softmax(student_logits_chunk, dim=-1)
|
|
124
|
-
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
|
125
|
-
student_prob_chunk.shape
|
|
126
|
-
)
|
|
111
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
|
|
127
112
|
) / temperature
|
|
128
113
|
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
|
129
114
|
# w.r.t. `lm_head` which should be computed in original dtype
|
|
@@ -202,7 +187,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
202
187
|
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
|
|
203
188
|
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
|
|
204
189
|
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
205
|
-
jsd_beta (float): coefficient beta of generalized JSD in the
|
|
190
|
+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
206
191
|
ignore_index (int): the index to ignore. Default: -100
|
|
207
192
|
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
|
|
208
193
|
|
|
@@ -211,9 +196,9 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
211
196
|
"""
|
|
212
197
|
has_label = False
|
|
213
198
|
if shift_labels is not None:
|
|
214
|
-
assert shift_labels.shape == (
|
|
215
|
-
|
|
216
|
-
)
|
|
199
|
+
assert shift_labels.shape == (teacher_input.shape[0],), (
|
|
200
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
201
|
+
)
|
|
217
202
|
shift_labels = shift_labels.contiguous()
|
|
218
203
|
has_label = True
|
|
219
204
|
|
|
@@ -239,7 +224,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
239
224
|
@amp_custom_bwd
|
|
240
225
|
def backward(ctx, grad_output):
|
|
241
226
|
(grad_input, grad_weight) = ctx.saved_tensors
|
|
242
|
-
grad_input, grad_weight = fused_linear_jsd_backward(
|
|
243
|
-
grad_output, grad_input, grad_weight
|
|
244
|
-
)
|
|
227
|
+
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
|
|
245
228
|
return (grad_input, grad_weight, None, None, None, None, None, None)
|