liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  6. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  7. liger_kernel/ops/geglu.py +1 -1
  8. liger_kernel/ops/layer_norm.py +126 -89
  9. liger_kernel/ops/multi_token_attention.py +207 -0
  10. liger_kernel/ops/rms_norm.py +267 -56
  11. liger_kernel/ops/rope.py +1 -1
  12. liger_kernel/ops/softmax.py +201 -0
  13. liger_kernel/ops/sparsemax.py +62 -50
  14. liger_kernel/ops/swiglu.py +1 -1
  15. liger_kernel/transformers/__init__.py +8 -0
  16. liger_kernel/transformers/functional.py +67 -0
  17. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  18. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  19. liger_kernel/transformers/model/gemma.py +25 -8
  20. liger_kernel/transformers/model/gemma2.py +27 -8
  21. liger_kernel/transformers/model/gemma3.py +63 -99
  22. liger_kernel/transformers/model/glm4.py +16 -7
  23. liger_kernel/transformers/model/llama.py +25 -7
  24. liger_kernel/transformers/model/llama4.py +108 -0
  25. liger_kernel/transformers/model/llava.py +95 -124
  26. liger_kernel/transformers/model/mistral.py +13 -8
  27. liger_kernel/transformers/model/mixtral.py +16 -7
  28. liger_kernel/transformers/model/mllama.py +16 -7
  29. liger_kernel/transformers/model/olmo2.py +16 -7
  30. liger_kernel/transformers/model/paligemma.py +8 -1
  31. liger_kernel/transformers/model/phi3.py +25 -8
  32. liger_kernel/transformers/model/qwen2.py +24 -7
  33. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  34. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  35. liger_kernel/transformers/model/qwen3.py +11 -3
  36. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  37. liger_kernel/transformers/model/smollm3.py +189 -0
  38. liger_kernel/transformers/monkey_patch.py +389 -82
  39. liger_kernel/transformers/multi_token_attention.py +64 -0
  40. liger_kernel/transformers/rms_norm.py +40 -4
  41. liger_kernel/transformers/softmax.py +12 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
  44. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
  45. liger_kernel/transformers/gema3_rms.py +0 -8
  46. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
  47. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
  48. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from typing import Tuple
2
+
1
3
  import torch
2
4
  import triton
3
5
  import triton.language as tl
@@ -105,63 +107,73 @@ def _sparsemax_backward_kernel(
105
107
  tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
108
 
107
109
 
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
108
165
  class LigerSparsemaxFunction(torch.autograd.Function):
109
166
  @staticmethod
110
167
  @ensure_contiguous
111
168
  def forward(ctx, x: torch.Tensor, dim: int):
112
- if dim < 0:
113
- dim += x.dim()
114
- ctx.dim = dim
115
-
116
- x_sw = x.transpose(dim, -1).contiguous()
117
- n_cols = x_sw.size(-1)
118
- n_rows = x_sw.numel() // n_cols
119
- x_flat = x_sw.view(n_rows, n_cols)
120
-
121
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
- out_flat = torch.empty_like(x_flat)
123
- grid = (n_rows,)
124
-
125
- x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
-
127
- _sparsemax_forward_kernel[grid](
128
- x_flat,
129
- x_flat.stride(0),
130
- x_sorted_flat,
131
- x_sorted_flat.stride(0),
132
- out_flat,
133
- out_flat.stride(0),
134
- n_cols,
135
- BLOCK_SIZE=BLOCK_SIZE,
136
- num_warps=num_warps,
137
- )
138
-
169
+ y, out_flat = _sparsemax_forward(x, dim)
139
170
  ctx.save_for_backward(out_flat)
140
- return out_flat.view_as(x_sw).transpose(dim, -1)
171
+ ctx.dim = dim
172
+ return y
141
173
 
142
174
  @staticmethod
143
175
  @ensure_contiguous
144
176
  def backward(ctx, grad_out: torch.Tensor):
145
177
  (out_flat,) = ctx.saved_tensors
146
- dim = ctx.dim
147
-
148
- go_sw = grad_out.transpose(dim, -1).contiguous()
149
- n_cols = go_sw.size(-1)
150
- n_rows = go_sw.numel() // n_cols
151
- go_flat = go_sw.view(n_rows, n_cols)
152
-
153
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
- gi_flat = torch.empty_like(go_flat)
155
- grid = (n_rows,)
156
-
157
- _sparsemax_backward_kernel[grid](
158
- out_flat,
159
- go_flat,
160
- gi_flat,
161
- out_flat.stride(0),
162
- n_cols,
163
- BLOCK_SIZE=BLOCK_SIZE,
164
- num_warps=num_warps,
165
- )
166
-
167
- return gi_flat.view_as(go_sw).transpose(dim, -1), None
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
5
5
  # Always-safe imports (independent of 'transformers')
6
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
7
7
  from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
8
9
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
9
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
10
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
30
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
31
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
32
33
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
33
35
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
34
36
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
35
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
@@ -42,6 +44,7 @@ if TYPE_CHECKING:
42
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
43
45
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
44
46
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
45
48
 
46
49
 
47
50
  # Check if 'transformers' is installed
@@ -87,6 +90,7 @@ def __getattr__(name: str):
87
90
  "apply_liger_kernel_to_granite",
88
91
  "apply_liger_kernel_to_llama",
89
92
  "apply_liger_kernel_to_llava",
93
+ "apply_liger_kernel_to_llama4",
90
94
  "apply_liger_kernel_to_mistral",
91
95
  "apply_liger_kernel_to_mixtral",
92
96
  "apply_liger_kernel_to_mllama",
@@ -98,6 +102,7 @@ def __getattr__(name: str):
98
102
  "apply_liger_kernel_to_qwen2_vl",
99
103
  "apply_liger_kernel_to_qwen3",
100
104
  "apply_liger_kernel_to_qwen3_moe",
105
+ "apply_liger_kernel_to_smollm3",
101
106
  }
