liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _softmax_single_block_forward_kernel(
13
+ Y_ptr,
14
+ Y_row_stride,
15
+ X_ptr,
16
+ X_row_stride,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ ):
20
+ row_id = tl.program_id(0)
21
+ offs = tl.arange(0, BLOCK_SIZE)
22
+ mask = offs < n_cols
23
+
24
+ x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
25
+ m = tl.max(x, axis=0)
26
+ e = tl.exp(x - m)
27
+ d = tl.sum(e, axis=0)
28
+ y = e / d
29
+ tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
30
+
31
+
32
+ @triton.jit
33
+ def _softmax_multi_block_forward_kernel(
34
+ Y_ptr,
35
+ Y_row_stride,
36
+ X_ptr,
37
+ X_row_stride,
38
+ n_cols,
39
+ BLOCK_SIZE: tl.constexpr,
40
+ ):
41
+ row_id = tl.program_id(0)
42
+ offs = tl.arange(0, BLOCK_SIZE)
43
+
44
+ m = tl.float32(-float("inf"))
45
+ d = tl.float32(0.0)
46
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
47
+ idx = start + offs
48
+ mask = idx < n_cols
49
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
50
+ blk_max = tl.max(xblk, axis=0)
51
+ new_m = tl.max(m, blk_max)
52
+ d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
53
+ m = new_m
54
+
55
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
56
+ idx = start + offs
57
+ mask = idx < n_cols
58
+ xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
59
+ yblk = tl.exp(xblk - m) / d
60
+ tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
61
+
62
+
63
+ @triton.jit
64
+ def _softmax_single_block_backward_kernel(
65
+ dy_ptr,
66
+ dy_stride,
67
+ y_ptr,
68
+ y_stride,
69
+ dx_ptr,
70
+ dx_stride,
71
+ n_cols,
72
+ BLOCK_SIZE: tl.constexpr,
73
+ ):
74
+ row_id = tl.program_id(0)
75
+ offs = tl.arange(0, BLOCK_SIZE)
76
+ mask = offs < n_cols
77
+
78
+ dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
79
+ y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
80
+ dot = tl.sum(dy * y, axis=0)
81
+ dx = y * (dy - dot)
82
+ tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
83
+
84
+
85
+ @triton.jit
86
+ def _softmax_multi_block_backward_kernel(
87
+ dy_ptr,
88
+ dy_stride,
89
+ y_ptr,
90
+ y_stride,
91
+ dx_ptr,
92
+ dx_stride,
93
+ n_cols,
94
+ BLOCK_SIZE: tl.constexpr,
95
+ ):
96
+ row_id = tl.program_id(0)
97
+ offs = tl.arange(0, BLOCK_SIZE)
98
+ acc = tl.float32(0.0)
99
+
100
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
101
+ idx = start + offs
102
+ mask = idx < n_cols
103
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
104
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
105
+ acc += tl.sum(dy_blk * y_blk, axis=0)
106
+
107
+ for start in tl.range(0, n_cols, BLOCK_SIZE):
108
+ idx = start + offs
109
+ mask = idx < n_cols
110
+ dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
111
+ y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
112
+ dx_blk = y_blk * (dy_blk - acc)
113
+ tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
114
+
115
+
116
+ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
117
+ *batch, n_cols = x.shape
118
+ x2d = x.contiguous().view(-1, n_cols)
119
+ n_rows = x2d.shape[0]
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ y2d = torch.empty_like(x2d)
123
+
124
+ if n_cols <= BLOCK_SIZE:
125
+ _softmax_single_block_forward_kernel[(n_rows,)](
126
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
127
+ )
128
+ multi_block_launch = False
129
+ else:
130
+ _softmax_multi_block_forward_kernel[(n_rows,)](
131
+ y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
132
+ )
133
+ multi_block_launch = True
134
+
135
+ return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
136
+
137
+
138
+ def _softmax_backward(
139
+ dy: torch.Tensor,
140
+ y: torch.Tensor,
141
+ BLOCK_SIZE: int,
142
+ num_warps: int,
143
+ multi_block_launch: bool,
144
+ ) -> torch.Tensor:
145
+ *batch, n_cols = dy.shape
146
+ dy2d = dy.contiguous().view(-1, n_cols)
147
+ y2d = y.contiguous().view(-1, n_cols)
148
+ n_rows = dy2d.shape[0]
149
+ dx2d = torch.empty_like(dy2d)
150
+
151
+ if not multi_block_launch and n_cols <= BLOCK_SIZE:
152
+ _softmax_single_block_backward_kernel[(n_rows,)](
153
+ dy2d,
154
+ dy2d.stride(0),
155
+ y2d,
156
+ y2d.stride(0),
157
+ dx2d,
158
+ dx2d.stride(0),
159
+ n_cols,
160
+ BLOCK_SIZE=BLOCK_SIZE,
161
+ num_warps=num_warps,
162
+ )
163
+ else:
164
+ _softmax_multi_block_backward_kernel[(n_rows,)](
165
+ dy2d,
166
+ dy2d.stride(0),
167
+ y2d,
168
+ y2d.stride(0),
169
+ dx2d,
170
+ dx2d.stride(0),
171
+ n_cols,
172
+ BLOCK_SIZE=BLOCK_SIZE,
173
+ num_warps=num_warps,
174
+ )
175
+
176
+ return dx2d.view(*batch, n_cols)
177
+
178
+
179
+ class LigerSoftmaxFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ @ensure_contiguous
182
+ def forward(ctx, input_: torch.Tensor):
183
+ y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
184
+ ctx.save_for_backward(y)
185
+ ctx.BLOCK_SIZE = BLOCK_SIZE
186
+ ctx.num_warps = num_warps
187
+ ctx.multi_block_launch = multi_block_launch
188
+ return y
189
+
190
+ @staticmethod
191
+ @ensure_contiguous
192
+ def backward(ctx, grad_output):
193
+ (y,) = ctx.saved_tensors
194
+ dx = _softmax_backward(
195
+ grad_output,
196
+ y,
197
+ ctx.BLOCK_SIZE,
198
+ ctx.num_warps,
199
+ ctx.multi_block_launch,
200
+ )
201
+ return dx
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
@@ -105,63 +107,73 @@ def _sparsemax_backward_kernel(
105
107
  tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
108
 
107
109
 
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
108
165
  class LigerSparsemaxFunction(torch.autograd.Function):
109
166
  @staticmethod
110
167
  @ensure_contiguous
111
168
  def forward(ctx, x: torch.Tensor, dim: int):
112
- if dim < 0:
113
- dim += x.dim()
114
- ctx.dim = dim
115
-
116
- x_sw = x.transpose(dim, -1).contiguous()
117
- n_cols = x_sw.size(-1)
118
- n_rows = x_sw.numel() // n_cols
119
- x_flat = x_sw.view(n_rows, n_cols)
120
-
121
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
- out_flat = torch.empty_like(x_flat)
123
- grid = (n_rows,)
124
-
125
- x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
-
127
- _sparsemax_forward_kernel[grid](
128
- x_flat,
129
- x_flat.stride(0),
130
- x_sorted_flat,
131
- x_sorted_flat.stride(0),
132
- out_flat,
133
- out_flat.stride(0),
134
- n_cols,
135
- BLOCK_SIZE=BLOCK_SIZE,
136
- num_warps=num_warps,
137
- )
138
-
169
+ y, out_flat = _sparsemax_forward(x, dim)
139
170
  ctx.save_for_backward(out_flat)
140
- return out_flat.view_as(x_sw).transpose(dim, -1)
171
+ ctx.dim = dim
172
+ return y
141
173
 
142
174
  @staticmethod
143
175
  @ensure_contiguous
144
176
  def backward(ctx, grad_out: torch.Tensor):
145
177
  (out_flat,) = ctx.saved_tensors
146
- dim = ctx.dim
147
-
148
- go_sw = grad_out.transpose(dim, -1).contiguous()
149
- n_cols = go_sw.size(-1)
150
- n_rows = go_sw.numel() // n_cols
151
- go_flat = go_sw.view(n_rows, n_cols)
152
-
153
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
- gi_flat = torch.empty_like(go_flat)
155
- grid = (n_rows,)
156
-
157
- _sparsemax_backward_kernel[grid](
158
- out_flat,
159
- go_flat,
160
- gi_flat,
161
- out_flat.stride(0),
162
- n_cols,
163
- BLOCK_SIZE=BLOCK_SIZE,
164
- num_warps=num_warps,
165
- )
166
-
167
- return gi_flat.view_as(go_sw).transpose(dim, -1), None
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
30
30
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
31
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
32
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
33
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
33
34
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
34
35
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
35
36
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
@@ -87,6 +88,7 @@ def __getattr__(name: str):
87
88
  "apply_liger_kernel_to_granite",
88
89
  "apply_liger_kernel_to_llama",
89
90
  "apply_liger_kernel_to_llava",
91
+ "apply_liger_kernel_to_llama4",
90
92
  "apply_liger_kernel_to_mistral",
91
93
  "apply_liger_kernel_to_mixtral",
92
94
  "apply_liger_kernel_to_mllama",
@@ -141,6 +143,7 @@ if _TRANSFORMERS_AVAILABLE:
141
143
  "apply_liger_kernel_to_granite",
142
144
  "apply_liger_kernel_to_llama",
143
145
  "apply_liger_kernel_to_llava",
146
+ "apply_liger_kernel_to_llama4",
144
147
  "apply_liger_kernel_to_mistral",
145
148
  "apply_liger_kernel_to_mixtral",
146
149
  "apply_liger_kernel_to_mllama",
@@ -4,14 +4,17 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
4
  from liger_kernel.ops.dyt import LigerDyTFunction
5
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
7
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
7
8
  from liger_kernel.ops.geglu import LigerGELUMulFunction
8
9
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
9
10
  from liger_kernel.ops.jsd import LigerJSDFunction
10
11
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
11
12
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
12
14
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
15
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
16
  from liger_kernel.ops.rope import LigerRopeFunction
17
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction
15
18
  from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
16
19
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
20
  from liger_kernel.ops.tvd import LigerTVDLossFunction
@@ -167,6 +170,61 @@ def liger_sparsemax(
167
170
  return LigerSparsemaxFunction.apply(input, dim)
168
171
 
169
172
 
173
+ def liger_multi_token_attention(
174
+ scores,
175
+ weight,
176
+ bias=None,
177
+ stride: int = 1,
178
+ padding: int = 0,
179
+ dilation: int = 1,
180
+ groups: int = 1,
181
+ sparse: bool = False,
182
+ ):
183
+ """
184
+ Functional interface for multi-token attention.
185
+
186
+ Args:
187
+ scores: Input tensor of shape (B, C_in, L, L)
188
+ weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
189
+ bias: Optional bias tensor of shape (C_out,)
190
+ stride: Stride for the convolution (default: 1)
191
+ padding: Padding for the convolution (default: 0)
192
+ dilation: Dilation factor for the convolution (default: 1)
193
+ groups: Number of groups for the convolution (default: 1)
194
+ sparse: Specifies if input tensors are expected to be sparse (default: False)
195
+ Returns:
196
+ Output tensor after applying multi-token attention.
197
+ """
198
+ return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
199
+
200
+
201
+ def liger_fused_neighborhood_attention(
202
+ query,
203
+ key,
204
+ value,
205
+ kernel_size: int = 7,
206
+ dilation: int = 1,
207
+ scale: float = None,
208
+ ):
209
+ """
210
+ Liger fused neighborhood attention.
211
+
212
+ paper: https://arxiv.org/pdf/2504.16922
213
+
214
+ Args:
215
+ query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
216
+ key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
217
+ value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
218
+ kernel_size: Size of the neighborhood window (default: 7)
219
+ dilation: Dilation factor for the neighborhood (default: 1)
220
+ scale: Scaling factor for attention scores (default: rsqrt(head_dim))
221
+
222
+ Returns:
223
+ Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
224
+ """
225
+ return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
226
+
227
+
170
228
  def liger_tvd(
171
229
  input,
172
230
  target,
@@ -203,5 +261,9 @@ def liger_swiglu(a, b):
203
261
  return LigerSiLUMulFunction.apply(a, b)
204
262
 
205
263
 
264
+ def liger_softmax(x):
265
+ return LigerSoftmaxFunction.apply(x)
266
+
267
+
206
268
  def liger_dyt(x, alpha, gamma, beta):
207
269
  return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -0,0 +1,234 @@
1
+ import math
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
9
+
10
+
11
+ class LigerFusedNeighborhoodAttention(nn.Module):
12
+ """
13
+ Liger Fused Neighborhood Attention Module.
14
+
15
+ Paper: https://arxiv.org/pdf/2504.16922
16
+
17
+ Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
18
+ around each position, reducing computational complexity from O(n²) to O(n*k)
19
+ where k is the neighborhood size.
20
+
21
+ Args:
22
+ hidden_size (int): The hidden dimension size
23
+ num_heads (int): Number of attention heads
24
+ kernel_size (int): Size of the neighborhood window (default: 7)
25
+ dilation (int): Dilation factor for the neighborhood (default: 1)
26
+ bias (bool): Whether to use bias in linear projections (default: True)
27
+ dropout (float): Dropout probability (default: 0.0)
28
+ scale (Optional[float]): Scaling factor for attention scores.
29
+ If None, uses 1/sqrt(head_dim) (default: None)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ num_heads: int,
36
+ kernel_size: int = 7,
37
+ dilation: int = 1,
38
+ bias: bool = True,
39
+ dropout: float = 0.0,
40
+ scale: Optional[float] = None,
41
+ ):
42
+ super().__init__()
43
+
44
+ if hidden_size % num_heads != 0:
45
+ raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
46
+
47
+ if kernel_size <= 0:
48
+ raise ValueError(f"kernel_size ({kernel_size}) must be positive")
49
+
50
+ if kernel_size % 2 == 0:
51
+ raise ValueError(f"kernel_size ({kernel_size}) must be odd")
52
+
53
+ if dilation < 1:
54
+ raise ValueError(f"dilation ({dilation}) must be positive")
55
+
56
+ self.hidden_size = hidden_size
57
+ self.num_heads = num_heads
58
+ self.head_dim = hidden_size // num_heads
59
+ self.kernel_size = kernel_size
60
+ self.dilation = dilation
61
+ self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
62
+ self.dropout_p = dropout
63
+
64
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
65
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
66
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
67
+
68
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
69
+
70
+ if dropout > 0.0:
71
+ self.dropout = nn.Dropout(dropout)
72
+ else:
73
+ self.dropout = None
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Forward pass of the fused neighborhood attention module.
82
+
83
+ Args:
84
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
85
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
86
+
87
+ Returns:
88
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
89
+ """
90
+ if attention_mask is not None:
91
+ raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
92
+
93
+ batch_size, seq_len, hidden_size = hidden_states.shape
94
+
95
+ query = self.q_proj(hidden_states)
96
+ key = self.k_proj(hidden_states)
97
+ value = self.v_proj(hidden_states)
98
+
99
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
100
+ key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+
103
+ attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
104
+ query, key, value, self.kernel_size, self.dilation, self.scale
105
+ )
106
+
107
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
108
+
109
+ if self.dropout is not None:
110
+ attn_output = self.dropout(attn_output)
111
+
112
+ output = self.out_proj(attn_output)
113
+
114
+ return output
115
+
116
+ def extra_repr(self) -> str:
117
+ return (
118
+ f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
119
+ f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
120
+ f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
121
+ )
122
+
123
+
124
+ class LigerFusedNeighborhoodAttentionLayer(nn.Module):
125
+ """
126
+ A complete neighborhood attention layer with layer norm and residual connection.
127
+
128
+ Args:
129
+ hidden_size (int): The hidden dimension size
130
+ num_heads (int): Number of attention heads
131
+ kernel_size (int): Size of the neighborhood window (default: 7)
132
+ dilation (int): Dilation factor for the neighborhood (default: 1)
133
+ bias (bool): Whether to use bias in linear projections (default: True)
134
+ dropout (float): Dropout probability (default: 0.0)
135
+ layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
136
+ scale (Optional[float]): Scaling factor for attention scores (default: None)
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ hidden_size: int,
142
+ num_heads: int,
143
+ kernel_size: int = 7,
144
+ dilation: int = 1,
145
+ bias: bool = True,
146
+ dropout: float = 0.0,
147
+ layer_norm_eps: float = 1e-5,
148
+ scale: Optional[float] = None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.attention = LigerFusedNeighborhoodAttention(
153
+ hidden_size=hidden_size,
154
+ num_heads=num_heads,
155
+ kernel_size=kernel_size,
156
+ dilation=dilation,
157
+ bias=bias,
158
+ dropout=dropout,
159
+ scale=scale,
160
+ )
161
+
162
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
163
+
164
+ if dropout > 0.0:
165
+ self.dropout = nn.Dropout(dropout)
166
+ else:
167
+ self.dropout = None
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ ) -> torch.Tensor:
174
+ """
175
+ Forward pass with residual connection and layer normalization.
176
+
177
+ Args:
178
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
179
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
180
+
181
+ Returns:
182
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
183
+ """
184
+ normed_hidden_states = self.layer_norm(hidden_states)
185
+
186
+ attn_output = self.attention(normed_hidden_states, attention_mask)
187
+
188
+ if self.dropout is not None:
189
+ attn_output = self.dropout(attn_output)
190
+
191
+ output = hidden_states + attn_output
192
+
193
+ return output
194
+
195
+
196
+ class LigerFusedNeighborhoodAttentionConfig:
197
+ """
198
+ Configuration class for Fused Neighborhood Attention.
199
+
200
+ This can be used to easily configure neighborhood attention parameters
201
+ for different model architectures.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size: int = 768,
207
+ num_heads: int = 12,
208
+ kernel_size: int = 7,
209
+ dilation: int = 1,
210
+ bias: bool = True,
211
+ dropout: float = 0.0,
212
+ layer_norm_eps: float = 1e-5,
213
+ scale: Optional[float] = None,
214
+ ):
215
+ self.hidden_size = hidden_size
216
+ self.num_heads = num_heads
217
+ self.kernel_size = kernel_size
218
+ self.dilation = dilation
219
+ self.bias = bias
220
+ self.dropout = dropout
221
+ self.layer_norm_eps = layer_norm_eps
222
+ self.scale = scale
223
+
224
+ def to_dict(self):
225
+ return {
226
+ "hidden_size": self.hidden_size,
227
+ "num_heads": self.num_heads,
228
+ "kernel_size": self.kernel_size,
229
+ "dilation": self.dilation,
230
+ "bias": self.bias,
231
+ "dropout": self.dropout,
232
+ "layer_norm_eps": self.layer_norm_eps,
233
+ "scale": self.scale,
234
+ }