liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/dyt.py CHANGED
@@ -4,12 +4,13 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import calculate_settings
8
7
  from liger_kernel.ops.utils import compare_version
9
8
  from liger_kernel.ops.utils import ensure_contiguous
10
9
  from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
11
12
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
14
  try:
14
15
  # typical import path with dispatch available
15
16
  from triton.language.extra.libdevice import tanh
@@ -20,187 +21,127 @@ else:
20
21
  from triton.language.math import tanh
21
22
 
22
23
 
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
23
30
  @triton.jit
24
- def _dyt_fwd_kernel(
25
- x_ptr,
26
- x_row_stride,
27
- alpha_ptr,
28
- gamma_ptr,
29
- beta_ptr,
30
- y_ptr,
31
- y_row_stride,
32
- n_cols,
33
- BLOCK_SIZE: tl.constexpr,
34
- ):
35
- """
36
- Reference:
37
- https://arxiv.org/abs/2503.10622
38
-
39
- Shapes:
40
- - x: (BT, C)
41
- - alpha: (1)
42
- - gamma: (C)
43
- - beta: (C)
44
- """
45
- row_idx = tl.program_id(0)
46
- offsets = tl.arange(0, BLOCK_SIZE)
47
- mask = offsets < n_cols
48
-
49
- x_ptr += row_idx * x_row_stride
50
- y_ptr += row_idx * y_row_stride
51
-
52
- alpha = tl.load(alpha_ptr)
53
- gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
- beta = tl.load(beta_ptr + offsets, mask=mask)
55
- x = tl.load(x_ptr + offsets, mask=mask)
56
- y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
- tl.store(y_ptr + offsets, y, mask=mask)
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
58
41
 
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
59
43
 
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
60
58
  @triton.jit
61
59
  def _dyt_bwd_kernel(
62
- x_ptr,
63
- x_row_stride,
64
- dy_ptr,
65
- dy_row_stride,
66
- dx_ptr,
67
- dx_row_stride,
68
- alpha_ptr,
69
- dalpha_ptr,
70
- gamma_ptr,
71
- dgamma_ptr,
72
- dgamma_row_stride,
73
- n_cols,
74
- n_rows,
75
- ROWS_PER_PROGRAM: tl.constexpr,
76
- BLOCK_SIZE: tl.constexpr,
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
77
61
  ):
78
- """
79
- Reference:
80
- https://arxiv.org/abs/2503.10622
81
-
82
- Shapes:
83
- - x: (BT, C)
84
- - alpha: (1)
85
- - gamma: (C)
86
- - dx: (BT, C)
87
- - dy: (BT, C)
88
- - dgamma: (sm_count, C)
89
- - dalpha: (sm_count,)
90
- """
91
- # d(gamma * tanh(alpha * x) + beta) / dx
92
- # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
- # d(gamma * tanh(alpha * x) + beta) / dalpha
94
- # = gamma * (1 - tanh^2(alpha * x)) * x
95
- # d(gamma * tanh(alpha * x) + beta) / dgamma
96
- # = tanh(alpha * x)
97
- # d(gamma * tanh(alpha * x)) / dbeta = 1
98
- pid = tl.program_id(0)
99
-
100
- row_start = pid * ROWS_PER_PROGRAM
101
- row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
- offsets = tl.arange(0, BLOCK_SIZE)
103
- mask = offsets < n_cols
104
-
105
- dalpha = 0.0
106
- dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
-
108
- x_ptr += row_start * x_row_stride
109
- dx_ptr += row_start * dx_row_stride
110
- dy_ptr += row_start * dy_row_stride
111
- alpha = tl.load(alpha_ptr)
112
- gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
-
114
- for _ in tl.range(row_start, row_end):
115
- dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
- tanh_ax = tanh((alpha * x).cast(tl.float32))
118
- sech2_ax = 1 - tanh_ax * tanh_ax
119
-
120
- dx = dy * gamma * sech2_ax * alpha
121
- dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
- dgamma += dy * tanh_ax
123
- tl.store(dx_ptr + offsets, dx, mask=mask)
124
-
125
- dy_ptr += dy_row_stride
126
- x_ptr += x_row_stride
127
- dx_ptr += dx_row_stride
128
-
129
- tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
- tl.store(dalpha_ptr + pid, dalpha)
131
-
132
- pass
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
133
88
 