102
107
 
103
108
  if name in monkey_patch_symbols:
@@ -117,6 +122,7 @@ __all__ = [
117
122
  "LigerGEGLUMLP",
118
123
  "LigerJSD",
119
124
  "LigerLayerNorm",
125
+ "LigerFusedAddRMSNorm",
120
126
  "LigerRMSNorm",
121
127
  "liger_rotary_pos_emb",
122
128
  "LigerBlockSparseTop2MLP",
@@ -141,6 +147,7 @@ if _TRANSFORMERS_AVAILABLE:
141
147
  "apply_liger_kernel_to_granite",
142
148
  "apply_liger_kernel_to_llama",
143
149
  "apply_liger_kernel_to_llava",
150
+ "apply_liger_kernel_to_llama4",
144
151
  "apply_liger_kernel_to_mistral",
145
152
  "apply_liger_kernel_to_mixtral",
146
153
  "apply_liger_kernel_to_mllama",
@@ -152,5 +159,6 @@ if _TRANSFORMERS_AVAILABLE:
152
159
  "apply_liger_kernel_to_qwen2_vl",
153
160
  "apply_liger_kernel_to_qwen3",
154
161
  "apply_liger_kernel_to_qwen3_moe",
162
+ "apply_liger_kernel_to_smollm3",
155
163
  ]
156
164
  )
@@ -2,16 +2,20 @@ from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
4
  from liger_kernel.ops.dyt import LigerDyTFunction
5
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
5
6
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
7
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
8
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
7
9
  from liger_kernel.ops.geglu import LigerGELUMulFunction
8
10
  from liger_kernel.ops.group_norm import LigerGroupNormFunction
9
11
  from liger_kernel.ops.jsd import LigerJSDFunction
10
12
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
11
13
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
14
+ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
12
15
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
16
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
17
  from liger_kernel.ops.rope import LigerRopeFunction
18
+ from liger_kernel.ops.softmax import LigerSoftmaxFunction
15
19
  from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
16
20
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
17
21
  from liger_kernel.ops.tvd import LigerTVDLossFunction
