liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  6. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  7. liger_kernel/ops/dyt.py +111 -179
  8. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  9. liger_kernel/ops/geglu.py +1 -1
  10. liger_kernel/ops/grpo_loss.py +310 -0
  11. liger_kernel/ops/multi_token_attention.py +207 -0
  12. liger_kernel/ops/rms_norm.py +265 -54
  13. liger_kernel/ops/softmax.py +201 -0
  14. liger_kernel/ops/sparsemax.py +179 -0
  15. liger_kernel/ops/swiglu.py +1 -1
  16. liger_kernel/transformers/__init__.py +8 -0
  17. liger_kernel/transformers/dyt.py +5 -3
  18. liger_kernel/transformers/fsdp.py +55 -0
  19. liger_kernel/transformers/functional.py +70 -0
  20. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  21. liger_kernel/transformers/grpo_loss.py +98 -0
  22. liger_kernel/transformers/model/gemma.py +25 -16
  23. liger_kernel/transformers/model/gemma2.py +27 -14
  24. liger_kernel/transformers/model/gemma3.py +62 -106
  25. liger_kernel/transformers/model/glm4.py +16 -13
  26. liger_kernel/transformers/model/llama.py +81 -18
  27. liger_kernel/transformers/model/llama4.py +108 -0
  28. liger_kernel/transformers/model/llava.py +95 -132
  29. liger_kernel/transformers/model/mistral.py +13 -14
  30. liger_kernel/transformers/model/mixtral.py +16 -15
  31. liger_kernel/transformers/model/mllama.py +16 -14
  32. liger_kernel/transformers/model/olmo2.py +16 -13
  33. liger_kernel/transformers/model/paligemma.py +8 -9
  34. liger_kernel/transformers/model/phi3.py +25 -16
  35. liger_kernel/transformers/model/qwen2.py +24 -15
  36. liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
  37. liger_kernel/transformers/model/qwen2_vl.py +38 -106
  38. liger_kernel/transformers/model/qwen3.py +11 -9
  39. liger_kernel/transformers/model/qwen3_moe.py +132 -0
  40. liger_kernel/transformers/monkey_patch.py +424 -81
  41. liger_kernel/transformers/multi_token_attention.py +64 -0
  42. liger_kernel/transformers/rms_norm.py +40 -4
  43. liger_kernel/transformers/softmax.py +12 -0
  44. liger_kernel/transformers/sparsemax.py +16 -0
  45. liger_kernel/transformers/swiglu.py +21 -0
  46. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  47. liger_kernel/utils.py +11 -0
  48. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
  49. liger_kernel-0.6.0.dist-info/RECORD +97 -0
  50. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  51. liger_kernel/transformers/gema3_rms.py +0 -8
  52. liger_kernel-0.5.9.dist-info/RECORD +0 -84
  53. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  54. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  55. {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
@@ -0,0 +1,127 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
10
+ """
11
+ Compute Cosine loss (Cosine Similarity Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
16
+ Returns:
17
+ torch.Tensor: cosine similarity loss
18
+ """
19
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
20
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
21
+
22
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
23
+ loss = beta * (1 - cosine_sim)
24
+ return loss.sum()
25
+
26
+ @classmethod
27
+ def forward(
28
+ cls,
29
+ ctx,
30
+ student_input: torch.Tensor,
31
+ student_weight: torch.Tensor,
32
+ teacher_input: torch.Tensor,
33
+ teacher_weight: torch.Tensor,
34
+ true_labels: torch.LongTensor,
35
+ student_bias: torch.Tensor,
36
+ teacher_bias: torch.Tensor,
37
+ weight_hard_loss: float = 0.5,
38
+ weight_soft_loss: float = 0.5,
39
+ beta: float = 0.5,
40
+ ignore_index: int = -100,
41
+ temperature: float = 1.0,
42
+ compiled: bool = True,
43
+ chunk_size: int = 1024,
44
+ ):
45
+ return super().forward(
46
+ cls=cls,
47
+ ctx=ctx,
48
+ student_input=student_input,
49
+ student_weight=student_weight,
50
+ teacher_input=teacher_input,
51
+ teacher_weight=teacher_weight,
52
+ target=true_labels,
53
+ student_bias=student_bias,
54
+ teacher_bias=teacher_bias,
55
+ chunk_size=chunk_size,
56
+ weight_hard_loss=weight_hard_loss,
57
+ weight_soft_loss=weight_soft_loss,
58
+ beta=beta,
59
+ ignore_index=ignore_index,
60
+ temperature=temperature,
61
+ compiled=compiled,
62
+ )
63
+
64
+ @staticmethod
65
+ def backward(ctx, grad_output):
66
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
67
+
68
+ return (
69
+ *grads,
70
+ None, # teacher_bias
71
+ None, # weight_hard_loss
72
+ None, # weight_soft_loss
73
+ None, # beta
74
+ None, # ignore_index
75
+ None, # temperature
76
+ None, # compiled
77
+ None, # chunk_size
78
+ )
79
+
80
+
81
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
82
+ def __init__(
83
+ self,
84
+ weight_hard_loss: float = 0.5,
85
+ weight_soft_loss: float = 0.5,
86
+ beta: float = 0.5,
87
+ ignore_index: int = -100,
88
+ temperature: float = 1.0,
89
+ compiled: bool = True,
90
+ chunk_size: int = 1024,
91
+ ):
92
+ super().__init__()
93
+ assert temperature != 0, "Temperature cannot be 0."
94
+ self.weight_hard_loss = weight_hard_loss
95
+ self.weight_soft_loss = weight_soft_loss
96
+ self.ignore_index = ignore_index
97
+ self.temperature = temperature
98
+ self.compiled = compiled
99
+ self.beta = beta
100
+ self.chunk_size = chunk_size
101
+
102
+ def forward(
103
+ self,
104
+ student_input: torch.Tensor,
105
+ student_weight: torch.Tensor,
106
+ teacher_input: torch.Tensor,
107
+ teacher_weight: torch.Tensor,
108
+ true_labels: torch.LongTensor,
109
+ student_bias: torch.Tensor = None,
110
+ teacher_bias: torch.Tensor = None,
111
+ ) -> torch.Tensor:
112
+ return LigerFusedLinearCosineSimilarityFunction.apply(
113
+ student_input,
114
+ student_weight,
115
+ teacher_input,
116
+ teacher_weight,
117
+ true_labels,
118
+ student_bias,
119
+ teacher_bias,
120
+ self.weight_hard_loss,
121
+ self.weight_soft_loss,
122
+ self.beta,
123
+ self.ignore_index,
124
+ self.temperature,
125
+ self.compiled,
126
+ self.chunk_size,
127
+ )
@@ -128,7 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
128
128
  compute_nll_loss: bool = False,
129
129
  compiled: bool = True,
130
130
  use_ref_model: bool = True,
131
- average_log_prob: bool = True,
131
+ average_log_prob: bool = False,
132
132
  chunk_size: int = 1,
133
133
  ):