134
89
 
135
90
  def liger_dyt_fwd(x, alpha, gamma, beta):
136
- shape = x.shape
137
- dim = shape[-1]
138
- x = x.view(-1, dim)
139
- n_rows, n_cols = x.shape
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
140
97
  y = torch.empty_like(x)
141
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
- _dyt_fwd_kernel[(n_rows,)](
143
- x_ptr=x,
144
- alpha_ptr=alpha,
145
- gamma_ptr=gamma,
146
- beta_ptr=beta,
147
- y_ptr=y,
148
- x_row_stride=x.stride(0),
149
- y_row_stride=y.stride(0),
150
- n_cols=n_cols,
151
- BLOCK_SIZE=BLOCK_SIZE,
152
- num_warps=num_warps,
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
153
114
  )
154
- return y.view(*shape)
155
-
156
-
157
- def liger_dyt_bwd(dy, x, alpha, gamma):
158
- shape = dy.shape
159
- dtype = x.dtype
160
- dim = shape[-1]
161
- dy = dy.view(-1, dim)
162
- x = x.view(-1, dim)
163
- n_rows, n_cols = dy.shape
164
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
- sm_count = 1
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
166
125
  device = infer_device()
167
126
  if device == "cuda":
168
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
169
128
  elif device == "xpu":
170
- sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
- if n_cols > BLOCK_SIZE:
172
- raise RuntimeError(
173
- f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
- )
175
-
176
- dx = torch.empty_like(x, dtype=torch.float32)
177
- _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
- _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
-
180
- grid = (sm_count,)
181
- rows_per_program = triton.cdiv(n_rows, sm_count)
182
- _dyt_bwd_kernel[grid](
183
- x_ptr=x,
184
- x_row_stride=x.stride(0),
185
- dy_ptr=dy,
186
- dy_row_stride=dy.stride(0),
187
- dx_ptr=dx,
188
- dx_row_stride=dx.stride(0),
189
- alpha_ptr=alpha,
190
- dalpha_ptr=_dalpha,
191
- gamma_ptr=gamma,
192
- dgamma_ptr=_dgamma,
193
- dgamma_row_stride=_dgamma.stride(0),
194
- n_cols=n_cols,
195
- n_rows=n_rows,
196
- ROWS_PER_PROGRAM=rows_per_program,
197
- BLOCK_SIZE=BLOCK_SIZE,
198
- num_warps=num_warps,
199
- )
200
- dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
- dgamma = _dgamma.sum(dim=0).to(dtype)
202
- dbeta = dy.sum(dim=0).to(dtype)
203
- return dx.view(*shape), dalpha, dgamma, dbeta
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
132
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
133
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
134
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
135
+ dx = torch.empty_like(dy)
136
+
137
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
138
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
139
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
140
+ if HAVE_BETA:
141
+ db = db.sum(0).to(x.dtype)
142
+ dg = dg.sum(0).to(gamma.dtype)
143
+ da = da.sum().to(x.dtype).unsqueeze(0)
144
+ return dx.view(input_shape), da, dg, db
204
145
 
205
146
 
206
147
  class LigerDyTFunction(torch.autograd.Function):
@@ -208,18 +149,12 @@ class LigerDyTFunction(torch.autograd.Function):
208
149
  @ensure_contiguous
209
150
  def forward(ctx, x, alpha, gamma, beta):
210
151
  y = liger_dyt_fwd(x, alpha, gamma, beta)
211
- ctx.save_for_backward(x, alpha, gamma)
152
+ ctx.save_for_backward(x, alpha, gamma, beta)
212
153
  return y
213
154
 
214
155
  @staticmethod
215
156
  @ensure_contiguous
216
- def backward(ctx, grad_output):
217
- x, alpha, gamma = ctx.saved_tensors
218
- dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
- grad_output,
220
- x,
221
- alpha,
222
- gamma,
223
- )
224
-
225
- return (dx, dalpha, dgamma, dbeta)
157
+ def backward(ctx, dy):
158
+ x, alpha, gamma, beta = ctx.saved_tensors
159
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
160
+ return dx, dalpha, dgamma, dbeta