liger-kernel-nightly 0.5.5.dev20250322021112__py3-none-any.whl → 0.5.5.dev20250324181221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
2
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
3
4
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
4
5
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
5
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
+
16
+ def forward(self, x):
17
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
+
19
+ def extra_repr(self):
20
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}"
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
4
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
7
  from liger_kernel.ops.geglu import LigerGELUMulFunction
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
192
193
 
193
194
  def liger_swiglu(a, b):
194
195
  return LigerSiLUMulFunction.apply(a, b)
196
+
197
+
198
+ def liger_dyt(x, alpha, gamma, beta):
199
+ return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250322021112
3
+ Version: 0.5.5.dev20250324181221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -17,6 +17,7 @@ liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmM
17
17
  liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
18
18
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  liger_kernel/ops/cross_entropy.py,sha256=yKKhN63I7r9NxJye4wTLBvvKAyrXQt6jf4nBo3lJyVg,18860
20
+ liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
20
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
21
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
22
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
@@ -32,10 +33,11 @@ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
32
33
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
33
34
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
34
35
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
35
- liger_kernel/transformers/__init__.py,sha256=34zWr2C9sg0H5ok6l-pNB6eZQr5W1w8Xl_hMbQWzEUY,2460
36
+ liger_kernel/transformers/__init__.py,sha256=eGCDpnvIBX7bhE_jGo5RRBipwT62WE_obzlniedNzt8,2525
36
37
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
37
38
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
38
- liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
39
+ liger_kernel/transformers/dyt.py,sha256=QMqqc14pkE0WhpRZvapfnNAun-6C0C_tHExL2ZJuCUA,648
40
+ liger_kernel/transformers/functional.py,sha256=4h9Pdx_iINBqfv2Zod_c27qOpYXDDwbdVgatQ9_XBmI,5089
39
41
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
40
42
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
41
43
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -69,9 +71,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
69
71
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
70
72
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
71
73
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
72
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
73
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/METADATA,sha256=t43xPpqj832O9qh6D1IfCqO06dNiO8VVJEUMvDzgs6U,22959
74
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
75
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
76
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
77
- liger_kernel_nightly-0.5.5.dev20250322021112.dist-info/RECORD,,
74
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
75
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/METADATA,sha256=NyKmdw6KevABFKKrqEdmIf8agklqARr8azTzS4RRx0k,22959
76
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
77
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
78
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
79
+ liger_kernel_nightly-0.5.5.dev20250324181221.dist-info/RECORD,,