134
134
  """
@@ -1,3 +1,4 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
1
2
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
3
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
4
  from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
10
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
10
11
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
11
12
  liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
12
14
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
13
15
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
14
16
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
222
222
  (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
223
223
  (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
224
224
  (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
225
- strict=False,
226
225
  ):
227
226
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
228
227
  ref_input_chunk = (
@@ -150,8 +150,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
150
150
  teacher_input: torch.Tensor,
151
151
  teacher_weight: torch.Tensor,
152
152
  true_labels: torch.LongTensor,
153
- student_bias: torch.Tensor,
154
- teacher_bias: torch.Tensor,
153
+ student_bias: torch.Tensor = None,
154
+ teacher_bias: torch.Tensor = None,
155
155
  ) -> torch.Tensor:
156
156
  """
157
157
  Compute the JSD distillation loss.
liger_kernel/ops/dyt.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import calculate_settings
8
7
  from liger_kernel.ops.utils import compare_version
9
8
  from liger_kernel.ops.utils import ensure_contiguous
10
9
  from liger_kernel.ops.utils import infer_device
@@ -20,187 +19,126 @@ else:
20
19
  from triton.language.math import tanh
21
20
 
22
21
 
22
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
23
+ # for bn in [1024, 2048, 4096]
24
+ # for ns in [1,2,4]
25
+ # for nw in [4, 8, 16, 32]
26
+ # ],
27
+ # key=['N'])
23
28
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
29
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
30
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
31
+ mask = col < N
32
+ row_id = tl.cast(tl.program_id(1), tl.int64)
33
+
34
+ X += row_id * N
35
+ Y += row_id * N
36
+ alpha = tl.load(Alpha).to(tl.float32)
37
+
38
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
39
+
40
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
58
41
 
