liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +111 -179
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +70 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +25 -16
- liger_kernel/transformers/model/gemma2.py +27 -14
- liger_kernel/transformers/model/gemma3.py +62 -106
- liger_kernel/transformers/model/glm4.py +16 -13
- liger_kernel/transformers/model/llama.py +81 -18
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -132
- liger_kernel/transformers/model/mistral.py +13 -14
- liger_kernel/transformers/model/mixtral.py +16 -15
- liger_kernel/transformers/model/mllama.py +16 -14
- liger_kernel/transformers/model/olmo2.py +16 -13
- liger_kernel/transformers/model/paligemma.py +8 -9
- liger_kernel/transformers/model/phi3.py +25 -16
- liger_kernel/transformers/model/qwen2.py +24 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
- liger_kernel/transformers/model/qwen2_vl.py +38 -106
- liger_kernel/transformers/model/qwen3.py +11 -9
- liger_kernel/transformers/model/qwen3_moe.py +132 -0
- liger_kernel/transformers/monkey_patch.py +424 -81
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
- liger_kernel-0.6.0.dist-info/RECORD +97 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel-0.5.9.dist-info/RECORD +0 -84
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
|
10
|
+
"""
|
|
11
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
12
|
+
Args:
|
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
15
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
16
|
+
Returns:
|
|
17
|
+
torch.Tensor: cosine similarity loss
|
|
18
|
+
"""
|
|
19
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
20
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
21
|
+
|
|
22
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
23
|
+
loss = beta * (1 - cosine_sim)
|
|
24
|
+
return loss.sum()
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def forward(
|
|
28
|
+
cls,
|
|
29
|
+
ctx,
|
|
30
|
+
student_input: torch.Tensor,
|
|
31
|
+
student_weight: torch.Tensor,
|
|
32
|
+
teacher_input: torch.Tensor,
|
|
33
|
+
teacher_weight: torch.Tensor,
|
|
34
|
+
true_labels: torch.LongTensor,
|
|
35
|
+
student_bias: torch.Tensor,
|
|
36
|
+
teacher_bias: torch.Tensor,
|
|
37
|
+
weight_hard_loss: float = 0.5,
|
|
38
|
+
weight_soft_loss: float = 0.5,
|
|
39
|
+
beta: float = 0.5,
|
|
40
|
+
ignore_index: int = -100,
|
|
41
|
+
temperature: float = 1.0,
|
|
42
|
+
compiled: bool = True,
|
|
43
|
+
chunk_size: int = 1024,
|
|
44
|
+
):
|
|
45
|
+
return super().forward(
|
|
46
|
+
cls=cls,
|
|
47
|
+
ctx=ctx,
|
|
48
|
+
student_input=student_input,
|
|
49
|
+
student_weight=student_weight,
|
|
50
|
+
teacher_input=teacher_input,
|
|
51
|
+
teacher_weight=teacher_weight,
|
|
52
|
+
target=true_labels,
|
|
53
|
+
student_bias=student_bias,
|
|
54
|
+
teacher_bias=teacher_bias,
|
|
55
|
+
chunk_size=chunk_size,
|
|
56
|
+
weight_hard_loss=weight_hard_loss,
|
|
57
|
+
weight_soft_loss=weight_soft_loss,
|
|
58
|
+
beta=beta,
|
|
59
|
+
ignore_index=ignore_index,
|
|
60
|
+
temperature=temperature,
|
|
61
|
+
compiled=compiled,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def backward(ctx, grad_output):
|
|
66
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
67
|
+
|
|
68
|
+
return (
|
|
69
|
+
*grads,
|
|
70
|
+
None, # teacher_bias
|
|
71
|
+
None, # weight_hard_loss
|
|
72
|
+
None, # weight_soft_loss
|
|
73
|
+
None, # beta
|
|
74
|
+
None, # ignore_index
|
|
75
|
+
None, # temperature
|
|
76
|
+
None, # compiled
|
|
77
|
+
None, # chunk_size
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
weight_hard_loss: float = 0.5,
|
|
85
|
+
weight_soft_loss: float = 0.5,
|
|
86
|
+
beta: float = 0.5,
|
|
87
|
+
ignore_index: int = -100,
|
|
88
|
+
temperature: float = 1.0,
|
|
89
|
+
compiled: bool = True,
|
|
90
|
+
chunk_size: int = 1024,
|
|
91
|
+
):
|
|
92
|
+
super().__init__()
|
|
93
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
94
|
+
self.weight_hard_loss = weight_hard_loss
|
|
95
|
+
self.weight_soft_loss = weight_soft_loss
|
|
96
|
+
self.ignore_index = ignore_index
|
|
97
|
+
self.temperature = temperature
|
|
98
|
+
self.compiled = compiled
|
|
99
|
+
self.beta = beta
|
|
100
|
+
self.chunk_size = chunk_size
|
|
101
|
+
|
|
102
|
+
def forward(
|
|
103
|
+
self,
|
|
104
|
+
student_input: torch.Tensor,
|
|
105
|
+
student_weight: torch.Tensor,
|
|
106
|
+
teacher_input: torch.Tensor,
|
|
107
|
+
teacher_weight: torch.Tensor,
|
|
108
|
+
true_labels: torch.LongTensor,
|
|
109
|
+
student_bias: torch.Tensor = None,
|
|
110
|
+
teacher_bias: torch.Tensor = None,
|
|
111
|
+
) -> torch.Tensor:
|
|
112
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
113
|
+
student_input,
|
|
114
|
+
student_weight,
|
|
115
|
+
teacher_input,
|
|
116
|
+
teacher_weight,
|
|
117
|
+
true_labels,
|
|
118
|
+
student_bias,
|
|
119
|
+
teacher_bias,
|
|
120
|
+
self.weight_hard_loss,
|
|
121
|
+
self.weight_soft_loss,
|
|
122
|
+
self.beta,
|
|
123
|
+
self.ignore_index,
|
|
124
|
+
self.temperature,
|
|
125
|
+
self.compiled,
|
|
126
|
+
self.chunk_size,
|
|
127
|
+
)
|
|
@@ -128,7 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
128
128
|
compute_nll_loss: bool = False,
|
|
129
129
|
compiled: bool = True,
|
|
130
130
|
use_ref_model: bool = True,
|
|
131
|
-
average_log_prob: bool =
|
|
131
|
+
average_log_prob: bool = False,
|
|
132
132
|
chunk_size: int = 1,
|
|
133
133
|
):
|
|
134
134
|
"""
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
222
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
224
|
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
-
strict=False,
|
|
226
225
|
):
|
|
227
226
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
228
227
|
ref_input_chunk = (
|
|
@@ -150,8 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
150
150
|
teacher_input: torch.Tensor,
|
|
151
151
|
teacher_weight: torch.Tensor,
|
|
152
152
|
true_labels: torch.LongTensor,
|
|
153
|
-
student_bias: torch.Tensor,
|
|
154
|
-
teacher_bias: torch.Tensor,
|
|
153
|
+
student_bias: torch.Tensor = None,
|
|
154
|
+
teacher_bias: torch.Tensor = None,
|
|
155
155
|
) -> torch.Tensor:
|
|
156
156
|
"""
|
|
157
157
|
Compute the JSD distillation loss.
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import calculate_settings
|
|
8
7
|
from liger_kernel.ops.utils import compare_version
|
|
9
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
9
|
from liger_kernel.ops.utils import infer_device
|
|
@@ -20,187 +19,126 @@ else:
|
|
|
20
19
|
from triton.language.math import tanh
|
|
21
20
|
|
|
22
21
|
|
|
22
|
+
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
|
|
23
|
+
# for bn in [1024, 2048, 4096]
|
|
24
|
+
# for ns in [1,2,4]
|
|
25
|
+
# for nw in [4, 8, 16, 32]
|
|
26
|
+
# ],
|
|
27
|
+
# key=['N'])
|
|
23
28
|
@triton.jit
|
|
24
|
-
def _dyt_fwd_kernel(
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
Reference:
|
|
37
|
-
https://arxiv.org/abs/2503.10622
|
|
38
|
-
|
|
39
|
-
Shapes:
|
|
40
|
-
- x: (BT, C)
|
|
41
|
-
- alpha: (1)
|
|
42
|
-
- gamma: (C)
|
|
43
|
-
- beta: (C)
|
|
44
|
-
"""
|
|
45
|
-
row_idx = tl.program_id(0)
|
|
46
|
-
offsets = tl.arange(0, BLOCK_SIZE)
|
|
47
|
-
mask = offsets < n_cols
|
|
48
|
-
|
|
49
|
-
x_ptr += row_idx * x_row_stride
|
|
50
|
-
y_ptr += row_idx * y_row_stride
|
|
51
|
-
|
|
52
|
-
alpha = tl.load(alpha_ptr)
|
|
53
|
-
gamma = tl.load(gamma_ptr + offsets, mask=mask)
|
|
54
|
-
beta = tl.load(beta_ptr + offsets, mask=mask)
|
|
55
|
-
x = tl.load(x_ptr + offsets, mask=mask)
|
|
56
|
-
y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
|
|
57
|
-
tl.store(y_ptr + offsets, y, mask=mask)
|
|
29
|
+
def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
|
|
30
|
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
31
|
+
mask = col < N
|
|
32
|
+
row_id = tl.cast(tl.program_id(1), tl.int64)
|
|
33
|
+
|
|
34
|
+
X += row_id * N
|
|
35
|
+
Y += row_id * N
|
|
36
|
+
alpha = tl.load(Alpha).to(tl.float32)
|
|
37
|
+
|
|
38
|
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
|
39
|
+
|
|
40
|
+
x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
|
|
58
41
|
|
|
42
|
+
tanh_x = tanh(alpha * x)
|
|
43
|
+
y = tanh_x * gamma
|
|
44
|
+
if HAVE_BETA:
|
|
45
|
+
beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
|
|
46
|
+
y += beta
|
|
47
|
+
tl.store(Y + col, y, mask=mask)
|
|
59
48
|
|
|
49
|
+
|
|
50
|
+
# @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
|
|
51
|
+
# for bn in [1024, 2048, 4096]
|
|
52
|
+
# for ns in [1,2,4]
|
|
53
|
+
# for nw in [4, 8, 16]
|
|
54
|
+
# ],
|
|
55
|
+
# key=['N'])
|
|
60
56
|
@triton.jit
|
|
61
57
|
def _dyt_bwd_kernel(
|
|
62
|
-
|
|
63
|
-
x_row_stride,
|
|
64
|
-
dy_ptr,
|
|
65
|
-
dy_row_stride,
|
|
66
|
-
dx_ptr,
|
|
67
|
-
dx_row_stride,
|
|
68
|
-
alpha_ptr,
|
|
69
|
-
dalpha_ptr,
|
|
70
|
-
gamma_ptr,
|
|
71
|
-
dgamma_ptr,
|
|
72
|
-
dgamma_row_stride,
|
|
73
|
-
n_cols,
|
|
74
|
-
n_rows,
|
|
75
|
-
ROWS_PER_PROGRAM: tl.constexpr,
|
|
76
|
-
BLOCK_SIZE: tl.constexpr,
|
|
58
|
+
DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
|
|
77
59
|
):
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
dalpha = 0.0
|
|
106
|
-
dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
107
|
-
|
|
108
|
-
x_ptr += row_start * x_row_stride
|
|
109
|
-
dx_ptr += row_start * dx_row_stride
|
|
110
|
-
dy_ptr += row_start * dy_row_stride
|
|
111
|
-
alpha = tl.load(alpha_ptr)
|
|
112
|
-
gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
|
|
113
|
-
|
|
114
|
-
for _ in tl.range(row_start, row_end):
|
|
115
|
-
dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
|
|
116
|
-
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
|
117
|
-
tanh_ax = tanh((alpha * x).cast(tl.float32))
|
|
118
|
-
sech2_ax = 1 - tanh_ax * tanh_ax
|
|
119
|
-
|
|
120
|
-
dx = dy * gamma * sech2_ax * alpha
|
|
121
|
-
dalpha += tl.sum(dy * gamma * sech2_ax * x)
|
|
122
|
-
dgamma += dy * tanh_ax
|
|
123
|
-
tl.store(dx_ptr + offsets, dx, mask=mask)
|
|
124
|
-
|
|
125
|
-
dy_ptr += dy_row_stride
|
|
126
|
-
x_ptr += x_row_stride
|
|
127
|
-
dx_ptr += dx_row_stride
|
|
128
|
-
|
|
129
|
-
tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
|
|
130
|
-
tl.store(dalpha_ptr + pid, dalpha)
|
|
131
|
-
|
|
132
|
-
pass
|
|
60
|
+
col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
61
|
+
mask = col < N
|
|
62
|
+
start_row_id = tl.cast(tl.program_id(1), tl.int64)
|
|
63
|
+
|
|
64
|
+
alpha = tl.load(Alpha).to(tl.float32)
|
|
65
|
+
da = 0.0
|
|
66
|
+
gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
|
|
67
|
+
dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
68
|
+
if HAVE_BETA:
|
|
69
|
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
70
|
+
for row_id in range(start_row_id, M, tl.num_programs(1)):
|
|
71
|
+
x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
|
72
|
+
dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
|
|
73
|
+
tanh_x = tanh(alpha * x)
|
|
74
|
+
if HAVE_BETA:
|
|
75
|
+
db += dy
|
|
76
|
+
dg += dy * tanh_x
|
|
77
|
+
tmp = (1 - tanh_x * tanh_x) * dy * gamma
|
|
78
|
+
da += tl.sum(x * tmp, 0)
|
|
79
|
+
dx = alpha * tmp
|
|
80
|
+
tl.store(DX + row_id * N + col, dx, mask=mask)
|
|
81
|
+
|
|
82
|
+
tl.store(DG + start_row_id * N + col, dg, mask=mask)
|
|
83
|
+
if HAVE_BETA:
|
|
84
|
+
tl.store(DB + start_row_id * N + col, db, mask=mask)
|
|
85
|
+
tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
|
|
133
86
|
|
|
134
87
|
|
|
135
88
|
def liger_dyt_fwd(x, alpha, gamma, beta):
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
89
|
+
assert x.is_contiguous()
|
|
90
|
+
HAVE_BETA = True if beta is not None else False
|
|
91
|
+
input_shape = x.shape
|
|
92
|
+
x = x.view(-1, input_shape[-1])
|
|
93
|
+
M, N = x.shape
|
|
94
|
+
|
|
140
95
|
y = torch.empty_like(x)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
96
|
+
|
|
97
|
+
if N >= 4096:
|
|
98
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
|
|
99
|
+
else:
|
|
100
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
|
|
101
|
+
|
|
102
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
|
|
103
|
+
_dyt_fwd_kernel[(grid)](
|
|
104
|
+
x,
|
|
105
|
+
y,
|
|
106
|
+
alpha,
|
|
107
|
+
gamma,
|
|
108
|
+
beta,
|
|
109
|
+
HAVE_BETA,
|
|
110
|
+
N,
|
|
111
|
+
**kwargs,
|
|
153
112
|
)
|
|
154
|
-
return y.view(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def liger_dyt_bwd(dy, x, alpha, gamma):
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
165
|
-
sm_count = 1
|
|
113
|
+
return y.view(input_shape)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
117
|
+
assert dy.is_contiguous()
|
|
118
|
+
input_shape = x.shape
|
|
119
|
+
x = x.view(-1, input_shape[-1])
|
|
120
|
+
M, N = x.shape
|
|
121
|
+
HAVE_BETA = True if beta is not None else False
|
|
122
|
+
|
|
166
123
|
device = infer_device()
|
|
167
124
|
if device == "cuda":
|
|
168
|
-
|
|
125
|
+
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
169
126
|
elif device == "xpu":
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
dy_ptr=dy,
|
|
186
|
-
dy_row_stride=dy.stride(0),
|
|
187
|
-
dx_ptr=dx,
|
|
188
|
-
dx_row_stride=dx.stride(0),
|
|
189
|
-
alpha_ptr=alpha,
|
|
190
|
-
dalpha_ptr=_dalpha,
|
|
191
|
-
gamma_ptr=gamma,
|
|
192
|
-
dgamma_ptr=_dgamma,
|
|
193
|
-
dgamma_row_stride=_dgamma.stride(0),
|
|
194
|
-
n_cols=n_cols,
|
|
195
|
-
n_rows=n_rows,
|
|
196
|
-
ROWS_PER_PROGRAM=rows_per_program,
|
|
197
|
-
BLOCK_SIZE=BLOCK_SIZE,
|
|
198
|
-
num_warps=num_warps,
|
|
199
|
-
)
|
|
200
|
-
dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
|
|
201
|
-
dgamma = _dgamma.sum(dim=0).to(dtype)
|
|
202
|
-
dbeta = dy.sum(dim=0).to(dtype)
|
|
203
|
-
return dx.view(*shape), dalpha, dgamma, dbeta
|
|
127
|
+
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
128
|
+
|
|
129
|
+
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
130
|
+
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
131
|
+
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
|
132
|
+
dx = torch.empty_like(dy)
|
|
133
|
+
|
|
134
|
+
kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
|
|
135
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
|
|
136
|
+
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
|
|
137
|
+
if HAVE_BETA:
|
|
138
|
+
db = db.sum(0).to(x.dtype)
|
|
139
|
+
dg = dg.sum(0).to(gamma.dtype)
|
|
140
|
+
da = da.sum().to(x.dtype).unsqueeze(0)
|
|
141
|
+
return dx.view(input_shape), da, dg, db
|
|
204
142
|
|
|
205
143
|
|
|
206
144
|
class LigerDyTFunction(torch.autograd.Function):
|
|
@@ -208,18 +146,12 @@ class LigerDyTFunction(torch.autograd.Function):
|
|
|
208
146
|
@ensure_contiguous
|
|
209
147
|
def forward(ctx, x, alpha, gamma, beta):
|
|
210
148
|
y = liger_dyt_fwd(x, alpha, gamma, beta)
|
|
211
|
-
ctx.save_for_backward(x, alpha, gamma)
|
|
149
|
+
ctx.save_for_backward(x, alpha, gamma, beta)
|
|
212
150
|
return y
|
|
213
151
|
|
|
214
152
|
@staticmethod
|
|
215
153
|
@ensure_contiguous
|
|
216
|
-
def backward(ctx,
|
|
217
|
-
x, alpha, gamma = ctx.saved_tensors
|
|
218
|
-
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
|
|
219
|
-
|
|
220
|
-
x,
|
|
221
|
-
alpha,
|
|
222
|
-
gamma,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
return (dx, dalpha, dgamma, dbeta)
|
|
154
|
+
def backward(ctx, dy):
|
|
155
|
+
x, alpha, gamma, beta = ctx.saved_tensors
|
|
156
|
+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
|
|
157
|
+
return dx, dalpha, dgamma, dbeta
|