@@ -167,6 +171,61 @@ def liger_sparsemax(
167
171
  return LigerSparsemaxFunction.apply(input, dim)
168
172
 
169
173
 
174
+ def liger_multi_token_attention(
175
+ scores,
176
+ weight,
177
+ bias=None,
178
+ stride: int = 1,
179
+ padding: int = 0,
180
+ dilation: int = 1,
181
+ groups: int = 1,
182
+ sparse: bool = False,
183
+ ):
184
+ """
185
+ Functional interface for multi-token attention.
186
+
187
+ Args:
188
+ scores: Input tensor of shape (B, C_in, L, L)
189
+ weight: Convolution weight tensor of shape (C_out, C_in // groups, K, K)
190
+ bias: Optional bias tensor of shape (C_out,)
191
+ stride: Stride for the convolution (default: 1)
192
+ padding: Padding for the convolution (default: 0)
193
+ dilation: Dilation factor for the convolution (default: 1)
194
+ groups: Number of groups for the convolution (default: 1)
195
+ sparse: Specifies if input tensors are expected to be sparse (default: False)
196
+ Returns:
197
+ Output tensor after applying multi-token attention.
198
+ """
199
+ return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
200
+
201
+
202
+ def liger_fused_neighborhood_attention(
203
+ query,
204
+ key,
205
+ value,
206
+ kernel_size: int = 7,
207
+ dilation: int = 1,
208
+ scale: float = None,
209
+ ):
210
+ """
211
+ Liger fused neighborhood attention.
212
+
213
+ paper: https://arxiv.org/pdf/2504.16922
214
+
215
+ Args:
216
+ query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
217
+ key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
218
+ value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
219
+ kernel_size: Size of the neighborhood window (default: 7)
220
+ dilation: Dilation factor for the neighborhood (default: 1)
221
+ scale: Scaling factor for attention scores (default: rsqrt(head_dim))
222
+
223
+ Returns:
224
+ Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
225
+ """
226
+ return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
227
+
228
+
170
229
  def liger_tvd(
171
230
  input,
172
231
  target,
@@ -195,6 +254,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
195
254
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
196
255
 
197
256
 
257
+ def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
258
+ return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
259
+
260
+
198
261
  def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
199
262
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
200
263
 
@@ -203,5 +266,9 @@ def liger_swiglu(a, b):
203
266
  return LigerSiLUMulFunction.apply(a, b)
204
267
 
205
268
 
269
+ def liger_softmax(x):
270
+ return LigerSoftmaxFunction.apply(x)
271
+
272
+
206
273
  def liger_dyt(x, alpha, gamma, beta):
207
274
  return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
5
+
6
+
7
+ class LigerFusedAddRMSNorm(nn.Module):
8
+ def __init__(
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=False,
16
+ ):
17
+ super().__init__()
18
+ assert init_fn in [
19
+ "ones",
20
+ "zeros",
21
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
23
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (eps, offset, casting_mode, in_place)
24
+
25
+ def forward(self, hidden_states, residual):
26
+ return LigerFusedAddRMSNormFunction.apply(
27
+ hidden_states,
28
+ residual,
29
+ self.weight,
30
+ self.variance_epsilon,
31
+ self.offset,
32
+ self.casting_mode,
33
+ self.in_place,
34
+ )
35
+
36
+ def extra_repr(self):
37
+ return (
38
+ f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
39
+ )
@@ -0,0 +1,234 @@
1
+ import math
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
9
+
10
+
11
+ class LigerFusedNeighborhoodAttention(nn.Module):
12
+ """
13
+ Liger Fused Neighborhood Attention Module.
14
+
15
+ Paper: https://arxiv.org/pdf/2504.16922
16
+
17
+ Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
18
+ around each position, reducing computational complexity from O(n²) to O(n*k)
19
+ where k is the neighborhood size.
20
+
21
+ Args:
22
+ hidden_size (int): The hidden dimension size
23
+ num_heads (int): Number of attention heads
24
+ kernel_size (int): Size of the neighborhood window (default: 7)
25
+ dilation (int): Dilation factor for the neighborhood (default: 1)
26
+ bias (bool): Whether to use bias in linear projections (default: True)
27
+ dropout (float): Dropout probability (default: 0.0)
28
+ scale (Optional[float]): Scaling factor for attention scores.
29
+ If None, uses 1/sqrt(head_dim) (default: None)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ num_heads: int,
36
+ kernel_size: int = 7,
37
+ dilation: int = 1,
38
+ bias: bool = True,
39
+ dropout: float = 0.0,
40
+ scale: Optional[float] = None,
41
+ ):
42
+ super().__init__()
43
+
44
+ if hidden_size % num_heads != 0:
45
+ raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
46
+
47
+ if kernel_size <= 0:
48
+ raise ValueError(f"kernel_size ({kernel_size}) must be positive")
49
+
50
+ if kernel_size % 2 == 0:
51
+ raise ValueError(f"kernel_size ({kernel_size}) must be odd")
52
+
53
+ if dilation < 1:
54
+ raise ValueError(f"dilation ({dilation}) must be positive")
55
+
56
+ self.hidden_size = hidden_size
57
+ self.num_heads = num_heads
58
+ self.head_dim = hidden_size // num_heads
59
+ self.kernel_size = kernel_size
60
+ self.dilation = dilation
61
+ self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
62
+ self.dropout_p = dropout
63
+
64
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
65
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
66
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
67
+
68
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
69
+
70
+ if dropout > 0.0:
71
+ self.dropout = nn.Dropout(dropout)
72
+ else:
73
+ self.dropout = None
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ attention_mask: Optional[torch.Tensor] = None,
79
+ ) -> torch.Tensor:
80
+ """
81
+ Forward pass of the fused neighborhood attention module.
82
+
83
+ Args:
84
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
85
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
86
+
87
+ Returns:
88
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
89
+ """
90
+ if attention_mask is not None:
91
+ raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
92
+
93
+ batch_size, seq_len, hidden_size = hidden_states.shape
94
+
95
+ query = self.q_proj(hidden_states)
96
+ key = self.k_proj(hidden_states)
97
+ value = self.v_proj(hidden_states)
98
+
99
+ query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
100
+ key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
101
+ value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
102
+
103
+ attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
104
+ query, key, value, self.kernel_size, self.dilation, self.scale
105
+ )
106
+
107
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
108
+
109
+ if self.dropout is not None:
110
+ attn_output = self.dropout(attn_output)
111
+
112
+ output = self.out_proj(attn_output)
113
+
114
+ return output
115
+
116
+ def extra_repr(self) -> str:
117
+ return (
118
+ f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
119
+ f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
120
+ f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
121
+ )
122
+
123
+
124
+ class LigerFusedNeighborhoodAttentionLayer(nn.Module):
125
+ """
126
+ A complete neighborhood attention layer with layer norm and residual connection.
127
+
128
+ Args:
129
+ hidden_size (int): The hidden dimension size
130
+ num_heads (int): Number of attention heads
131
+ kernel_size (int): Size of the neighborhood window (default: 7)
132
+ dilation (int): Dilation factor for the neighborhood (default: 1)
133
+ bias (bool): Whether to use bias in linear projections (default: True)
134
+ dropout (float): Dropout probability (default: 0.0)
135
+ layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
136
+ scale (Optional[float]): Scaling factor for attention scores (default: None)
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ hidden_size: int,
142
+ num_heads: int,
143
+ kernel_size: int = 7,
144
+ dilation: int = 1,
145
+ bias: bool = True,
146
+ dropout: float = 0.0,
147
+ layer_norm_eps: float = 1e-5,
148
+ scale: Optional[float] = None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.attention = LigerFusedNeighborhoodAttention(
153
+ hidden_size=hidden_size,
154
+ num_heads=num_heads,
155
+ kernel_size=kernel_size,
156
+ dilation=dilation,
157
+ bias=bias,
158
+ dropout=dropout,
159
+ scale=scale,
160
+ )
161
+
162
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
163
+
164
+ if dropout > 0.0:
165
+ self.dropout = nn.Dropout(dropout)
166
+ else:
167
+ self.dropout = None
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ ) -> torch.Tensor:
174
+ """
175
+ Forward pass with residual connection and layer normalization.
176
+
177
+ Args:
178
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
179
+ attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
180
+
181
+ Returns:
182
+ torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
183
+ """
184
+ normed_hidden_states = self.layer_norm(hidden_states)
185
+
186
+ attn_output = self.attention(normed_hidden_states, attention_mask)
187
+
188
+ if self.dropout is not None:
189
+ attn_output = self.dropout(attn_output)
190
+
191
+ output = hidden_states + attn_output
192
+
193
+ return output
194
+
195
+
196
+ class LigerFusedNeighborhoodAttentionConfig:
197
+ """
198
+ Configuration class for Fused Neighborhood Attention.
199
+
200
+ This can be used to easily configure neighborhood attention parameters
201
+ for different model architectures.
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size: int = 768,
207
+ num_heads: int = 12,
208
+ kernel_size: int = 7,
209
+ dilation: int = 1,
210
+ bias: bool = True,
211
+ dropout: float = 0.0,
212
+ layer_norm_eps: float = 1e-5,
213
+ scale: Optional[float] = None,
214
+ ):
215
+ self.hidden_size = hidden_size
216
+ self.num_heads = num_heads
217
+ self.kernel_size = kernel_size
218
+ self.dilation = dilation
219
+ self.bias = bias
220
+ self.dropout = dropout
221
+ self.layer_norm_eps = layer_norm_eps
222
+ self.scale = scale
223
+
224
+ def to_dict(self):
225
+ return {
226
+ "hidden_size": self.hidden_size,
227
+ "num_heads": self.num_heads,
228
+ "kernel_size": self.kernel_size,
229
+ "dilation": self.dilation,
230
+ "bias": self.bias,
231
+ "dropout": self.dropout,
232
+ "layer_norm_eps": self.layer_norm_eps,
233
+ "scale": self.scale,
234
+ }
@@ -27,6 +27,7 @@ def lce_forward_deprecated(
27
27
  output_hidden_states: Optional[bool] = None,
28
28
  return_dict: Optional[bool] = None,
29
29
  cache_position: Optional[torch.LongTensor] = None,
30
+ skip_logits: Optional[bool] = None,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
 
@@ -81,7 +82,14 @@ def lce_forward_deprecated(
81
82
  loss = None
82
83
  logits = None
83
84
 
84
- if self.training and (labels is not None):
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels is None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ if skip_logits:
85
93
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
86
94
  shift_labels = labels[..., 1:].contiguous()
87
95
 
@@ -137,7 +145,8 @@ def lce_forward(
137
145
  return_dict: Optional[bool] = None,
138
146
  cache_position: Optional[torch.LongTensor] = None,
139
147
  logits_to_keep: Union[int, torch.Tensor] = 0,
140
- **loss_kwargs,
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
141
150
  ) -> Union[Tuple, CausalLMOutputWithPast]:
142
151
  r"""
143
152
  Args:
@@ -189,6 +198,7 @@ def lce_forward(
189
198
  output_hidden_states=output_hidden_states,
190
199
  return_dict=return_dict,
191
200
  cache_position=cache_position,
201
+ **kwargs,
192
202
  )
193
203
 
194
204
  hidden_states = outputs[0]
@@ -196,27 +206,34 @@ def lce_forward(
196
206
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
197
207
  kept_hidden_states = hidden_states[:, slice_indices, :]
198
208
 
199
- shift_labels = loss_kwargs.pop("shift_labels", None)
209
+ shift_labels = kwargs.pop("shift_labels", None)
200
210
  logits = None
201
211
  loss = None
202
- # if in training mode, don't materialize logits
203
- if self.training and (labels is not None or shift_labels is not None):
212
+
213
+ if skip_logits and labels is None and shift_labels is None:
214
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
215
+
216
+ if skip_logits is None:
217
+ # By default, if in training mode, don't materialize logits
218
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
219
+
220
+ if skip_logits:
204
221
  loss = LigerForCausalLMLoss(
205
222
  hidden_states=kept_hidden_states,
206
223
  lm_head_weight=self.lm_head.weight,
207
224
  labels=labels,
208
225
  shift_labels=shift_labels,
209
226
  hidden_size=self.config.hidden_size,
210
- **loss_kwargs,
227
+ **kwargs,
211
228
  )
212
- else: # if in inference mode materialize logits
229
+ else:
213
230
  logits = self.lm_head(kept_hidden_states)
214
231
  if labels is not None:
215
232
  loss = self.loss_function(
216
233
  logits=logits,
217
234
  labels=labels,
218
235
  vocab_size=self.config.vocab_size,
219
- **loss_kwargs,
236
+ **kwargs,
220
237
  )
221
238
 
222
239
  if not return_dict: