liger-kernel-nightly 0.5.10.dev20250605210201__py3-none-any.whl → 0.5.10.dev20250605224739__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/transformers/functional.py +28 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +5 -4
- liger_kernel/transformers/model/gemma2.py +7 -4
- liger_kernel/transformers/model/glm4.py +5 -4
- liger_kernel/transformers/model/llama.py +5 -4
- liger_kernel/transformers/model/mistral.py +5 -4
- liger_kernel/transformers/model/mixtral.py +5 -4
- liger_kernel/transformers/model/mllama.py +5 -4
- liger_kernel/transformers/model/olmo2.py +5 -4
- liger_kernel/transformers/model/phi3.py +5 -4
- liger_kernel/transformers/model/qwen2.py +5 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +4 -3
- liger_kernel/transformers/model/qwen2_vl.py +4 -3
- liger_kernel/transformers/model/qwen3_moe.py +5 -4
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/RECORD +22 -20
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
4
|
from liger_kernel.ops.dyt import LigerDyTFunction
|
5
5
|
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
|
6
6
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
7
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
7
8
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
8
9
|
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
9
10
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
@@ -197,6 +198,33 @@ def liger_multi_token_attention(
|
|
197
198
|
return LigerMultiTokenAttentionFunction.apply(scores, weight, bias, stride, padding, dilation, groups, sparse)
|
198
199
|
|
199
200
|
|
201
|
+
def liger_fused_neighborhood_attention(
|
202
|
+
query,
|
203
|
+
key,
|
204
|
+
value,
|
205
|
+
kernel_size: int = 7,
|
206
|
+
dilation: int = 1,
|
207
|
+
scale: float = None,
|
208
|
+
):
|
209
|
+
"""
|
210
|
+
Liger fused neighborhood attention.
|
211
|
+
|
212
|
+
paper: https://arxiv.org/pdf/2504.16922
|
213
|
+
|
214
|
+
Args:
|
215
|
+
query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
216
|
+
key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
217
|
+
value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
218
|
+
kernel_size: Size of the neighborhood window (default: 7)
|
219
|
+
dilation: Dilation factor for the neighborhood (default: 1)
|
220
|
+
scale: Scaling factor for attention scores (default: rsqrt(head_dim))
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
Output tensor of shape [batch_size, num_heads, seq_len, head_dim]
|
224
|
+
"""
|
225
|
+
return LigerFusedNeighborhoodAttentionFunction.apply(query, key, value, kernel_size, dilation, scale)
|
226
|
+
|
227
|
+
|
200
228
|
def liger_tvd(
|
201
229
|
input,
|
202
230
|
target,
|
@@ -0,0 +1,234 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
|
8
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
|
9
|
+
|
10
|
+
|
11
|
+
class LigerFusedNeighborhoodAttention(nn.Module):
|
12
|
+
"""
|
13
|
+
Liger Fused Neighborhood Attention Module.
|
14
|
+
|
15
|
+
Paper: https://arxiv.org/pdf/2504.16922
|
16
|
+
|
17
|
+
Fused Neighborhood attention restricts the attention mechanism to a local neighborhood
|
18
|
+
around each position, reducing computational complexity from O(n²) to O(n*k)
|
19
|
+
where k is the neighborhood size.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
hidden_size (int): The hidden dimension size
|
23
|
+
num_heads (int): Number of attention heads
|
24
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
25
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
26
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
27
|
+
dropout (float): Dropout probability (default: 0.0)
|
28
|
+
scale (Optional[float]): Scaling factor for attention scores.
|
29
|
+
If None, uses 1/sqrt(head_dim) (default: None)
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
hidden_size: int,
|
35
|
+
num_heads: int,
|
36
|
+
kernel_size: int = 7,
|
37
|
+
dilation: int = 1,
|
38
|
+
bias: bool = True,
|
39
|
+
dropout: float = 0.0,
|
40
|
+
scale: Optional[float] = None,
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
|
44
|
+
if hidden_size % num_heads != 0:
|
45
|
+
raise ValueError(f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})")
|
46
|
+
|
47
|
+
if kernel_size <= 0:
|
48
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be positive")
|
49
|
+
|
50
|
+
if kernel_size % 2 == 0:
|
51
|
+
raise ValueError(f"kernel_size ({kernel_size}) must be odd")
|
52
|
+
|
53
|
+
if dilation < 1:
|
54
|
+
raise ValueError(f"dilation ({dilation}) must be positive")
|
55
|
+
|
56
|
+
self.hidden_size = hidden_size
|
57
|
+
self.num_heads = num_heads
|
58
|
+
self.head_dim = hidden_size // num_heads
|
59
|
+
self.kernel_size = kernel_size
|
60
|
+
self.dilation = dilation
|
61
|
+
self.scale = scale if scale is not None else 1.0 / math.sqrt(self.head_dim)
|
62
|
+
self.dropout_p = dropout
|
63
|
+
|
64
|
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
65
|
+
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
66
|
+
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
67
|
+
|
68
|
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
69
|
+
|
70
|
+
if dropout > 0.0:
|
71
|
+
self.dropout = nn.Dropout(dropout)
|
72
|
+
else:
|
73
|
+
self.dropout = None
|
74
|
+
|
75
|
+
def forward(
|
76
|
+
self,
|
77
|
+
hidden_states: torch.Tensor,
|
78
|
+
attention_mask: Optional[torch.Tensor] = None,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
"""
|
81
|
+
Forward pass of the fused neighborhood attention module.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
85
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
89
|
+
"""
|
90
|
+
if attention_mask is not None:
|
91
|
+
raise NotImplementedError("Attention mask is not yet supported in LigerFusedNeighborhoodAttention")
|
92
|
+
|
93
|
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
94
|
+
|
95
|
+
query = self.q_proj(hidden_states)
|
96
|
+
key = self.k_proj(hidden_states)
|
97
|
+
value = self.v_proj(hidden_states)
|
98
|
+
|
99
|
+
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
100
|
+
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
101
|
+
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
102
|
+
|
103
|
+
attn_output = LigerFusedNeighborhoodAttentionFunction.apply(
|
104
|
+
query, key, value, self.kernel_size, self.dilation, self.scale
|
105
|
+
)
|
106
|
+
|
107
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
|
108
|
+
|
109
|
+
if self.dropout is not None:
|
110
|
+
attn_output = self.dropout(attn_output)
|
111
|
+
|
112
|
+
output = self.out_proj(attn_output)
|
113
|
+
|
114
|
+
return output
|
115
|
+
|
116
|
+
def extra_repr(self) -> str:
|
117
|
+
return (
|
118
|
+
f"hidden_size={self.hidden_size}, num_heads={self.num_heads}, "
|
119
|
+
f"head_dim={self.head_dim}, kernel_size={self.kernel_size}, "
|
120
|
+
f"dilation={self.dilation}, scale={self.scale}, dropout={self.dropout_p}"
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
class LigerFusedNeighborhoodAttentionLayer(nn.Module):
|
125
|
+
"""
|
126
|
+
A complete neighborhood attention layer with layer norm and residual connection.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
hidden_size (int): The hidden dimension size
|
130
|
+
num_heads (int): Number of attention heads
|
131
|
+
kernel_size (int): Size of the neighborhood window (default: 7)
|
132
|
+
dilation (int): Dilation factor for the neighborhood (default: 1)
|
133
|
+
bias (bool): Whether to use bias in linear projections (default: True)
|
134
|
+
dropout (float): Dropout probability (default: 0.0)
|
135
|
+
layer_norm_eps (float): Epsilon for layer normalization (default: 1e-5)
|
136
|
+
scale (Optional[float]): Scaling factor for attention scores (default: None)
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
hidden_size: int,
|
142
|
+
num_heads: int,
|
143
|
+
kernel_size: int = 7,
|
144
|
+
dilation: int = 1,
|
145
|
+
bias: bool = True,
|
146
|
+
dropout: float = 0.0,
|
147
|
+
layer_norm_eps: float = 1e-5,
|
148
|
+
scale: Optional[float] = None,
|
149
|
+
):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
self.attention = LigerFusedNeighborhoodAttention(
|
153
|
+
hidden_size=hidden_size,
|
154
|
+
num_heads=num_heads,
|
155
|
+
kernel_size=kernel_size,
|
156
|
+
dilation=dilation,
|
157
|
+
bias=bias,
|
158
|
+
dropout=dropout,
|
159
|
+
scale=scale,
|
160
|
+
)
|
161
|
+
|
162
|
+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
163
|
+
|
164
|
+
if dropout > 0.0:
|
165
|
+
self.dropout = nn.Dropout(dropout)
|
166
|
+
else:
|
167
|
+
self.dropout = None
|
168
|
+
|
169
|
+
def forward(
|
170
|
+
self,
|
171
|
+
hidden_states: torch.Tensor,
|
172
|
+
attention_mask: Optional[torch.Tensor] = None,
|
173
|
+
) -> torch.Tensor:
|
174
|
+
"""
|
175
|
+
Forward pass with residual connection and layer normalization.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
179
|
+
attention_mask (Optional[torch.Tensor]): Attention mask (currently not supported)
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
|
183
|
+
"""
|
184
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
185
|
+
|
186
|
+
attn_output = self.attention(normed_hidden_states, attention_mask)
|
187
|
+
|
188
|
+
if self.dropout is not None:
|
189
|
+
attn_output = self.dropout(attn_output)
|
190
|
+
|
191
|
+
output = hidden_states + attn_output
|
192
|
+
|
193
|
+
return output
|
194
|
+
|
195
|
+
|
196
|
+
class LigerFusedNeighborhoodAttentionConfig:
|
197
|
+
"""
|
198
|
+
Configuration class for Fused Neighborhood Attention.
|
199
|
+
|
200
|
+
This can be used to easily configure neighborhood attention parameters
|
201
|
+
for different model architectures.
|
202
|
+
"""
|
203
|
+
|
204
|
+
def __init__(
|
205
|
+
self,
|
206
|
+
hidden_size: int = 768,
|
207
|
+
num_heads: int = 12,
|
208
|
+
kernel_size: int = 7,
|
209
|
+
dilation: int = 1,
|
210
|
+
bias: bool = True,
|
211
|
+
dropout: float = 0.0,
|
212
|
+
layer_norm_eps: float = 1e-5,
|
213
|
+
scale: Optional[float] = None,
|
214
|
+
):
|
215
|
+
self.hidden_size = hidden_size
|
216
|
+
self.num_heads = num_heads
|
217
|
+
self.kernel_size = kernel_size
|
218
|
+
self.dilation = dilation
|
219
|
+
self.bias = bias
|
220
|
+
self.dropout = dropout
|
221
|
+
self.layer_norm_eps = layer_norm_eps
|
222
|
+
self.scale = scale
|
223
|
+
|
224
|
+
def to_dict(self):
|
225
|
+
return {
|
226
|
+
"hidden_size": self.hidden_size,
|
227
|
+
"num_heads": self.num_heads,
|
228
|
+
"kernel_size": self.kernel_size,
|
229
|
+
"dilation": self.dilation,
|
230
|
+
"bias": self.bias,
|
231
|
+
"dropout": self.dropout,
|
232
|
+
"layer_norm_eps": self.layer_norm_eps,
|
233
|
+
"scale": self.scale,
|
234
|
+
}
|
@@ -138,7 +138,7 @@ def lce_forward(
|
|
138
138
|
cache_position: Optional[torch.LongTensor] = None,
|
139
139
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
140
140
|
skip_logits: Optional[bool] = None,
|
141
|
-
**
|
141
|
+
**kwargs,
|
142
142
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
143
143
|
r"""
|
144
144
|
Args:
|
@@ -190,6 +190,7 @@ def lce_forward(
|
|
190
190
|
output_hidden_states=output_hidden_states,
|
191
191
|
return_dict=return_dict,
|
192
192
|
cache_position=cache_position,
|
193
|
+
**kwargs,
|
193
194
|
)
|
194
195
|
|
195
196
|
hidden_states = outputs[0]
|
@@ -197,7 +198,7 @@ def lce_forward(
|
|
197
198
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
198
199
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
199
200
|
|
200
|
-
shift_labels =
|
201
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
201
202
|
logits = None
|
202
203
|
loss = None
|
203
204
|
|
@@ -215,7 +216,7 @@ def lce_forward(
|
|
215
216
|
labels=labels,
|
216
217
|
shift_labels=shift_labels,
|
217
218
|
hidden_size=self.config.hidden_size,
|
218
|
-
**
|
219
|
+
**kwargs,
|
219
220
|
)
|
220
221
|
else:
|
221
222
|
logits = self.lm_head(kept_hidden_states)
|
@@ -224,7 +225,7 @@ def lce_forward(
|
|
224
225
|
logits=logits,
|
225
226
|
labels=labels,
|
226
227
|
vocab_size=self.config.vocab_size,
|
227
|
-
**
|
228
|
+
**kwargs,
|
228
229
|
)
|
229
230
|
|
230
231
|
if not return_dict:
|
@@ -30,6 +30,7 @@ def lce_forward_deprecated(
|
|
30
30
|
output_hidden_states: Optional[bool] = None,
|
31
31
|
return_dict: Optional[bool] = None,
|
32
32
|
cache_position: Optional[torch.LongTensor] = None,
|
33
|
+
**kwargs,
|
33
34
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
34
35
|
r"""
|
35
36
|
Args:
|
@@ -76,6 +77,7 @@ def lce_forward_deprecated(
|
|
76
77
|
output_hidden_states=output_hidden_states,
|
77
78
|
return_dict=return_dict,
|
78
79
|
cache_position=cache_position,
|
80
|
+
**kwargs,
|
79
81
|
)
|
80
82
|
|
81
83
|
hidden_states = outputs[0]
|
@@ -147,7 +149,7 @@ def lce_forward(
|
|
147
149
|
cache_position: Optional[torch.LongTensor] = None,
|
148
150
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
149
151
|
skip_logits: Optional[bool] = None,
|
150
|
-
**
|
152
|
+
**kwargs,
|
151
153
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
152
154
|
r"""
|
153
155
|
Args:
|
@@ -204,6 +206,7 @@ def lce_forward(
|
|
204
206
|
output_hidden_states=output_hidden_states,
|
205
207
|
return_dict=return_dict,
|
206
208
|
cache_position=cache_position,
|
209
|
+
**kwargs,
|
207
210
|
)
|
208
211
|
|
209
212
|
hidden_states = outputs[0]
|
@@ -211,7 +214,7 @@ def lce_forward(
|
|
211
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
212
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
213
216
|
|
214
|
-
shift_labels =
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
215
218
|
logits = None
|
216
219
|
loss = None
|
217
220
|
|
@@ -230,7 +233,7 @@ def lce_forward(
|
|
230
233
|
shift_labels=shift_labels,
|
231
234
|
hidden_size=self.config.hidden_size,
|
232
235
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
233
|
-
**
|
236
|
+
**kwargs,
|
234
237
|
)
|
235
238
|
|
236
239
|
else:
|
@@ -242,7 +245,7 @@ def lce_forward(
|
|
242
245
|
|
243
246
|
loss = None
|
244
247
|
if labels is not None:
|
245
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
248
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
246
249
|
|
247
250
|
if not return_dict:
|
248
251
|
output = (logits,) + outputs[1:]
|
@@ -27,7 +27,7 @@ def lce_forward(
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
29
29
|
skip_logits: Optional[bool] = None,
|
30
|
-
**
|
30
|
+
**kwargs,
|
31
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
32
32
|
r"""
|
33
33
|
Args:
|
@@ -80,6 +80,7 @@ def lce_forward(
|
|
80
80
|
output_hidden_states=output_hidden_states,
|
81
81
|
return_dict=return_dict,
|
82
82
|
cache_position=cache_position,
|
83
|
+
**kwargs,
|
83
84
|
)
|
84
85
|
|
85
86
|
hidden_states = outputs[0]
|
@@ -87,7 +88,7 @@ def lce_forward(
|
|
87
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
88
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
89
90
|
|
90
|
-
shift_labels =
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
91
92
|
logits = None
|
92
93
|
loss = None
|
93
94
|
|
@@ -105,7 +106,7 @@ def lce_forward(
|
|
105
106
|
labels=labels,
|
106
107
|
shift_labels=shift_labels,
|
107
108
|
hidden_size=self.config.hidden_size,
|
108
|
-
**
|
109
|
+
**kwargs,
|
109
110
|
)
|
110
111
|
|
111
112
|
else:
|
@@ -115,7 +116,7 @@ def lce_forward(
|
|
115
116
|
logits=logits,
|
116
117
|
labels=labels,
|
117
118
|
vocab_size=self.config.vocab_size,
|
118
|
-
**
|
119
|
+
**kwargs,
|
119
120
|
)
|
120
121
|
|
121
122
|
return CausalLMOutputWithPast(
|
@@ -152,7 +152,7 @@ def lce_forward(
|
|
152
152
|
cache_position: Optional[torch.LongTensor] = None,
|
153
153
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
154
154
|
skip_logits: Optional[bool] = None,
|
155
|
-
**
|
155
|
+
**kwargs,
|
156
156
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
157
157
|
r"""
|
158
158
|
Args:
|
@@ -205,6 +205,7 @@ def lce_forward(
|
|
205
205
|
output_hidden_states=output_hidden_states,
|
206
206
|
return_dict=return_dict,
|
207
207
|
cache_position=cache_position,
|
208
|
+
**kwargs,
|
208
209
|
)
|
209
210
|
|
210
211
|
hidden_states = outputs[0]
|
@@ -215,7 +216,7 @@ def lce_forward(
|
|
215
216
|
if self.config.pretraining_tp > 1:
|
216
217
|
raise Exception("Liger Kernel does not support pretraining_tp!!")
|
217
218
|
|
218
|
-
shift_labels =
|
219
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
219
220
|
logits = None
|
220
221
|
loss = None
|
221
222
|
# if in training mode, don't materialize logits
|
@@ -233,7 +234,7 @@ def lce_forward(
|
|
233
234
|
hidden_size=self.config.hidden_size,
|
234
235
|
labels=labels,
|
235
236
|
shift_labels=shift_labels,
|
236
|
-
**
|
237
|
+
**kwargs,
|
237
238
|
)
|
238
239
|
|
239
240
|
else:
|
@@ -243,7 +244,7 @@ def lce_forward(
|
|
243
244
|
logits=logits,
|
244
245
|
labels=labels,
|
245
246
|
vocab_size=self.config.vocab_size,
|
246
|
-
**
|
247
|
+
**kwargs,
|
247
248
|
)
|
248
249
|
|
249
250
|
if not return_dict:
|
@@ -28,7 +28,7 @@ def lce_forward(
|
|
28
28
|
cache_position: Optional[torch.LongTensor] = None,
|
29
29
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
30
30
|
skip_logits: Optional[bool] = None,
|
31
|
-
**
|
31
|
+
**kwargs,
|
32
32
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
33
33
|
r"""
|
34
34
|
Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -83,6 +83,7 @@ def lce_forward(
|
|
83
83
|
output_hidden_states=output_hidden_states,
|
84
84
|
return_dict=return_dict,
|
85
85
|
cache_position=cache_position,
|
86
|
+
**kwargs,
|
86
87
|
)
|
87
88
|
|
88
89
|
hidden_states = outputs[0]
|
@@ -90,7 +91,7 @@ def lce_forward(
|
|
90
91
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
91
92
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
92
93
|
|
93
|
-
shift_labels =
|
94
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
94
95
|
loss = None
|
95
96
|
logits = None
|
96
97
|
|
@@ -107,7 +108,7 @@ def lce_forward(
|
|
107
108
|
labels=labels,
|
108
109
|
shift_labels=shift_labels,
|
109
110
|
hidden_size=self.config.hidden_size,
|
110
|
-
**
|
111
|
+
**kwargs,
|
111
112
|
)
|
112
113
|
|
113
114
|
else:
|
@@ -119,7 +120,7 @@ def lce_forward(
|
|
119
120
|
logits=logits,
|
120
121
|
labels=labels,
|
121
122
|
vocab_size=self.config.vocab_size,
|
122
|
-
**
|
123
|
+
**kwargs,
|
123
124
|
)
|
124
125
|
if not return_dict:
|
125
126
|
output = (logits,) + outputs[1:]
|
@@ -157,7 +157,7 @@ def lce_forward(
|
|
157
157
|
cache_position: Optional[torch.LongTensor] = None,
|
158
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
159
159
|
skip_logits: Optional[bool] = None,
|
160
|
-
**
|
160
|
+
**kwargs,
|
161
161
|
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
162
162
|
r"""
|
163
163
|
Args:
|
@@ -215,6 +215,7 @@ def lce_forward(
|
|
215
215
|
output_router_logits=output_router_logits,
|
216
216
|
return_dict=return_dict,
|
217
217
|
cache_position=cache_position,
|
218
|
+
**kwargs,
|
218
219
|
)
|
219
220
|
|
220
221
|
hidden_states = outputs[0]
|
@@ -222,7 +223,7 @@ def lce_forward(
|
|
222
223
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
223
224
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
224
225
|
|
225
|
-
shift_labels =
|
226
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
226
227
|
logits = None
|
227
228
|
loss = None
|
228
229
|
|
@@ -240,7 +241,7 @@ def lce_forward(
|
|
240
241
|
labels=labels,
|
241
242
|
shift_labels=shift_labels,
|
242
243
|
hidden_size=self.config.hidden_size,
|
243
|
-
**
|
244
|
+
**kwargs,
|
244
245
|
)
|
245
246
|
|
246
247
|
else:
|
@@ -248,7 +249,7 @@ def lce_forward(
|
|
248
249
|
|
249
250
|
loss = None
|
250
251
|
if labels is not None:
|
251
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
252
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
252
253
|
aux_loss = None
|
253
254
|
if output_router_logits:
|
254
255
|
aux_loss = load_balancing_loss_func(
|
@@ -148,7 +148,7 @@ def lce_forward(
|
|
148
148
|
cache_position: Optional[torch.LongTensor] = None,
|
149
149
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
150
150
|
skip_logits: Optional[bool] = None,
|
151
|
-
**
|
151
|
+
**kwargs,
|
152
152
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
153
153
|
r"""
|
154
154
|
Args:
|
@@ -206,6 +206,7 @@ def lce_forward(
|
|
206
206
|
output_hidden_states=output_hidden_states,
|
207
207
|
return_dict=return_dict,
|
208
208
|
cache_position=cache_position,
|
209
|
+
**kwargs,
|
209
210
|
)
|
210
211
|
|
211
212
|
hidden_states = outputs[0]
|
@@ -213,7 +214,7 @@ def lce_forward(
|
|
213
214
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
214
215
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
215
216
|
|
216
|
-
shift_labels =
|
217
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
217
218
|
logits = None
|
218
219
|
loss = None
|
219
220
|
|
@@ -231,7 +232,7 @@ def lce_forward(
|
|
231
232
|
labels=labels,
|
232
233
|
shift_labels=shift_labels,
|
233
234
|
hidden_size=self.config.hidden_size,
|
234
|
-
**
|
235
|
+
**kwargs,
|
235
236
|
)
|
236
237
|
|
237
238
|
else:
|
@@ -241,7 +242,7 @@ def lce_forward(
|
|
241
242
|
logits=logits,
|
242
243
|
labels=labels,
|
243
244
|
vocab_size=self.config.vocab_size,
|
244
|
-
**
|
245
|
+
**kwargs,
|
245
246
|
)
|
246
247
|
|
247
248
|
if not return_dict:
|
@@ -27,7 +27,7 @@ def lce_forward(
|
|
27
27
|
cache_position: Optional[torch.LongTensor] = None,
|
28
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
29
29
|
skip_logits: Optional[bool] = None,
|
30
|
-
**
|
30
|
+
**kwargs,
|
31
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
32
32
|
r"""
|
33
33
|
Args:
|
@@ -80,6 +80,7 @@ def lce_forward(
|
|
80
80
|
output_hidden_states=output_hidden_states,
|
81
81
|
return_dict=return_dict,
|
82
82
|
cache_position=cache_position,
|
83
|
+
**kwargs,
|
83
84
|
)
|
84
85
|
|
85
86
|
hidden_states = outputs[0]
|
@@ -87,7 +88,7 @@ def lce_forward(
|
|
87
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
88
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
89
90
|
|
90
|
-
shift_labels =
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
91
92
|
logits = None
|
92
93
|
loss = None
|
93
94
|
|
@@ -105,7 +106,7 @@ def lce_forward(
|
|
105
106
|
labels=labels,
|
106
107
|
shift_labels=shift_labels,
|
107
108
|
hidden_size=self.config.hidden_size,
|
108
|
-
**
|
109
|
+
**kwargs,
|
109
110
|
)
|
110
111
|
|
111
112
|
else:
|
@@ -115,7 +116,7 @@ def lce_forward(
|
|
115
116
|
logits=logits,
|
116
117
|
labels=labels,
|
117
118
|
vocab_size=self.config.vocab_size,
|
118
|
-
**
|
119
|
+
**kwargs,
|
119
120
|
)
|
120
121
|
|
121
122
|
return CausalLMOutputWithPast(
|
@@ -137,7 +137,7 @@ def lce_forward(
|
|
137
137
|
cache_position: Optional[torch.LongTensor] = None,
|
138
138
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
139
139
|
skip_logits: Optional[bool] = None,
|
140
|
-
**
|
140
|
+
**kwargs,
|
141
141
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
142
142
|
r"""
|
143
143
|
Args:
|
@@ -203,6 +203,7 @@ def lce_forward(
|
|
203
203
|
output_attentions=output_attentions,
|
204
204
|
output_hidden_states=output_hidden_states,
|
205
205
|
return_dict=return_dict,
|
206
|
+
**kwargs,
|
206
207
|
)
|
207
208
|
|
208
209
|
hidden_states = outputs[0]
|
@@ -210,7 +211,7 @@ def lce_forward(
|
|
210
211
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
211
212
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
212
213
|
|
213
|
-
shift_labels =
|
214
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
214
215
|
logits = None
|
215
216
|
loss = None
|
216
217
|
|
@@ -228,7 +229,7 @@ def lce_forward(
|
|
228
229
|
labels=labels,
|
229
230
|
shift_labels=shift_labels,
|
230
231
|
hidden_size=self.config.hidden_size,
|
231
|
-
**
|
232
|
+
**kwargs,
|
232
233
|
)
|
233
234
|
|
234
235
|
else:
|
@@ -238,7 +239,7 @@ def lce_forward(
|
|
238
239
|
logits=logits,
|
239
240
|
labels=labels,
|
240
241
|
vocab_size=self.config.vocab_size,
|
241
|
-
**
|
242
|
+
**kwargs,
|
242
243
|
)
|
243
244
|
|
244
245
|
if not return_dict:
|