42
+ tanh_x = tanh(alpha * x)
43
+ y = tanh_x * gamma
44
+ if HAVE_BETA:
45
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
46
+ y += beta
47
+ tl.store(Y + col, y, mask=mask)
59
48
 
49
+
50
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
51
+ # for bn in [1024, 2048, 4096]
52
+ # for ns in [1,2,4]
53
+ # for nw in [4, 8, 16]
54
+ # ],
55
+ # key=['N'])
60
56
  @triton.jit
61
57
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
58
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
77
59
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
60
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
61
+ mask = col < N
62
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
63
+
64
+ alpha = tl.load(Alpha).to(tl.float32)
65
+ da = 0.0
66
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
67
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
68
+ if HAVE_BETA:
69
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
71
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
72
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
73
+ tanh_x = tanh(alpha * x)
74
+ if HAVE_BETA:
75
+ db += dy
76
+ dg += dy * tanh_x
77
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
78
+ da += tl.sum(x * tmp, 0)
79
+ dx = alpha * tmp
80
+ tl.store(DX + row_id * N + col, dx, mask=mask)
81
+
82
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
83
+ if HAVE_BETA:
84
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
85
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
133
86
 
134
87
 
135
88
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
89
+ assert x.is_contiguous()
90
+ HAVE_BETA = True if beta is not None else False
91
+ input_shape = x.shape
92
+ x = x.view(-1, input_shape[-1])
93
+ M, N = x.shape
94
+
140
95
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
96
+
97
+ if N >= 4096:
98
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
99
+ else:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
101
+
102
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
103
+ _dyt_fwd_kernel[(grid)](
104
+ x,
105
+ y,
106
+ alpha,
107
+ gamma,
108
+ beta,
109
+ HAVE_BETA,
110
+ N,
111
+ **kwargs,
153
112
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
113
+ return y.view(input_shape)
114
+
115
+
116
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
117
+ assert dy.is_contiguous()
118
+ input_shape = x.shape
119
+ x = x.view(-1, input_shape[-1])
120
+ M, N = x.shape
121
+ HAVE_BETA = True if beta is not None else False
122
+
166
123
  device = infer_device()
167
124
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
125
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
169
126
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
127
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128
+
129
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
132
+ dx = torch.empty_like(dy)
133
+
134
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
135
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
136
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
137
+ if HAVE_BETA:
138
+ db = db.sum(0).to(x.dtype)
139
+ dg = dg.sum(0).to(gamma.dtype)
140
+ da = da.sum().to(x.dtype).unsqueeze(0)
141
+ return dx.view(input_shape), da, dg, db
204
142
 
205
143
 
206
144
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +146,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
146
  @ensure_contiguous
209
147
  def forward(ctx, x, alpha, gamma, beta):
210
148
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
149
+ ctx.save_for_backward(x, alpha, gamma, beta)
212
150
  return y
213
151
 
214
152
  @staticmethod
215
153
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
154
+ def backward(ctx, dy):
155
+ x, alpha, gamma, beta = ctx.saved_tensors
156
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
157
+ return dx, dalpha, dgamma, dbeta