liger-kernel 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/dpo_loss.py +8 -1
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/cross_entropy.py +4 -1
- liger_kernel/ops/dyt.py +113 -179
- liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/sparsemax.py +167 -0
- liger_kernel/transformers/__init__.py +11 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +8 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +8 -12
- liger_kernel/transformers/model/gemma2.py +8 -10
- liger_kernel/transformers/model/gemma3.py +3 -9
- liger_kernel/transformers/model/glm4.py +119 -0
- liger_kernel/transformers/model/llama.py +64 -15
- liger_kernel/transformers/model/llava.py +0 -8
- liger_kernel/transformers/model/mistral.py +8 -10
- liger_kernel/transformers/model/mixtral.py +8 -12
- liger_kernel/transformers/model/mllama.py +8 -11
- liger_kernel/transformers/model/olmo2.py +8 -10
- liger_kernel/transformers/model/paligemma.py +0 -8
- liger_kernel/transformers/model/phi3.py +8 -12
- liger_kernel/transformers/model/qwen2.py +8 -12
- liger_kernel/transformers/model/qwen2_5_vl.py +3 -7
- liger_kernel/transformers/model/qwen2_vl.py +3 -7
- liger_kernel/transformers/model/qwen3.py +112 -0
- liger_kernel/transformers/model/qwen3_moe.py +128 -0
- liger_kernel/transformers/monkey_patch.py +243 -13
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +36 -20
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +42 -34
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.8.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
|
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
68
|
compute_nll_loss=False,
|
|
69
69
|
compiled=True,
|
|
70
70
|
use_ref_model=True,
|
|
71
|
+
average_log_prob=False,
|
|
71
72
|
chunk_size=1,
|
|
72
73
|
):
|
|
73
74
|
"""
|
|
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
85
86
|
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
87
|
compiled (bool): Whether to use torch compile
|
|
87
88
|
use_ref_model (bool): Whether to use a reference model
|
|
89
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
88
90
|
chunk_size (int): Size of chunks for processing.
|
|
89
91
|
Returns:
|
|
90
92
|
torch.Tensor: Computed loss
|
|
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
104
106
|
ref_input=ref_input,
|
|
105
107
|
ref_weight=ref_weight,
|
|
106
108
|
ref_bias=ref_bias,
|
|
109
|
+
average_log_prob=average_log_prob,
|
|
107
110
|
chunk_size=chunk_size,
|
|
108
111
|
)
|
|
109
112
|
|
|
110
113
|
@staticmethod
|
|
111
114
|
def backward(ctx, *grad_output):
|
|
112
115
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
113
|
-
return *grads, None, None, None, None, None, None, None, None, None
|
|
116
|
+
return *grads, None, None, None, None, None, None, None, None, None, None
|
|
114
117
|
|
|
115
118
|
|
|
116
119
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
125
128
|
compute_nll_loss: bool = False,
|
|
126
129
|
compiled: bool = True,
|
|
127
130
|
use_ref_model: bool = True,
|
|
131
|
+
average_log_prob: bool = False,
|
|
128
132
|
chunk_size: int = 1,
|
|
129
133
|
):
|
|
130
134
|
"""
|
|
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
134
138
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
135
139
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
136
140
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
141
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
137
142
|
chunk_size (int): Size of chunks for processing.
|
|
138
143
|
"""
|
|
139
144
|
super().__init__()
|
|
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
142
147
|
self.compute_nll_loss = compute_nll_loss
|
|
143
148
|
self.compiled = compiled
|
|
144
149
|
self.use_ref_model = use_ref_model
|
|
150
|
+
self.average_log_prob = average_log_prob
|
|
145
151
|
self.chunk_size = chunk_size
|
|
146
152
|
|
|
147
153
|
def forward(
|
|
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
167
173
|
self.compute_nll_loss,
|
|
168
174
|
self.compiled,
|
|
169
175
|
self.use_ref_model,
|
|
176
|
+
self.average_log_prob,
|
|
170
177
|
self.chunk_size,
|
|
171
178
|
)
|
|
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
222
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
224
|
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
-
strict=False,
|
|
226
225
|
):
|
|
227
226
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
228
227
|
ref_input_chunk = (
|
|
@@ -150,8 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
150
150
|
teacher_input: torch.Tensor,
|
|
151
151
|
teacher_weight: torch.Tensor,
|
|
152
152
|
true_labels: torch.LongTensor,
|
|
153
|
-
student_bias: torch.Tensor,
|
|
154
|
-
teacher_bias: torch.Tensor,
|
|
153
|
+
student_bias: torch.Tensor = None,
|
|
154
|
+
teacher_bias: torch.Tensor = None,
|
|
155
155
|
) -> torch.Tensor:
|
|
156
156
|
"""
|
|
157
157
|
Compute the JSD distillation loss.
|
|
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
|
|
|
351
351
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
352
352
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
353
353
|
pass
|
|
354
|
-
|
|
354
|
+
# If reduction is 'none'
|
|
355
|
+
elif grad_output.ndim > 0:
|
|
356
|
+
_input = _input * grad_output.unsqueeze(dim=1)
|
|
357
|
+
# If reduction is ['mean', 'sum'], grad_output is just a scalar
|
|
355
358
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
356
359
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
357
360
|
else:
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -4,7 +4,8 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from triton.language.extra.libdevice import tanh
|
|
8
|
+
|
|
8
9
|
from liger_kernel.ops.utils import compare_version
|
|
9
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
11
|
from liger_kernel.ops.utils import infer_device
|
|
@@ -20,187 +21,126 @@ else:
|
|
|
20
21
|
from triton.language.math import tanh
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
|
|
25
|
+
# for bn in [1024, 2048, 4096]
|
|
26
|
+
# for ns in [1,2,4]
|
|
27
|
+
# for nw in [4, 8, 16, 32]
|
|
28
|
+
# ],
|
|
29
|
+
# key=['N'])
|
|
23
30
|
@triton.jit
|
|
24
|
-
def _dyt_fwd_kernel(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
):
|
|
35
|
-
"""
|
|
36
|
-
Reference:
|
|
37
|
-
https://arxiv.org/abs/2503.10622
|
|
38
|
-
|
|
39
|
-
Shapes:
|
|
40
|
-
- x: (BT, C)
|
|
41
|
-
- alpha: (1)
|
|
42
|
-
- gamma: (C)
|
|
43
|
-
- beta: (C)
|
|
44
|
-
"""
|
|
45
|
-
row_idx = tl.program_id(0)
|
|
46
|
-
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
-
mask = offsets < n_cols
|
|
48
|
-
|
|
49
|
-
x_ptr += row_idx * x_row_stride
|
|
50
|
-
y_ptr += row_idx * y_row_stride
|
|
51
|
-
|
|
52
|
-
alpha = tl.load(alpha_ptr)
|
|
53
|
-
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
-
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
-
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
-
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
-
tl.store(y_ptr + offsets, y, mask=mask)
|
|
31
|
+
def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
|
|
32
|
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
33
|
+
mask = col < N
|
|
34
|
+
row_id = tl.cast(tl.program_id(1), tl.int64)
|
|
35
|
+
|
|
36
|
+
X += row_id * N
|
|
37
|
+
Y += row_id * N
|
|
38
|
+
alpha = tl.load(Alpha).to(tl.float32)
|
|
39
|
+
|
|
40
|
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
|
58
41
|
|
|
42
|
+
x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
|
|
59
43
|
|
|
44
|
+
tanh_x = tanh(alpha * x)
|
|
45
|
+
y = tanh_x * gamma
|
|
46
|
+
if HAVE_BETA:
|
|
47
|
+
beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
|
|
48
|
+
y += beta
|
|
49
|
+
tl.store(Y + col, y, mask=mask)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
|
|
53
|
+
# for bn in [1024, 2048, 4096]
|
|
54
|
+
# for ns in [1,2,4]
|
|
55
|
+
# for nw in [4, 8, 16]
|
|
56
|
+
# ],
|
|
57
|
+
# key=['N'])
|
|
60
58
|
@triton.jit
|
|
61
59
|
def _dyt_bwd_kernel(
|
|
62
|
-
|
|
63
|
-
x_row_stride,
|
|
64
|
-
dy_ptr,
|
|
65
|
-
dy_row_stride,
|
|
66
|
-
dx_ptr,
|
|
67
|
-
dx_row_stride,
|
|
68
|
-
alpha_ptr,
|
|
69
|
-
dalpha_ptr,
|
|
70
|
-
gamma_ptr,
|
|
71
|
-
dgamma_ptr,
|
|
72
|
-
dgamma_row_stride,
|
|
73
|
-
n_cols,
|
|
74
|
-
n_rows,
|
|
75
|
-
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
-
BLOCK_SIZE: tl.constexpr,
|
|
60
|
+
DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
|
|
77
61
|
):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
dalpha = 0.0
|
|
106
|
-
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
-
|
|
108
|
-
x_ptr += row_start * x_row_stride
|
|
109
|
-
dx_ptr += row_start * dx_row_stride
|
|
110
|
-
dy_ptr += row_start * dy_row_stride
|
|
111
|
-
alpha = tl.load(alpha_ptr)
|
|
112
|
-
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
-
|
|
114
|
-
for _ in tl.range(row_start, row_end):
|
|
115
|
-
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
-
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
-
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
-
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
-
|
|
120
|
-
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
-
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
-
dgamma += dy * tanh_ax
|
|
123
|
-
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
-
|
|
125
|
-
dy_ptr += dy_row_stride
|
|
126
|
-
x_ptr += x_row_stride
|
|
127
|
-
dx_ptr += dx_row_stride
|
|
128
|
-
|
|
129
|
-
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
-
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
-
|
|
132
|
-
pass
|
|
62
|
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
63
|
+
mask = col < N
|
|
64
|
+
start_row_id = tl.cast(tl.program_id(1), tl.int64)
|
|
65
|
+
|
|
66
|
+
alpha = tl.load(Alpha).to(tl.float32)
|
|
67
|
+
da = 0.0
|
|
68
|
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
|
69
|
+
dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
70
|
+
if HAVE_BETA:
|
|
71
|
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
72
|
+
for row_id in range(start_row_id, M, tl.num_programs(1)):
|
|
73
|
+
x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
|
74
|
+
dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
|
75
|
+
tanh_x = tanh(alpha * x)
|
|
76
|
+
if HAVE_BETA:
|
|
77
|
+
db += dy
|
|
78
|
+
dg += dy * tanh_x
|
|
79
|
+
tmp = (1 - tanh_x * tanh_x) * dy * gamma
|
|
80
|
+
da += tl.sum(x * tmp, 0)
|
|
81
|
+
dx = alpha * tmp
|
|
82
|
+
tl.store(DX + row_id * N + col, dx, mask=mask)
|
|
83
|
+
|
|
84
|
+
tl.store(DG + start_row_id * N + col, dg, mask=mask)
|
|
85
|
+
if HAVE_BETA:
|
|
86
|
+
tl.store(DB + start_row_id * N + col, db, mask=mask)
|
|
87
|
+
tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
|
|
133
88
|
|
|
134
89
|
|
|
135
90
|
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
91
|
+
assert x.is_contiguous()
|
|
92
|
+
HAVE_BETA = True if beta is not None else False
|
|
93
|
+
input_shape = x.shape
|
|
94
|
+
x = x.view(-1, input_shape[-1])
|
|
95
|
+
M, N = x.shape
|
|
96
|
+
|
|
140
97
|
y = torch.empty_like(x)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
98
|
+
|
|
99
|
+
if N >= 4096:
|
|
100
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
|
|
101
|
+
else:
|
|
102
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
|
|
103
|
+
|
|
104
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
|
|
105
|
+
_dyt_fwd_kernel[(grid)](
|
|
106
|
+
x,
|
|
107
|
+
y,
|
|
108
|
+
alpha,
|
|
109
|
+
gamma,
|
|
110
|
+
beta,
|
|
111
|
+
HAVE_BETA,
|
|
112
|
+
N,
|
|
113
|
+
**kwargs,
|
|
153
114
|
)
|
|
154
|
-
return y.view(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
-
sm_count = 1
|
|
115
|
+
return y.view(input_shape)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
119
|
+
assert dy.is_contiguous()
|
|
120
|
+
input_shape = x.shape
|
|
121
|
+
x = x.view(-1, input_shape[-1])
|
|
122
|
+
M, N = x.shape
|
|
123
|
+
HAVE_BETA = True if beta is not None else False
|
|
124
|
+
|
|
166
125
|
device = infer_device()
|
|
167
126
|
if device == "cuda":
|
|
168
|
-
|
|
127
|
+
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
128
|
elif device == "xpu":
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
dy_ptr=dy,
|
|
186
|
-
dy_row_stride=dy.stride(0),
|
|
187
|
-
dx_ptr=dx,
|
|
188
|
-
dx_row_stride=dx.stride(0),
|
|
189
|
-
alpha_ptr=alpha,
|
|
190
|
-
dalpha_ptr=_dalpha,
|
|
191
|
-
gamma_ptr=gamma,
|
|
192
|
-
dgamma_ptr=_dgamma,
|
|
193
|
-
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
-
n_cols=n_cols,
|
|
195
|
-
n_rows=n_rows,
|
|
196
|
-
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
-
num_warps=num_warps,
|
|
199
|
-
)
|
|
200
|
-
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
-
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
-
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
-
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
129
|
+
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
130
|
+
|
|
131
|
+
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
132
|
+
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
133
|
+
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
|
134
|
+
dx = torch.empty_like(dy)
|
|
135
|
+
|
|
136
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
|
|
137
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
|
|
138
|
+
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
|
|
139
|
+
if HAVE_BETA:
|
|
140
|
+
db = db.sum(0).to(x.dtype)
|
|
141
|
+
dg = dg.sum(0).to(gamma.dtype)
|
|
142
|
+
da = da.sum().to(x.dtype).unsqueeze(0)
|
|
143
|
+
return dx.view(input_shape), da, dg, db
|
|
204
144
|
|
|
205
145
|
|
|
206
146
|
class LigerDyTFunction(torch.autograd.Function):
|
|
@@ -208,18 +148,12 @@ class LigerDyTFunction(torch.autograd.Function):
|
|
|
208
148
|
@ensure_contiguous
|
|
209
149
|
def forward(ctx, x, alpha, gamma, beta):
|
|
210
150
|
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
-
ctx.save_for_backward(x, alpha, gamma)
|
|
151
|
+
ctx.save_for_backward(x, alpha, gamma, beta)
|
|
212
152
|
return y
|
|
213
153
|
|
|
214
154
|
@staticmethod
|
|
215
155
|
@ensure_contiguous
|
|
216
|
-
def backward(ctx,
|
|
217
|
-
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
-
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
-
|
|
220
|
-
x,
|
|
221
|
-
alpha,
|
|
222
|
-
gamma,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
return (dx, dalpha, dgamma, dbeta)
|
|
156
|
+
def backward(ctx, dy):
|
|
157
|
+
x, alpha, gamma, beta = ctx.saved_tensors
|
|
158
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
|
|
159
|
+
return dx, dalpha, dgamma, dbeta
|
|
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
|
|
|
143
143
|
alpha=1.0,
|
|
144
144
|
)
|
|
145
145
|
|
|
146
|
-
if reduction
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
|
|
147
|
+
# if reduction == "none":
|
|
148
|
+
# loss = loss_1d
|
|
149
|
+
# z_loss = z_loss_1d if return_z_loss else None
|
|
149
150
|
|
|
150
151
|
else:
|
|
151
152
|
loss = torch.sum(loss_1d)
|