mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,900 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Optional, Dict, Literal, Any, List
|
|
3
|
+
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
|
|
7
|
+
from .base import (
|
|
8
|
+
BaseModelArgs,
|
|
9
|
+
RaclateBaseModel,
|
|
10
|
+
compute_similarity_and_loss,
|
|
11
|
+
mean_pooling,
|
|
12
|
+
normalize_embeddings
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
""" NOTE : This implementation of ModernBERT excludes all features related to Flash Attention 2, padded/unpadded handling"""
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ModelArgs(BaseModelArgs):
|
|
19
|
+
architectures: List[str] = field(default_factory=lambda: ["ModernBertModel"])
|
|
20
|
+
attention_bias: bool = False
|
|
21
|
+
attention_dropout : float =0.0
|
|
22
|
+
bos_token_id: int = 50281
|
|
23
|
+
cls_token_id: int = 50281
|
|
24
|
+
embedding_dropout : float = 0.0
|
|
25
|
+
eos_token_id : int = 50282
|
|
26
|
+
global_attn_every_n_layers : int = 3
|
|
27
|
+
global_rope_theta : float = 160000.0
|
|
28
|
+
hidden_size: int = 768
|
|
29
|
+
initializer_range : float = 0.02
|
|
30
|
+
initializer_cutoff_factor: float = 2.0 # relevant for MLX?
|
|
31
|
+
intermediate_size: int = 1152
|
|
32
|
+
local_attention : int =128
|
|
33
|
+
local_rope_theta: float = 10000
|
|
34
|
+
max_position_embeddings: int = 8192
|
|
35
|
+
mlp_bias: bool = False
|
|
36
|
+
mlp_dropout : float = 0.0
|
|
37
|
+
model_type: str = "modernbert"
|
|
38
|
+
norm_bias : bool = False
|
|
39
|
+
norm_eps: float = 1e-05
|
|
40
|
+
num_attention_heads: int = 12
|
|
41
|
+
num_hidden_layers: int = 22
|
|
42
|
+
output_hidden_states: bool = False
|
|
43
|
+
pad_token_id: int = 50283
|
|
44
|
+
sep_token_id: int = 50282
|
|
45
|
+
vocab_size: int = 50368
|
|
46
|
+
|
|
47
|
+
### pipeline args
|
|
48
|
+
decoder_bias=True,
|
|
49
|
+
classifier_pooling: Literal["cls", "mean"] = "cls"
|
|
50
|
+
classifier_dropout=0.0
|
|
51
|
+
classifier_bias=False
|
|
52
|
+
sparse_prediction=True ### True seems a more appropriate value for MLM
|
|
53
|
+
sparse_pred_ignore_index=-100
|
|
54
|
+
is_regression: Optional[bool] = None
|
|
55
|
+
label2id: Optional[Dict[str, int]] = None
|
|
56
|
+
id2label: Optional[Dict[int, str]] = None
|
|
57
|
+
pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
|
|
58
|
+
use_late_interaction: bool = False
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def num_labels(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Number of labels is determined by:
|
|
64
|
+
- For zero-shot classification: length of label_candidates
|
|
65
|
+
- For regression or binary with sigmoid: 1
|
|
66
|
+
- For classification: length of id2label mapping
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
if self.is_regression:
|
|
70
|
+
return 1
|
|
71
|
+
|
|
72
|
+
if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
|
|
73
|
+
return 1
|
|
74
|
+
|
|
75
|
+
if self.id2label is None:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"id2label mapping must be provided for categorical classification. "
|
|
78
|
+
"For regression or binary classification with sigmoid output, "
|
|
79
|
+
"set is_regression=True or binary_sigmoid=True in pipeline_config."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return len(self.id2label)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ModernBertEmbeddings(nn.Module):
|
|
86
|
+
"""
|
|
87
|
+
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
|
88
|
+
"""
|
|
89
|
+
def __init__(self, config: ModelArgs):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
92
|
+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
93
|
+
self.drop = nn.Dropout(p=config.embedding_dropout)
|
|
94
|
+
|
|
95
|
+
def __call__(self, input_ids):
|
|
96
|
+
embeddings = self.tok_embeddings(input_ids)
|
|
97
|
+
embeddings = self.norm(embeddings)
|
|
98
|
+
embeddings = self.drop(embeddings)
|
|
99
|
+
return embeddings
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ModernBertMLP(nn.Module):
|
|
103
|
+
"""Applies the GLU at the end of each ModernBERT layer.
|
|
104
|
+
|
|
105
|
+
Compared to the default BERT architecture, this block replaces class BertIntermediate`
|
|
106
|
+
and class SelfOutput with a single module that has similar functionality.
|
|
107
|
+
"""
|
|
108
|
+
def __init__(self, config: ModelArgs):
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.config = config
|
|
111
|
+
self.Wi = nn.Linear(config.hidden_size, config.intermediate_size *2, bias=config.mlp_bias)
|
|
112
|
+
self.act = nn.GELU()
|
|
113
|
+
self.drop = nn.Dropout(p=config.mlp_dropout)
|
|
114
|
+
self.Wo = nn.Linear(int(config.intermediate_size), config.hidden_size, bias=config.mlp_bias)
|
|
115
|
+
|
|
116
|
+
def __call__(self, hidden_states):
|
|
117
|
+
x = self.Wi(hidden_states)
|
|
118
|
+
|
|
119
|
+
split_dim = x.shape[-1] // 2
|
|
120
|
+
input, gate = x[:, :, :split_dim], x[:, :, split_dim:] # gate : https://arxiv.org/pdf/2002.05202v1
|
|
121
|
+
return self.Wo(self.drop(self.act(input) * gate))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ModernBertAttention(nn.Module):
|
|
125
|
+
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
|
126
|
+
For now, only supports the Scaled Dot-Product Attention (SDPA) implementation.
|
|
127
|
+
"""
|
|
128
|
+
def __init__(self, config: ModelArgs, layer_id: Optional[int] = None):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.config = config
|
|
131
|
+
self.layer_id = layer_id
|
|
132
|
+
|
|
133
|
+
if config.hidden_size % config.num_attention_heads != 0:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"hidden_size ({config.hidden_size}) must be divisible by num_attention_heads ({config.num_attention_heads})"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.attention_dropout = config.attention_dropout
|
|
139
|
+
self.num_heads = config.num_attention_heads
|
|
140
|
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
141
|
+
self.all_head_size = self.head_dim * self.num_heads
|
|
142
|
+
self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
|
|
143
|
+
|
|
144
|
+
if layer_id % config.global_attn_every_n_layers != 0:
|
|
145
|
+
self.local_attention = (config.local_attention // 2, config.local_attention // 2)
|
|
146
|
+
else:
|
|
147
|
+
self.local_attention = (-1, -1)
|
|
148
|
+
|
|
149
|
+
rope_theta = config.global_rope_theta
|
|
150
|
+
if self.local_attention != (-1, -1) and config.local_rope_theta is not None:
|
|
151
|
+
rope_theta = config.local_rope_theta
|
|
152
|
+
|
|
153
|
+
self.rotary_emb = nn.RoPE(dims=self.head_dim, base=rope_theta)
|
|
154
|
+
|
|
155
|
+
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
|
156
|
+
self.out_drop = nn.Dropout(p=config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
157
|
+
self.pruned_heads = set()
|
|
158
|
+
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
hidden_states,
|
|
162
|
+
attention_mask = None,
|
|
163
|
+
sliding_window_mask = None,
|
|
164
|
+
**kwargs
|
|
165
|
+
):
|
|
166
|
+
qkv = self.Wqkv(hidden_states)
|
|
167
|
+
bs = hidden_states.shape[0]
|
|
168
|
+
qkv = mx.reshape(qkv, (bs, -1, 3, self.num_heads, self.head_dim))
|
|
169
|
+
|
|
170
|
+
# Get attention outputs using SDPA
|
|
171
|
+
qkv = mx.transpose(
|
|
172
|
+
qkv, [0, 3, 2, 1, 4]
|
|
173
|
+
) # [batch_size, nheads, 3, seqlen, headdim]
|
|
174
|
+
query, key, value = mx.split(
|
|
175
|
+
qkv, indices_or_sections=3, axis=2
|
|
176
|
+
) # each [batch_size, nheads, 1, seqlen, headdim]
|
|
177
|
+
query = query.squeeze(2) # [batch_size, nheads, seqlen, headdim]
|
|
178
|
+
key = key.squeeze(2) # [batch_size, nheads, seqlen, headdim]
|
|
179
|
+
value = value.squeeze(2) # [batch_size, nheads, seqlen, headdim]
|
|
180
|
+
|
|
181
|
+
# Applying rotary embeddings
|
|
182
|
+
query = self.rotary_emb(query)
|
|
183
|
+
key = self.rotary_emb(key)
|
|
184
|
+
|
|
185
|
+
# Handling local attention if needed
|
|
186
|
+
if self.local_attention != (-1, -1):
|
|
187
|
+
attention_mask = sliding_window_mask
|
|
188
|
+
|
|
189
|
+
# Computing attention using MLX's SDPA
|
|
190
|
+
scale = query.shape[-1] ** -0.5
|
|
191
|
+
attn_output = mx.fast.scaled_dot_product_attention(
|
|
192
|
+
query, key, value,
|
|
193
|
+
scale=scale,
|
|
194
|
+
mask=attention_mask
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Reshaping and apply output projection
|
|
198
|
+
attn_output = mx.transpose(attn_output, [0, 2, 1, 3])
|
|
199
|
+
attn_output = mx.reshape(attn_output, (bs, -1, self.all_head_size))
|
|
200
|
+
|
|
201
|
+
# Applying output projection and dropout
|
|
202
|
+
hidden_states = self.Wo(attn_output)
|
|
203
|
+
hidden_states = self.out_drop(hidden_states)
|
|
204
|
+
|
|
205
|
+
return (hidden_states,)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class ModernBertEncoderLayer(nn.Module):
|
|
209
|
+
def __init__(self, config: ModelArgs, layer_id: Optional[int] = None):
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.config = config
|
|
212
|
+
if layer_id == 0:
|
|
213
|
+
self.attn_norm = nn.Identity()
|
|
214
|
+
else:
|
|
215
|
+
self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
216
|
+
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
|
|
217
|
+
self.mlp = ModernBertMLP(config)
|
|
218
|
+
self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
219
|
+
|
|
220
|
+
def __call__(
|
|
221
|
+
self,
|
|
222
|
+
hidden_states ,
|
|
223
|
+
attention_mask =None,
|
|
224
|
+
sliding_window_mask = None,
|
|
225
|
+
position_ids = None,
|
|
226
|
+
):
|
|
227
|
+
normalized_hidden_states = self.attn_norm(hidden_states)
|
|
228
|
+
attention_output = self.attn(
|
|
229
|
+
normalized_hidden_states,
|
|
230
|
+
attention_mask=attention_mask,
|
|
231
|
+
sliding_window_mask=sliding_window_mask,
|
|
232
|
+
position_ids=position_ids,
|
|
233
|
+
)
|
|
234
|
+
hidden_states = hidden_states + attention_output[0]
|
|
235
|
+
mlp_output = self.mlp(self.mlp_norm(hidden_states))
|
|
236
|
+
hidden_states = hidden_states + mlp_output
|
|
237
|
+
|
|
238
|
+
return (hidden_states,)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class ModernBertModel(nn.Module):
|
|
242
|
+
def __init__(self, config: ModelArgs):
|
|
243
|
+
super().__init__()
|
|
244
|
+
self.config = config
|
|
245
|
+
self.embeddings = ModernBertEmbeddings(config)
|
|
246
|
+
self.layers = [
|
|
247
|
+
ModernBertEncoderLayer(config, i) for i in range(config.num_hidden_layers)
|
|
248
|
+
]
|
|
249
|
+
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
|
250
|
+
self.gradient_checkpointing = False
|
|
251
|
+
|
|
252
|
+
def get_input_embeddings(self) -> ModernBertEmbeddings:
|
|
253
|
+
return self.embeddings.tok_embeddings
|
|
254
|
+
|
|
255
|
+
def set_input_embeddings(self, value):
|
|
256
|
+
self.embeddings.tok_embeddings = value
|
|
257
|
+
|
|
258
|
+
def _update_attention_mask(self, attention_mask, model_dtype): #TODO: move to base.py ??
|
|
259
|
+
|
|
260
|
+
batch_size, seq_len = attention_mask.shape
|
|
261
|
+
neg_inf = -1e4
|
|
262
|
+
|
|
263
|
+
additive_mask = mx.where(attention_mask == 1, 0.0, neg_inf)
|
|
264
|
+
additive_mask = additive_mask[:, None, None, :] # (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
|
|
265
|
+
|
|
266
|
+
# Create the causal mask for global attention
|
|
267
|
+
global_attention_mask = mx.broadcast_to(additive_mask, (batch_size, 1, seq_len, seq_len))
|
|
268
|
+
|
|
269
|
+
# Create position indices for sliding window
|
|
270
|
+
rows = mx.arange(seq_len)
|
|
271
|
+
rows = rows[None, :] # (1, seq_len)
|
|
272
|
+
# Calculate position-wise distances
|
|
273
|
+
distance = mx.abs(rows - rows.T) # (seq_len, seq_len)
|
|
274
|
+
|
|
275
|
+
# Create sliding window mask using mx.where
|
|
276
|
+
window_mask = mx.where(
|
|
277
|
+
distance <= (self.config.local_attention // 2),
|
|
278
|
+
mx.ones_like(distance),
|
|
279
|
+
mx.zeros_like(distance)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Expand dimensions using None indexing
|
|
283
|
+
window_mask = window_mask[None, None, :, :] # (1, 1, seq_len, seq_len)
|
|
284
|
+
|
|
285
|
+
# Broadcast to match batch size
|
|
286
|
+
window_mask = mx.broadcast_to(window_mask, global_attention_mask.shape)
|
|
287
|
+
|
|
288
|
+
# Create sliding window attention mask
|
|
289
|
+
# Replace non-window positions with large negative value
|
|
290
|
+
sliding_window_mask = mx.where(
|
|
291
|
+
window_mask,
|
|
292
|
+
global_attention_mask,
|
|
293
|
+
neg_inf # if not broadcasted for some reason : neg_inf * mx.ones_like(global_attention_mask)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Convert to model_dtype for scaled_dot_product_attention
|
|
297
|
+
global_attention_mask = global_attention_mask.astype(model_dtype)
|
|
298
|
+
sliding_window_mask = sliding_window_mask.astype(model_dtype)
|
|
299
|
+
|
|
300
|
+
return global_attention_mask, sliding_window_mask
|
|
301
|
+
|
|
302
|
+
def __call__(
|
|
303
|
+
self,
|
|
304
|
+
input_ids,
|
|
305
|
+
attention_mask = None, # (batch_size, seq_len) see below
|
|
306
|
+
sliding_window_mask = None,
|
|
307
|
+
position_ids = None,
|
|
308
|
+
output_hidden_states = False,
|
|
309
|
+
return_dict = True,
|
|
310
|
+
):
|
|
311
|
+
output_hidden_states = (
|
|
312
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
batch_size, seq_len = input_ids.shape[:2]
|
|
316
|
+
|
|
317
|
+
if attention_mask is None:
|
|
318
|
+
attention_mask = mx.ones((batch_size, seq_len)) ### updated with _update_attention_mask() below
|
|
319
|
+
|
|
320
|
+
hidden_states = self.embeddings(input_ids)
|
|
321
|
+
model_dtype = hidden_states.dtype
|
|
322
|
+
|
|
323
|
+
# get attention mask and sliding window mask
|
|
324
|
+
attention_mask, sliding_window_mask = self._update_attention_mask(
|
|
325
|
+
attention_mask=attention_mask,
|
|
326
|
+
model_dtype=model_dtype
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
all_hidden_states = () if output_hidden_states else None
|
|
330
|
+
|
|
331
|
+
for encoder_layer in self.layers:
|
|
332
|
+
if output_hidden_states:
|
|
333
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
334
|
+
|
|
335
|
+
layer_outputs = encoder_layer(
|
|
336
|
+
hidden_states,
|
|
337
|
+
attention_mask=attention_mask,
|
|
338
|
+
sliding_window_mask=sliding_window_mask,
|
|
339
|
+
position_ids=position_ids,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
hidden_states = layer_outputs[0]
|
|
343
|
+
|
|
344
|
+
hidden_states = self.final_norm(hidden_states)
|
|
345
|
+
|
|
346
|
+
if not return_dict:
|
|
347
|
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
|
348
|
+
return {
|
|
349
|
+
"last_hidden_state": hidden_states,
|
|
350
|
+
"hidden_states": all_hidden_states,
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
### below are the classes for specific pipelines
|
|
355
|
+
class Model(RaclateBaseModel):
|
|
356
|
+
"""
|
|
357
|
+
Computes embeddings for input sequences using a ModernBERT model.
|
|
358
|
+
|
|
359
|
+
Note : sanitization is a hack to align with other models here while downloading weights
|
|
360
|
+
with the maskedlm config from HF (original modelBert model).
|
|
361
|
+
"""
|
|
362
|
+
def __init__(self, config: ModelArgs):
|
|
363
|
+
super().__init__()
|
|
364
|
+
self.config = config
|
|
365
|
+
self.model = ModernBertModel(config)
|
|
366
|
+
|
|
367
|
+
# no transformer architecture for embedding model
|
|
368
|
+
|
|
369
|
+
def __call__(
|
|
370
|
+
self,
|
|
371
|
+
input_ids : mx.array,
|
|
372
|
+
attention_mask: Optional[mx.array] = None,
|
|
373
|
+
position_ids: Optional[mx.array] = None,
|
|
374
|
+
output_hidden_states: Optional[bool] = None,
|
|
375
|
+
return_dict: Optional[bool] = True,
|
|
376
|
+
):
|
|
377
|
+
|
|
378
|
+
if attention_mask is None:
|
|
379
|
+
batch_size, seq_len = input_ids.shape
|
|
380
|
+
attention_mask = mx.ones(
|
|
381
|
+
(batch_size, seq_len),
|
|
382
|
+
dtype=self.model.embeddings.tok_embeddings.weight.dtype) ### updated via _update_attention_mask() in the model
|
|
383
|
+
|
|
384
|
+
# Get embeddings and encoder outputs as before
|
|
385
|
+
encoder_outputs = self.model(
|
|
386
|
+
input_ids,
|
|
387
|
+
attention_mask=attention_mask,
|
|
388
|
+
position_ids=position_ids,
|
|
389
|
+
output_hidden_states=output_hidden_states,
|
|
390
|
+
return_dict=return_dict,
|
|
391
|
+
)
|
|
392
|
+
last_hidden_state = encoder_outputs["last_hidden_state"] if isinstance(encoder_outputs, dict) else encoder_outputs[0]
|
|
393
|
+
|
|
394
|
+
# Pooling based on config
|
|
395
|
+
if self.config.classifier_pooling == "cls":
|
|
396
|
+
pooled = last_hidden_state[:, 0]
|
|
397
|
+
elif self.config.classifier_pooling == "mean":
|
|
398
|
+
pooled = mean_pooling(last_hidden_state, attention_mask)
|
|
399
|
+
|
|
400
|
+
text_embeds = normalize_embeddings(pooled)
|
|
401
|
+
|
|
402
|
+
if not return_dict:
|
|
403
|
+
return (text_embeds, last_hidden_state)
|
|
404
|
+
|
|
405
|
+
return {
|
|
406
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
407
|
+
"last_hidden_state": last_hidden_state,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
def sanitize(self, weights):
|
|
411
|
+
sanitized_weights = {}
|
|
412
|
+
for k, v in weights.items():
|
|
413
|
+
if "position_ids" in k:
|
|
414
|
+
# Remove unused position_ids
|
|
415
|
+
continue
|
|
416
|
+
if k in ["head.norm.weight", "head.dense.weight", "decoder.bias"]:
|
|
417
|
+
continue
|
|
418
|
+
else:
|
|
419
|
+
sanitized_weights[k] = v
|
|
420
|
+
return sanitized_weights
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class ModelForSentenceSimilarity(RaclateBaseModel):
|
|
424
|
+
"""
|
|
425
|
+
Handles:
|
|
426
|
+
1. Inference: Generates embeddings and similarity scores (cosine similarity or MaxSim if late interaction is used).
|
|
427
|
+
2. Training (Standard): (Sentence1, Sentence2, Score) -> MSE/Cosine Loss.
|
|
428
|
+
3. Training (Triplets): (Anchor, Positive, Negative) -> MNRL with Hard Negatives (Cross-entropy Loss).
|
|
429
|
+
"""
|
|
430
|
+
def __init__(self, config : ModelArgs):
|
|
431
|
+
super().__init__()
|
|
432
|
+
self.config = config
|
|
433
|
+
self.model = ModernBertModel(config)
|
|
434
|
+
|
|
435
|
+
def _call_model(
|
|
436
|
+
self,
|
|
437
|
+
input_ids: mx.array,
|
|
438
|
+
position_ids: Optional[mx.array] = None,
|
|
439
|
+
attention_mask: Optional[mx.array] = None,
|
|
440
|
+
output_hidden_states: Optional[bool] = False,
|
|
441
|
+
return_dict: Optional[bool] = True,
|
|
442
|
+
):
|
|
443
|
+
out = self.model(input_ids, attention_mask)
|
|
444
|
+
last_hidden_state = (
|
|
445
|
+
out["last_hidden_state"] if isinstance(out, dict) else out[0]
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# text_embeds = normalize_embeddings(last_hidden_state)
|
|
449
|
+
if self.config.use_late_interaction:
|
|
450
|
+
text_embeds = normalize_embeddings(last_hidden_state)
|
|
451
|
+
# Keep unpooled for ColBERT style
|
|
452
|
+
# Mask padding tokens to avoid them affecting MaxSim
|
|
453
|
+
if attention_mask is not None:
|
|
454
|
+
text_embeds = text_embeds * attention_mask[..., None]
|
|
455
|
+
else:
|
|
456
|
+
# Pooling based on config
|
|
457
|
+
if self.config.classifier_pooling == "cls":
|
|
458
|
+
pooled = last_hidden_state[:, 0]
|
|
459
|
+
elif self.config.classifier_pooling == "mean":
|
|
460
|
+
pooled = mean_pooling(last_hidden_state, attention_mask)
|
|
461
|
+
text_embeds = normalize_embeddings(pooled)
|
|
462
|
+
|
|
463
|
+
if not return_dict:
|
|
464
|
+
return (text_embeds, last_hidden_state)
|
|
465
|
+
|
|
466
|
+
return {
|
|
467
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
468
|
+
"last_hidden_state": last_hidden_state,
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
def __call__(
|
|
472
|
+
self,
|
|
473
|
+
input_ids,
|
|
474
|
+
reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
|
|
475
|
+
negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
|
|
476
|
+
attention_mask: Optional[mx.array] = None,
|
|
477
|
+
reference_attention_mask: Optional[mx.array] = None,
|
|
478
|
+
negative_attention_mask: Optional[mx.array] = None,
|
|
479
|
+
similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
|
|
480
|
+
position_ids: Optional[mx.array] = None,
|
|
481
|
+
return_dict: Optional[bool] = True,
|
|
482
|
+
):
|
|
483
|
+
|
|
484
|
+
if attention_mask is None:
|
|
485
|
+
batch_size, seq_len = input_ids.shape
|
|
486
|
+
attention_mask = mx.ones(
|
|
487
|
+
(batch_size, seq_len),
|
|
488
|
+
dtype=self.model.embeddings.tok_embeddings.weight.dtype) ### updated via _update_attention_mask() in the model
|
|
489
|
+
|
|
490
|
+
# Get embeddings for input batch
|
|
491
|
+
batch_outputs = self._call_model(
|
|
492
|
+
input_ids=input_ids,
|
|
493
|
+
attention_mask=attention_mask,
|
|
494
|
+
position_ids=position_ids,
|
|
495
|
+
return_dict=True
|
|
496
|
+
)
|
|
497
|
+
embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
|
|
498
|
+
|
|
499
|
+
loss = None
|
|
500
|
+
similarities = None
|
|
501
|
+
if reference_input_ids is not None:
|
|
502
|
+
|
|
503
|
+
# Get embeddings for reference sentences
|
|
504
|
+
ref_outputs = self._call_model(
|
|
505
|
+
input_ids=reference_input_ids,
|
|
506
|
+
attention_mask=reference_attention_mask,
|
|
507
|
+
position_ids=position_ids, ### ?
|
|
508
|
+
return_dict=True
|
|
509
|
+
)
|
|
510
|
+
reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
|
|
511
|
+
|
|
512
|
+
similarities, loss = compute_similarity_and_loss(
|
|
513
|
+
self.config,
|
|
514
|
+
input_ids,
|
|
515
|
+
embeddings,
|
|
516
|
+
reference_embeddings,
|
|
517
|
+
self._call_model,
|
|
518
|
+
similarity_scores,
|
|
519
|
+
negative_input_ids,
|
|
520
|
+
negative_attention_mask,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
if not return_dict:
|
|
524
|
+
return (loss, similarities, embeddings)
|
|
525
|
+
|
|
526
|
+
return {
|
|
527
|
+
"loss": loss,
|
|
528
|
+
"similarities": similarities, # [batch_size, num_references]
|
|
529
|
+
"embeddings": embeddings, # [batch_size, hidden_size]
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
def sanitize(self, weights):
|
|
533
|
+
sanitized_weights = {}
|
|
534
|
+
for k, v in weights.items():
|
|
535
|
+
if "position_ids" in k:
|
|
536
|
+
# Remove unused position_ids
|
|
537
|
+
continue
|
|
538
|
+
if not k.startswith("model."):
|
|
539
|
+
continue
|
|
540
|
+
else:
|
|
541
|
+
sanitized_weights[k] = v
|
|
542
|
+
return sanitized_weights
|
|
543
|
+
|
|
544
|
+
class ModelForSentenceTransformers(ModelForSentenceSimilarity):
|
|
545
|
+
"""
|
|
546
|
+
Extends ModelForSentenceSimilarity.
|
|
547
|
+
Handles:
|
|
548
|
+
1. Inference: Generates embeddings and similarity scores (cosine similarity or MaxSim if late interaction is used).
|
|
549
|
+
2. Training (Standard): (Sentence1, Sentence2, Score) -> MSE/Cosine Loss.
|
|
550
|
+
3. Training (Triplets): (Anchor, Positive, Negative) -> MNRL with Hard Negatives (Cross-entropy Loss).
|
|
551
|
+
This class sanitizes typical sentence transformers weights to align with the ModernBERT model.
|
|
552
|
+
"""
|
|
553
|
+
def __init__(self, config: ModelArgs):
|
|
554
|
+
super().__init__(config)
|
|
555
|
+
|
|
556
|
+
def sanitize(self, weights):
|
|
557
|
+
"""Convert sentence transformer weights to ModernBERT format."""
|
|
558
|
+
sanitized_weights = {}
|
|
559
|
+
|
|
560
|
+
for k, v in weights.items():
|
|
561
|
+
if "position_ids" in k:
|
|
562
|
+
# Remove unused position_ids
|
|
563
|
+
continue
|
|
564
|
+
if not k.startswith("model."):
|
|
565
|
+
new_key = "model." + k
|
|
566
|
+
else:
|
|
567
|
+
new_key = k
|
|
568
|
+
sanitized_weights[new_key] = v
|
|
569
|
+
return sanitized_weights
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class ModernBertPredictionHead(nn.Module):
|
|
573
|
+
def __init__(self, config : ModelArgs):
|
|
574
|
+
super().__init__()
|
|
575
|
+
self.dense = nn.Linear(
|
|
576
|
+
config.hidden_size, config.hidden_size, bias=False
|
|
577
|
+
) ### current HF checkpoint does not have bias for the dense layer
|
|
578
|
+
self.act = nn.GELU()
|
|
579
|
+
self.norm = nn.LayerNorm(
|
|
580
|
+
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def __call__(self, hidden_states):
|
|
584
|
+
return self.norm(self.act(self.dense(hidden_states)))
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class ModelForMaskedLM(RaclateBaseModel):
|
|
588
|
+
"""
|
|
589
|
+
Computes masked language modeling (MLM) loss for input sequences.
|
|
590
|
+
"""
|
|
591
|
+
def __init__(self, config : ModelArgs):
|
|
592
|
+
super().__init__()
|
|
593
|
+
self.config = config
|
|
594
|
+
self.model = ModernBertModel(config)
|
|
595
|
+
self.head = ModernBertPredictionHead(config) ## no bias for this in the current HF checkpoint
|
|
596
|
+
self.decoder = nn.Linear(
|
|
597
|
+
config.hidden_size, config.vocab_size, bias=config.decoder_bias
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# transformer architecture name for compatibility
|
|
601
|
+
self.hf_transformers_arch = "ModernBertForMaskedLM"
|
|
602
|
+
|
|
603
|
+
# Tie weights ### does not seem to work (sanitizing the weights to enforce weight tying)
|
|
604
|
+
self.tie_weights()
|
|
605
|
+
|
|
606
|
+
def tie_weights(self):
|
|
607
|
+
embedding_layer = self.model.get_input_embeddings()
|
|
608
|
+
self.decoder.weight = embedding_layer.weight
|
|
609
|
+
|
|
610
|
+
def get_input_embeddings(self):
|
|
611
|
+
return self.model.get_input_embeddings()
|
|
612
|
+
|
|
613
|
+
def get_output_embeddings(self):
|
|
614
|
+
return self.decoder
|
|
615
|
+
|
|
616
|
+
def set_input_embeddings(self, value):
|
|
617
|
+
self.model.set_input_embeddings(value)
|
|
618
|
+
self.tie_weights() # Re-tie weights after setting new embeddings
|
|
619
|
+
|
|
620
|
+
def set_output_embeddings(self, new_embeddings):
|
|
621
|
+
self.decoder = new_embeddings
|
|
622
|
+
self.tie_weights() # Re-tie weights after setting new decoder
|
|
623
|
+
|
|
624
|
+
def __call__(
|
|
625
|
+
self,
|
|
626
|
+
input_ids,
|
|
627
|
+
attention_mask: Optional[mx.array] = None,
|
|
628
|
+
labels: Optional[mx.array] = None,
|
|
629
|
+
position_ids: Optional[mx.array] = None,
|
|
630
|
+
output_hidden_states: Optional[bool] = None,
|
|
631
|
+
return_dict: Optional[bool] = True,
|
|
632
|
+
) -> Dict:
|
|
633
|
+
|
|
634
|
+
if attention_mask is None:
|
|
635
|
+
batch_size, seq_len = input_ids.shape
|
|
636
|
+
attention_mask = mx.ones((batch_size, seq_len)) ### updated via _update_attention_mask() in the model
|
|
637
|
+
|
|
638
|
+
outputs = self.model(
|
|
639
|
+
input_ids=input_ids,
|
|
640
|
+
attention_mask=attention_mask,
|
|
641
|
+
position_ids=position_ids,
|
|
642
|
+
output_hidden_states=output_hidden_states,
|
|
643
|
+
return_dict=return_dict,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
|
|
647
|
+
logits = self.head(last_hidden_state)
|
|
648
|
+
logits = self.decoder(logits)
|
|
649
|
+
|
|
650
|
+
loss = None
|
|
651
|
+
if self.training and labels is not None :
|
|
652
|
+
if getattr(self.config, "sparse_prediction", False):
|
|
653
|
+
# Flatten labels and predictions
|
|
654
|
+
flat_labels = labels.reshape(-1)
|
|
655
|
+
flat_predictions = logits.reshape(-1, logits.shape[-1])
|
|
656
|
+
|
|
657
|
+
# Filter out non-masked tokens
|
|
658
|
+
ignore_index = getattr(self.config, "sparse_pred_ignore_index", -100)
|
|
659
|
+
mask_tokens = flat_labels != ignore_index
|
|
660
|
+
|
|
661
|
+
# Only compute loss on masked tokens
|
|
662
|
+
masked_predictions = flat_predictions[mask_tokens]
|
|
663
|
+
masked_labels = flat_labels[mask_tokens]
|
|
664
|
+
|
|
665
|
+
loss = nn.losses.cross_entropy(
|
|
666
|
+
masked_predictions,
|
|
667
|
+
masked_labels,
|
|
668
|
+
reduction='mean'
|
|
669
|
+
)
|
|
670
|
+
else:
|
|
671
|
+
# Standard loss computation on all tokens
|
|
672
|
+
loss = nn.losses.cross_entropy(
|
|
673
|
+
logits.reshape(-1, logits.shape[-1]),
|
|
674
|
+
labels.reshape(-1),
|
|
675
|
+
reduction='mean'
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
if not return_dict:
|
|
679
|
+
return [loss, logits, outputs[1:]]
|
|
680
|
+
|
|
681
|
+
return {
|
|
682
|
+
"loss": loss,
|
|
683
|
+
"logits": logits,
|
|
684
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
def sanitize(self, weights):
|
|
688
|
+
sanitized_weights = {}
|
|
689
|
+
for k, v in weights.items():
|
|
690
|
+
if "position_ids" in k:
|
|
691
|
+
# Remove unused position_ids
|
|
692
|
+
continue
|
|
693
|
+
if k == "model.embeddings.tok_embeddings.weight":
|
|
694
|
+
### going around the weight tying issue. TODO : improve this
|
|
695
|
+
sanitized_weights["decoder.weight"] = v
|
|
696
|
+
sanitized_weights[k] = v
|
|
697
|
+
else:
|
|
698
|
+
sanitized_weights[k] = v
|
|
699
|
+
return sanitized_weights
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class ModelForSequenceClassification(RaclateBaseModel):
|
|
703
|
+
"""
|
|
704
|
+
Computes sequence classification probabilities for input sequences.
|
|
705
|
+
Sanitization aligns typical BERT weights with the ModernBERT model.
|
|
706
|
+
|
|
707
|
+
NOTE : binary classification not tested.
|
|
708
|
+
"""
|
|
709
|
+
def __init__(self, config: ModelArgs):
|
|
710
|
+
super().__init__()
|
|
711
|
+
self.config = config
|
|
712
|
+
self.num_labels = config.num_labels
|
|
713
|
+
self.is_regression = config.is_regression
|
|
714
|
+
|
|
715
|
+
self.model = ModernBertModel(config)
|
|
716
|
+
self.head = ModernBertPredictionHead(config)
|
|
717
|
+
self.drop = nn.Dropout(p=config.classifier_dropout)
|
|
718
|
+
self.classifier = nn.Linear(
|
|
719
|
+
config.hidden_size,
|
|
720
|
+
config.num_labels,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# transformer architecture name for compatibility
|
|
724
|
+
self.hf_transformers_arch = "ModernBertForSequenceClassification"
|
|
725
|
+
|
|
726
|
+
def _process_outputs(self, logits: mx.array) -> mx.array:
|
|
727
|
+
"""Apply the appropriate activation function to the logits."""
|
|
728
|
+
if self.is_regression:
|
|
729
|
+
return logits # No activation for regression
|
|
730
|
+
elif self.num_labels == 1:
|
|
731
|
+
return mx.sigmoid(logits) # Binary classification
|
|
732
|
+
else:
|
|
733
|
+
# Using softmax for multi-class classification
|
|
734
|
+
return mx.softmax(logits, axis=-1)
|
|
735
|
+
|
|
736
|
+
def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
|
|
737
|
+
"""Compute the appropriate loss based on label characteristics."""
|
|
738
|
+
if self.is_regression:
|
|
739
|
+
return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
|
|
740
|
+
elif self.num_labels == 1:
|
|
741
|
+
return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
|
|
742
|
+
else:
|
|
743
|
+
return nn.losses.cross_entropy(
|
|
744
|
+
logits.reshape(-1, self.num_labels),
|
|
745
|
+
labels.reshape(-1)
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
def __call__(
|
|
749
|
+
self,
|
|
750
|
+
input_ids,
|
|
751
|
+
attention_mask: Optional[mx.array] = None,
|
|
752
|
+
position_ids: Optional[mx.array] = None, ### need this?
|
|
753
|
+
labels: Optional[mx.array] = None,
|
|
754
|
+
output_hidden_states: Optional[bool] = False,
|
|
755
|
+
return_dict: Optional[bool] = True,
|
|
756
|
+
) -> Dict:
|
|
757
|
+
|
|
758
|
+
if attention_mask is None:
|
|
759
|
+
batch_size, seq_len = input_ids.shape
|
|
760
|
+
attention_mask = mx.ones((batch_size, seq_len))
|
|
761
|
+
|
|
762
|
+
outputs = self.model(
|
|
763
|
+
input_ids=input_ids,
|
|
764
|
+
attention_mask=attention_mask,
|
|
765
|
+
position_ids=position_ids,
|
|
766
|
+
output_hidden_states=output_hidden_states,
|
|
767
|
+
return_dict=return_dict,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
|
|
771
|
+
|
|
772
|
+
# Pooling strategy
|
|
773
|
+
if self.config.classifier_pooling == "cls":
|
|
774
|
+
pooled = last_hidden_state[:, 0]
|
|
775
|
+
elif self.config.classifier_pooling == "mean":
|
|
776
|
+
pooled = mean_pooling(last_hidden_state, attention_mask)
|
|
777
|
+
|
|
778
|
+
# Apply head, dropout and classifier
|
|
779
|
+
pooled = self.head(pooled)
|
|
780
|
+
pooled = self.drop(pooled)
|
|
781
|
+
logits = self.classifier(pooled)
|
|
782
|
+
|
|
783
|
+
# Process logits for inference
|
|
784
|
+
processed_logits = self._process_outputs(logits)
|
|
785
|
+
|
|
786
|
+
loss = None
|
|
787
|
+
if labels is not None :
|
|
788
|
+
loss = self._compute_loss(logits, labels)
|
|
789
|
+
|
|
790
|
+
if not return_dict:
|
|
791
|
+
return [loss, processed_logits, outputs[1:]]
|
|
792
|
+
|
|
793
|
+
return {
|
|
794
|
+
"loss": loss,
|
|
795
|
+
"probabilities": processed_logits,
|
|
796
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
def sanitize(self, weights):
|
|
800
|
+
sanitized_weights = {}
|
|
801
|
+
for k, v in weights.items():
|
|
802
|
+
if "position_ids" in k:
|
|
803
|
+
# Remove unused position_ids
|
|
804
|
+
continue
|
|
805
|
+
if k in ["decoder.bias"]:
|
|
806
|
+
### this is the hack
|
|
807
|
+
continue
|
|
808
|
+
elif k.startswith("bert"):
|
|
809
|
+
# Handle legacy BERT naming if needed
|
|
810
|
+
new_k = k.replace("bert.", "model.")
|
|
811
|
+
sanitized_weights[new_k] = v
|
|
812
|
+
else:
|
|
813
|
+
sanitized_weights[k] = v
|
|
814
|
+
return sanitized_weights
|
|
815
|
+
|
|
816
|
+
class ModelForTokenClassification(RaclateBaseModel):
|
|
817
|
+
"""
|
|
818
|
+
Computes token classification probabilities for input sequences.
|
|
819
|
+
|
|
820
|
+
NOTE: untested for now
|
|
821
|
+
TODO : https://huggingface.co/disham993/electrical-ner-ModernBERT-base
|
|
822
|
+
"""
|
|
823
|
+
def __init__(self, config: ModelArgs):
|
|
824
|
+
super().__init__()
|
|
825
|
+
self.config = config
|
|
826
|
+
self.num_labels = config.num_labels
|
|
827
|
+
|
|
828
|
+
self.model = ModernBertModel(config)
|
|
829
|
+
self.head = ModernBertPredictionHead(config)
|
|
830
|
+
self.drop = nn.Dropout(p=config.classifier_dropout)
|
|
831
|
+
self.classifier = nn.Linear(
|
|
832
|
+
config.hidden_size,
|
|
833
|
+
config.num_labels,
|
|
834
|
+
# bias=config.classifier_bias
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# transformer architecture name for compatibility
|
|
838
|
+
self.hf_transformers_arch = "ModernBertForTokenClassification"
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def __call__(
|
|
842
|
+
self,
|
|
843
|
+
input_ids,
|
|
844
|
+
attention_mask: Optional[mx.array] = None,
|
|
845
|
+
position_ids: Optional[mx.array] = None,
|
|
846
|
+
labels: Optional[mx.array] = None,
|
|
847
|
+
output_hidden_states: Optional[bool] = None,
|
|
848
|
+
return_dict: Optional[bool] = True,
|
|
849
|
+
) -> Dict:
|
|
850
|
+
if attention_mask is None:
|
|
851
|
+
batch_size, seq_len = input_ids.shape
|
|
852
|
+
attention_mask = mx.ones((batch_size, seq_len))
|
|
853
|
+
|
|
854
|
+
outputs = self.model(
|
|
855
|
+
input_ids=input_ids,
|
|
856
|
+
attention_mask=attention_mask,
|
|
857
|
+
position_ids=position_ids,
|
|
858
|
+
output_hidden_states=output_hidden_states,
|
|
859
|
+
return_dict=return_dict,
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
|
|
863
|
+
|
|
864
|
+
# Apply prediction head, dropout, and classification layer to each token
|
|
865
|
+
sequence_output = self.head(last_hidden_state)
|
|
866
|
+
sequence_output = self.drop(sequence_output)
|
|
867
|
+
logits = self.classifier(sequence_output)
|
|
868
|
+
|
|
869
|
+
# Process logits for inference
|
|
870
|
+
processed_logits = mx.softmax(logits, axis=-1)
|
|
871
|
+
|
|
872
|
+
loss = None
|
|
873
|
+
if labels is not None:
|
|
874
|
+
# Compute token classification loss
|
|
875
|
+
loss = nn.losses.cross_entropy(
|
|
876
|
+
logits.reshape(-1, self.num_labels),
|
|
877
|
+
labels.reshape(-1)
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
if not return_dict:
|
|
881
|
+
return [loss, processed_logits, outputs[1:]]
|
|
882
|
+
|
|
883
|
+
return {
|
|
884
|
+
"loss": loss,
|
|
885
|
+
"probabilities": processed_logits,
|
|
886
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
def sanitize(self, weights):
|
|
890
|
+
sanitized_weights = {}
|
|
891
|
+
for k, v in weights.items():
|
|
892
|
+
if "position_ids" in k:
|
|
893
|
+
# Remove unused position_ids
|
|
894
|
+
continue
|
|
895
|
+
if k in ["decoder.bias"]:
|
|
896
|
+
### this is the hack
|
|
897
|
+
continue
|
|
898
|
+
else:
|
|
899
|
+
sanitized_weights[k] = v
|
|
900
|
+
return sanitized_weights
|