mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,913 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Dict, List, Optional, Any, Literal
|
|
3
|
+
import re
|
|
4
|
+
from functools import partial
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from .base import (
|
|
10
|
+
BaseModelArgs,
|
|
11
|
+
mean_pooling,
|
|
12
|
+
last_token_pooling,
|
|
13
|
+
normalize_embeddings,
|
|
14
|
+
compute_similarity_and_loss,
|
|
15
|
+
RaclateBaseModel,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ModelArgs(BaseModelArgs):
|
|
20
|
+
architectures: List[str] = field(default_factory=lambda: ["Gemma3TextModel"])
|
|
21
|
+
attention_bias: Optional[bool] = False
|
|
22
|
+
attention_dropout: Optional[float] = 0.0
|
|
23
|
+
attn_logit_softcapping: Optional[float] = None # Not supported with sdpa
|
|
24
|
+
bos_token_id: Optional[int] = None
|
|
25
|
+
eos_token_id: Optional[int] = None
|
|
26
|
+
final_logit_softcapping: Optional[float] = None # Not supported with sdpa
|
|
27
|
+
hidden_activation: Optional[str] = "gelu_pytorch_tanh"
|
|
28
|
+
hidden_size: int = 1152
|
|
29
|
+
intermediate_size: int = 6912
|
|
30
|
+
initializer_range: Optional[float] = (
|
|
31
|
+
0.02 # Only needed in case of initializing weights
|
|
32
|
+
)
|
|
33
|
+
head_dim: int = 256
|
|
34
|
+
layer_types: List[str] = field(default_factory=list)
|
|
35
|
+
max_position_embeddings: int = 2048
|
|
36
|
+
model_type: str = "gemma3_text"
|
|
37
|
+
num_attention_heads: int = 4
|
|
38
|
+
num_hidden_layers: int = 26
|
|
39
|
+
num_key_value_heads: int = 1
|
|
40
|
+
query_pre_attn_scalar: float = 256
|
|
41
|
+
rms_norm_eps: float = 1.0e-6
|
|
42
|
+
rope_global_base_freq: Optional[float] = None
|
|
43
|
+
rope_local_base_freq: float = 10000.0
|
|
44
|
+
rope_theta: Optional[float] = None
|
|
45
|
+
rope_traditional: bool = False
|
|
46
|
+
sliding_window: int = 512
|
|
47
|
+
_sliding_window_pattern: Optional[int] = None
|
|
48
|
+
sliding_window_pattern: Optional[int] = None
|
|
49
|
+
use_bidirectional_attn: bool = False
|
|
50
|
+
use_bidirectional_attention: bool = False
|
|
51
|
+
vocab_size: int = 262144
|
|
52
|
+
|
|
53
|
+
### Defaults
|
|
54
|
+
default_sliding_pattern: int = 6
|
|
55
|
+
default_global_rope_freq: float = 1000000.0
|
|
56
|
+
|
|
57
|
+
### pipeline args
|
|
58
|
+
decoder_bias=True,
|
|
59
|
+
classifier_dropout=0.0
|
|
60
|
+
classifier_bias=False
|
|
61
|
+
sparse_prediction=True ### True seems a more appropriate value for MLM
|
|
62
|
+
sparse_pred_ignore_index=-100
|
|
63
|
+
is_regression: Optional[bool] = None
|
|
64
|
+
label2id: Optional[Dict[str, int]] = None
|
|
65
|
+
id2label: Optional[Dict[int, str]] = None
|
|
66
|
+
pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
|
|
67
|
+
use_late_interaction: bool = False
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def sliding_pattern(self) -> int:
|
|
71
|
+
if self.sliding_window_pattern is not None:
|
|
72
|
+
return self.sliding_window_pattern
|
|
73
|
+
if self._sliding_window_pattern is not None:
|
|
74
|
+
return self._sliding_window_pattern
|
|
75
|
+
return self.default_sliding_pattern
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def rope_global_freq(self) -> float:
|
|
79
|
+
if self.rope_global_base_freq is not None:
|
|
80
|
+
return self.rope_global_base_freq
|
|
81
|
+
if self.rope_theta is not None:
|
|
82
|
+
return self.rope_theta
|
|
83
|
+
return self.default_global_rope_freq
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def is_causal(self) -> bool:
|
|
87
|
+
return not self.use_bidirectional_attn and not self.use_bidirectional_attention
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def num_labels(self) -> int:
|
|
91
|
+
"""
|
|
92
|
+
Number of labels is determined by:
|
|
93
|
+
- For zero-shot classification: length of label_candidates
|
|
94
|
+
- For regression or binary with sigmoid: 1
|
|
95
|
+
- For classification: length of id2label mapping
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
if self.is_regression:
|
|
99
|
+
return 1
|
|
100
|
+
|
|
101
|
+
if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
|
|
102
|
+
return 1
|
|
103
|
+
|
|
104
|
+
if self.id2label is None:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"id2label mapping must be provided for categorical classification. "
|
|
107
|
+
"For regression or binary classification with sigmoid output, "
|
|
108
|
+
"set is_regression=True or binary_sigmoid=True in pipeline_config."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return len(self.id2label)
|
|
112
|
+
|
|
113
|
+
def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
|
|
114
|
+
"""
|
|
115
|
+
Standardizes keys for the Gemma3 Backbone.
|
|
116
|
+
Prefixes generic keys with 'model.' and handles basic mapping.
|
|
117
|
+
"""
|
|
118
|
+
sanitized = {}
|
|
119
|
+
for k, v in weights.items():
|
|
120
|
+
# Skip unrelated heads that might be in the checkpoint
|
|
121
|
+
if any(x in k for x in ["lm_head", "classifier"]):
|
|
122
|
+
# We don't automatically map these; specific models handle them if needed
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
# Map generic 'layers' to 'model.layers' if not already present
|
|
126
|
+
if "Dense.linear" not in k and \
|
|
127
|
+
not k.startswith("model.") and \
|
|
128
|
+
not k.startswith("dense.") and \
|
|
129
|
+
not k.startswith("score.") and \
|
|
130
|
+
not k.startswith("head.") and \
|
|
131
|
+
not k.startswith("decoder."):
|
|
132
|
+
|
|
133
|
+
new_key = f"model.{k}"
|
|
134
|
+
|
|
135
|
+
sanitized[new_key] = v
|
|
136
|
+
else:
|
|
137
|
+
sanitized[k] = v
|
|
138
|
+
|
|
139
|
+
return sanitized
|
|
140
|
+
|
|
141
|
+
class Attention(nn.Module):
|
|
142
|
+
def __init__(self, args: ModelArgs, layer_idx: int):
|
|
143
|
+
super().__init__()
|
|
144
|
+
|
|
145
|
+
dim = args.hidden_size
|
|
146
|
+
self.n_heads = n_heads = args.num_attention_heads
|
|
147
|
+
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
148
|
+
self.repeats = n_heads // n_kv_heads
|
|
149
|
+
self.head_dim = head_dim = args.head_dim
|
|
150
|
+
self.layer_idx = layer_idx
|
|
151
|
+
|
|
152
|
+
self.scale = args.query_pre_attn_scalar**-0.5
|
|
153
|
+
|
|
154
|
+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
155
|
+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
156
|
+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
157
|
+
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
158
|
+
|
|
159
|
+
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
|
160
|
+
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
|
|
161
|
+
|
|
162
|
+
layer_type = args.layer_types[layer_idx] if args.layer_types else None
|
|
163
|
+
if not layer_type:
|
|
164
|
+
if (layer_idx + 1) % args.sliding_pattern == 0:
|
|
165
|
+
layer_type = "full_attention"
|
|
166
|
+
else:
|
|
167
|
+
layer_type = "sliding_window"
|
|
168
|
+
self.is_sliding = layer_type == "sliding_window"
|
|
169
|
+
|
|
170
|
+
base = (
|
|
171
|
+
args.rope_local_base_freq
|
|
172
|
+
if self.is_sliding
|
|
173
|
+
else args.rope_global_freq
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.rope = nn.RoPE(
|
|
177
|
+
head_dim,
|
|
178
|
+
traditional=args.rope_traditional,
|
|
179
|
+
base=base,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Add softcapping support
|
|
183
|
+
self.attn_logit_softcapping = args.attn_logit_softcapping
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def __call__(
|
|
187
|
+
self,
|
|
188
|
+
x: mx.array,
|
|
189
|
+
mask: Optional[mx.array] = None
|
|
190
|
+
) -> mx.array:
|
|
191
|
+
B, L, _ = x.shape
|
|
192
|
+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
193
|
+
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
|
194
|
+
|
|
195
|
+
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
196
|
+
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
197
|
+
|
|
198
|
+
queries = self.q_norm(queries)
|
|
199
|
+
keys = self.k_norm(keys)
|
|
200
|
+
|
|
201
|
+
queries = self.rope(queries)
|
|
202
|
+
keys = self.rope(keys)
|
|
203
|
+
|
|
204
|
+
if self.attn_logit_softcapping is None:
|
|
205
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
206
|
+
queries, keys, values, scale=self.scale, mask=mask
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
raise NotImplementedError("Softcapping attention not supported with sdpa.")
|
|
210
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
211
|
+
return self.o_proj(output)
|
|
212
|
+
|
|
213
|
+
class RMSNorm(nn.Module):
|
|
214
|
+
def __init__(self, dims: int, eps: float = 1e-5):
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.weight = mx.ones((dims,))
|
|
217
|
+
self.eps = eps
|
|
218
|
+
|
|
219
|
+
def __call__(self, x):
|
|
220
|
+
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
|
221
|
+
|
|
222
|
+
class MLP(nn.Module):
|
|
223
|
+
def __init__(self, dim, hidden_dim):
|
|
224
|
+
super().__init__()
|
|
225
|
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
226
|
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
227
|
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
228
|
+
|
|
229
|
+
def __call__(self, x) -> mx.array:
|
|
230
|
+
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
|
|
231
|
+
|
|
232
|
+
@partial(mx.compile, shapeless=True)
|
|
233
|
+
def clip_residual(x, y):
|
|
234
|
+
if x.dtype != mx.float16:
|
|
235
|
+
return x + y
|
|
236
|
+
bound = mx.finfo(mx.float16).max
|
|
237
|
+
return mx.clip(x.astype(mx.float32) + y.astype(mx.float32), -bound, bound).astype(
|
|
238
|
+
mx.float16
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
class TransformerBlock(nn.Module):
|
|
242
|
+
def __init__(self, args: ModelArgs, layer_idx: int):
|
|
243
|
+
super().__init__()
|
|
244
|
+
self.num_attention_heads = args.num_attention_heads
|
|
245
|
+
self.hidden_size = args.hidden_size
|
|
246
|
+
self.self_attn = Attention(args, layer_idx)
|
|
247
|
+
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
248
|
+
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
249
|
+
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
250
|
+
self.pre_feedforward_layernorm = RMSNorm(
|
|
251
|
+
args.hidden_size, eps=args.rms_norm_eps
|
|
252
|
+
)
|
|
253
|
+
self.post_feedforward_layernorm = RMSNorm(
|
|
254
|
+
args.hidden_size, eps=args.rms_norm_eps
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def __call__(
|
|
258
|
+
self,
|
|
259
|
+
x: mx.array,
|
|
260
|
+
mask: Optional[mx.array] = None
|
|
261
|
+
) -> mx.array:
|
|
262
|
+
r = self.self_attn(self.input_layernorm(x), mask)
|
|
263
|
+
h = clip_residual(x, self.post_attention_layernorm(r))
|
|
264
|
+
r = self.mlp(self.pre_feedforward_layernorm(h))
|
|
265
|
+
out = clip_residual(h, self.post_feedforward_layernorm(r))
|
|
266
|
+
return (out,)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class Gemma3Model(nn.Module):
|
|
270
|
+
def __init__(self, config: ModelArgs):
|
|
271
|
+
super().__init__()
|
|
272
|
+
self.config = config
|
|
273
|
+
self.vocab_size = config.vocab_size
|
|
274
|
+
self.num_hidden_layers = config.num_hidden_layers
|
|
275
|
+
assert self.vocab_size > 0
|
|
276
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
277
|
+
self.layers = [
|
|
278
|
+
TransformerBlock(config, layer_idx=i)
|
|
279
|
+
for i in range(config.num_hidden_layers)
|
|
280
|
+
]
|
|
281
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
282
|
+
|
|
283
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
|
284
|
+
return self.embed_tokens
|
|
285
|
+
|
|
286
|
+
def set_input_embeddings(self, value):
|
|
287
|
+
self.embed_tokens = value
|
|
288
|
+
|
|
289
|
+
def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
|
|
290
|
+
"""
|
|
291
|
+
Creates a causal mask and combines it with the padding mask.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
B, L = attention_mask.shape
|
|
295
|
+
window_size = self.config.sliding_window
|
|
296
|
+
indices = mx.arange(L)
|
|
297
|
+
row = indices[:, None]
|
|
298
|
+
col = row.T
|
|
299
|
+
|
|
300
|
+
if not self.config.is_causal:
|
|
301
|
+
mask_base = mx.zeros((L, L), dtype=mx.bool_) # All False (visible)
|
|
302
|
+
|
|
303
|
+
# Sliding Window Logic for Bidirectional:
|
|
304
|
+
# Valid if abs(row - col) < window
|
|
305
|
+
# Mask if distance >= window
|
|
306
|
+
dist = mx.abs(row - col)
|
|
307
|
+
mask_window_violator = dist >= window_size
|
|
308
|
+
|
|
309
|
+
else:
|
|
310
|
+
# Causal: Standard triangular mask
|
|
311
|
+
mask_future = col > row
|
|
312
|
+
mask_base = mask_future
|
|
313
|
+
|
|
314
|
+
# Sliding Window Logic for Causal:
|
|
315
|
+
# Valid if row - col < window (and not future)
|
|
316
|
+
# Mask if (row - col) >= window
|
|
317
|
+
mask_past = (row - col) >= window_size
|
|
318
|
+
mask_window_violator = mask_past
|
|
319
|
+
|
|
320
|
+
global_mask = mx.where(mask_base, -1e9, 0.0).astype(dtype)
|
|
321
|
+
sliding_mask_bool = mask_base | mask_window_violator
|
|
322
|
+
sliding_mask = mx.where(sliding_mask_bool, -1e9, 0.0).astype(dtype)
|
|
323
|
+
|
|
324
|
+
# Padding Mask
|
|
325
|
+
if attention_mask is not None:
|
|
326
|
+
# Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
|
|
327
|
+
padding_mask = attention_mask[:, None, None, :]
|
|
328
|
+
additive_padding = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
|
|
329
|
+
|
|
330
|
+
global_mask = global_mask + additive_padding
|
|
331
|
+
sliding_mask = sliding_mask + additive_padding
|
|
332
|
+
|
|
333
|
+
return global_mask, sliding_mask
|
|
334
|
+
|
|
335
|
+
def __call__(
|
|
336
|
+
self,
|
|
337
|
+
input_ids: mx.array,
|
|
338
|
+
attention_mask: Optional[mx.array] = None,
|
|
339
|
+
output_hidden_states: Optional[bool] = False,
|
|
340
|
+
position_ids: Optional[mx.array] = None,
|
|
341
|
+
return_dict: Optional[bool] = True
|
|
342
|
+
):
|
|
343
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
344
|
+
model_dtype = hidden_states.dtype
|
|
345
|
+
|
|
346
|
+
# normalizer
|
|
347
|
+
hidden_states *= mx.array(self.config.hidden_size**0.5, model_dtype)
|
|
348
|
+
|
|
349
|
+
global_mask, sliding_window_mask = self._update_attention_mask(
|
|
350
|
+
attention_mask,
|
|
351
|
+
dtype=model_dtype
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
for i, layer in enumerate(self.layers):
|
|
355
|
+
if self.config.layer_types:
|
|
356
|
+
is_global = self.config.layer_types[i] == "full_attention"
|
|
357
|
+
else:
|
|
358
|
+
# Fallback to pattern
|
|
359
|
+
is_global = (i + 1) % self.config.sliding_pattern == 0
|
|
360
|
+
layer_mask = global_mask if is_global else sliding_window_mask
|
|
361
|
+
layer_outputs = layer(hidden_states, layer_mask)
|
|
362
|
+
hidden_states = layer_outputs[0]
|
|
363
|
+
|
|
364
|
+
hidden_states = self.norm(hidden_states)
|
|
365
|
+
|
|
366
|
+
return {
|
|
367
|
+
"last_hidden_state": hidden_states,
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class Model(RaclateBaseModel):
|
|
372
|
+
def __init__(self, config: ModelArgs):
|
|
373
|
+
super().__init__()
|
|
374
|
+
self.config = config
|
|
375
|
+
self.model_type = config.model_type
|
|
376
|
+
self.model = Gemma3Model(config)
|
|
377
|
+
self.dense = [
|
|
378
|
+
nn.Linear(config.hidden_size, config.hidden_size * 4, bias=False),
|
|
379
|
+
nn.Linear(config.hidden_size * 4, config.hidden_size, bias=False),
|
|
380
|
+
]
|
|
381
|
+
|
|
382
|
+
def __call__(
|
|
383
|
+
self,
|
|
384
|
+
input_ids: mx.array,
|
|
385
|
+
position_ids: Optional[mx.array] = None,
|
|
386
|
+
attention_mask: Optional[mx.array] = None,
|
|
387
|
+
output_hidden_states: Optional[bool] = False,
|
|
388
|
+
return_dict: Optional[bool] = True,
|
|
389
|
+
):
|
|
390
|
+
|
|
391
|
+
if attention_mask is None:
|
|
392
|
+
batch_size, seq_len = input_ids.shape
|
|
393
|
+
attention_mask = mx.ones(
|
|
394
|
+
(batch_size, seq_len),
|
|
395
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
out = self.model(input_ids, attention_mask)
|
|
399
|
+
last_hidden_state = (
|
|
400
|
+
out["last_hidden_state"] if isinstance(out, dict) else out[0]
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# normalized features
|
|
404
|
+
if not self.config.is_causal:
|
|
405
|
+
text_embeds = mean_pooling(last_hidden_state, attention_mask)
|
|
406
|
+
else:
|
|
407
|
+
text_embeds = last_token_pooling(last_hidden_state, attention_mask)
|
|
408
|
+
|
|
409
|
+
for layer in self.dense:
|
|
410
|
+
text_embeds = layer(text_embeds)
|
|
411
|
+
|
|
412
|
+
text_embeds = normalize_embeddings(text_embeds)
|
|
413
|
+
|
|
414
|
+
if not return_dict:
|
|
415
|
+
return (text_embeds, last_hidden_state)
|
|
416
|
+
|
|
417
|
+
return {
|
|
418
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
419
|
+
"last_hidden_state": last_hidden_state,
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
def sanitize(self, weights):
|
|
423
|
+
|
|
424
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
425
|
+
|
|
426
|
+
# Handle SentenceTransformer specific keys
|
|
427
|
+
final_weights = {}
|
|
428
|
+
for k, v in sanitized_weights.items():
|
|
429
|
+
|
|
430
|
+
if not k.startswith("model."):
|
|
431
|
+
continue
|
|
432
|
+
final_weights[k] = v
|
|
433
|
+
|
|
434
|
+
return final_weights
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class ModelForSentenceSimilarity(RaclateBaseModel):
|
|
438
|
+
"""
|
|
439
|
+
Computes similarity scores between input sequences and reference sentences.
|
|
440
|
+
"""
|
|
441
|
+
def __init__(self, config: ModelArgs):
|
|
442
|
+
super().__init__()
|
|
443
|
+
self.config = config
|
|
444
|
+
self.model_type = config.model_type
|
|
445
|
+
self.model = Gemma3Model(config)
|
|
446
|
+
self.dense = [
|
|
447
|
+
nn.Linear(config.hidden_size, config.hidden_size * 4, bias=False),
|
|
448
|
+
nn.Linear(config.hidden_size * 4, config.hidden_size, bias=False),
|
|
449
|
+
]
|
|
450
|
+
|
|
451
|
+
def _call_model(self, input_ids, attention_mask=None, return_dict=True):
|
|
452
|
+
out = self.model(input_ids, attention_mask)
|
|
453
|
+
last_hidden_state = (
|
|
454
|
+
out["last_hidden_state"] if isinstance(out, dict) else out[0]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# text_embeds = normalize_embeddings(last_hidden_state)
|
|
458
|
+
if self.config.use_late_interaction:
|
|
459
|
+
for dense in self.dense:
|
|
460
|
+
last_hidden_state = dense(last_hidden_state)
|
|
461
|
+
|
|
462
|
+
text_embeds = normalize_embeddings(last_hidden_state)
|
|
463
|
+
# Keep unpooled for ColBERT style
|
|
464
|
+
# Mask padding tokens to avoid them affecting MaxSim
|
|
465
|
+
if attention_mask is not None:
|
|
466
|
+
text_embeds = text_embeds * attention_mask[..., None]
|
|
467
|
+
else:
|
|
468
|
+
# Standard dense retrieval: Mean Pooling
|
|
469
|
+
if not self.config.is_causal:
|
|
470
|
+
text_embeds = mean_pooling(last_hidden_state, attention_mask)
|
|
471
|
+
else:
|
|
472
|
+
text_embeds = last_token_pooling(last_hidden_state, attention_mask)
|
|
473
|
+
|
|
474
|
+
for layer in self.dense:
|
|
475
|
+
text_embeds = layer(text_embeds)
|
|
476
|
+
|
|
477
|
+
text_embeds = normalize_embeddings(text_embeds)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
if not return_dict:
|
|
481
|
+
return (text_embeds, last_hidden_state)
|
|
482
|
+
|
|
483
|
+
return {
|
|
484
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
485
|
+
"last_hidden_state": last_hidden_state,
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
def __call__(
|
|
489
|
+
self,
|
|
490
|
+
input_ids,
|
|
491
|
+
reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
|
|
492
|
+
negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
|
|
493
|
+
attention_mask: Optional[mx.array] = None,
|
|
494
|
+
reference_attention_mask: Optional[mx.array] = None,
|
|
495
|
+
negative_attention_mask: Optional[mx.array] = None,
|
|
496
|
+
similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
|
|
497
|
+
position_ids: Optional[mx.array] = None,
|
|
498
|
+
return_dict: Optional[bool] = True,
|
|
499
|
+
):
|
|
500
|
+
if attention_mask is None:
|
|
501
|
+
batch_size, seq_len = input_ids.shape
|
|
502
|
+
attention_mask = mx.ones(
|
|
503
|
+
(batch_size, seq_len),
|
|
504
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Get embeddings for input batch
|
|
508
|
+
batch_outputs = self._call_model(
|
|
509
|
+
input_ids=input_ids,
|
|
510
|
+
attention_mask=attention_mask,
|
|
511
|
+
return_dict=True
|
|
512
|
+
)
|
|
513
|
+
embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
|
|
514
|
+
|
|
515
|
+
loss = None
|
|
516
|
+
similarities = None
|
|
517
|
+
|
|
518
|
+
if reference_input_ids is not None:
|
|
519
|
+
|
|
520
|
+
# Get embeddings for reference sentences
|
|
521
|
+
ref_outputs = self._call_model(
|
|
522
|
+
input_ids=reference_input_ids,
|
|
523
|
+
attention_mask=reference_attention_mask,
|
|
524
|
+
return_dict=True
|
|
525
|
+
)
|
|
526
|
+
reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
|
|
527
|
+
|
|
528
|
+
# Compute similarities and loss
|
|
529
|
+
similarities, loss = compute_similarity_and_loss(
|
|
530
|
+
self.config,
|
|
531
|
+
input_ids=input_ids,
|
|
532
|
+
embeddings=embeddings,
|
|
533
|
+
reference_embeddings=reference_embeddings,
|
|
534
|
+
call_model=self._call_model,
|
|
535
|
+
similarity_scores=similarity_scores,
|
|
536
|
+
negative_input_ids=negative_input_ids,
|
|
537
|
+
negative_attention_mask=negative_attention_mask
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
if not return_dict:
|
|
541
|
+
return (loss, similarities, embeddings)
|
|
542
|
+
|
|
543
|
+
return {
|
|
544
|
+
"loss": loss,
|
|
545
|
+
"similarities": similarities, # [batch_size, num_references]
|
|
546
|
+
"embeddings": embeddings, # [batch_size, hidden_size]
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
def sanitize(self, weights):
|
|
550
|
+
|
|
551
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
552
|
+
|
|
553
|
+
# Handle SentenceTransformer specific keys
|
|
554
|
+
final_weights = {}
|
|
555
|
+
for k, v in sanitized_weights.items():
|
|
556
|
+
|
|
557
|
+
if k.startswith("dense.") or k.startswith("model."):
|
|
558
|
+
final_weights[k] = v
|
|
559
|
+
elif re.search(r"\d+_Dense", k):
|
|
560
|
+
key_id = "0" if v.shape[0] > v.shape[1] else "1"
|
|
561
|
+
new_key = re.sub(r"\d+_Dense\.linear", f"dense.{key_id}", k)
|
|
562
|
+
final_weights[new_key] = v
|
|
563
|
+
else:
|
|
564
|
+
continue
|
|
565
|
+
|
|
566
|
+
return final_weights
|
|
567
|
+
|
|
568
|
+
class ModelForSentenceTransformers(ModelForSentenceSimilarity):
|
|
569
|
+
"""
|
|
570
|
+
Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
|
|
571
|
+
This class sanitizes typical sentence transformers weights to align with the Gemma3 model.
|
|
572
|
+
"""
|
|
573
|
+
def __init__(self, config: ModelArgs):
|
|
574
|
+
super().__init__(config)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
class Gemma3PredictionHead(nn.Module):
|
|
579
|
+
def __init__(self, config: ModelArgs):
|
|
580
|
+
super().__init__()
|
|
581
|
+
self.config = config
|
|
582
|
+
self.dense = nn.Linear(
|
|
583
|
+
config.hidden_size, config.hidden_size, config.classifier_bias
|
|
584
|
+
)
|
|
585
|
+
self.act = nn.GELU(approx="precise")
|
|
586
|
+
self.norm = nn.RMSNorm(
|
|
587
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
self.soft_cap = config.final_logit_softcapping
|
|
591
|
+
|
|
592
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
593
|
+
logits = self.norm(self.act(self.dense(hidden_states)))
|
|
594
|
+
if self.soft_cap is not None:
|
|
595
|
+
logits = mx.tanh(logits / self.soft_cap) * self.soft_cap
|
|
596
|
+
return logits
|
|
597
|
+
|
|
598
|
+
class ModelForSequenceClassification(RaclateBaseModel):
|
|
599
|
+
"""
|
|
600
|
+
Computes sequence classification probabilities for input sequences.
|
|
601
|
+
Sanitization aligns typical BERT weights with HF's Qwen3ForSequenceClassification architecture.
|
|
602
|
+
|
|
603
|
+
NOTE : regression and binary classification not tested.
|
|
604
|
+
"""
|
|
605
|
+
def __init__(self, config: ModelArgs):
|
|
606
|
+
super().__init__()
|
|
607
|
+
self.config = config
|
|
608
|
+
self.num_labels = config.num_labels
|
|
609
|
+
self.is_regression = config.is_regression
|
|
610
|
+
|
|
611
|
+
self.model = Gemma3Model(config)
|
|
612
|
+
|
|
613
|
+
### The HF architecture Gemma3ForSequenceClassification
|
|
614
|
+
### does not have head and drop
|
|
615
|
+
#### and uses 'score' as the final layer name
|
|
616
|
+
# self.head = Gemma3PredictionHead(config)
|
|
617
|
+
# self.drop = nn.Dropout(p=config.classifier_dropout)
|
|
618
|
+
|
|
619
|
+
self.score = nn.Linear(
|
|
620
|
+
config.hidden_size,
|
|
621
|
+
config.num_labels,
|
|
622
|
+
bias=False
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
self.hf_transformers_arch = "Gemma3ForSequenceClassification"
|
|
626
|
+
|
|
627
|
+
def _process_outputs(self, logits: mx.array) -> mx.array:
|
|
628
|
+
"""Apply the appropriate activation function to the logits."""
|
|
629
|
+
if self.is_regression:
|
|
630
|
+
return logits # No activation for regression
|
|
631
|
+
elif self.num_labels == 1:
|
|
632
|
+
return mx.sigmoid(logits) # Binary classification
|
|
633
|
+
else:
|
|
634
|
+
# Using softmax for multi-class classification
|
|
635
|
+
return mx.softmax(logits, axis=-1)
|
|
636
|
+
|
|
637
|
+
def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
|
|
638
|
+
"""Compute the appropriate loss based on label characteristics."""
|
|
639
|
+
if self.is_regression:
|
|
640
|
+
return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
|
|
641
|
+
elif self.num_labels == 1:
|
|
642
|
+
return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
|
|
643
|
+
else:
|
|
644
|
+
return nn.losses.cross_entropy(
|
|
645
|
+
logits.reshape(-1, self.num_labels),
|
|
646
|
+
labels.reshape(-1)
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
def __call__(
|
|
650
|
+
self,
|
|
651
|
+
input_ids,
|
|
652
|
+
attention_mask: Optional[mx.array] = None,
|
|
653
|
+
position_ids: Optional[mx.array] = None, ### need this?
|
|
654
|
+
labels: Optional[mx.array] = None,
|
|
655
|
+
output_hidden_states: Optional[bool] = False,
|
|
656
|
+
return_dict: Optional[bool] = True,
|
|
657
|
+
) -> Dict:
|
|
658
|
+
if attention_mask is None:
|
|
659
|
+
batch_size, seq_len = input_ids.shape
|
|
660
|
+
attention_mask = mx.ones(
|
|
661
|
+
(batch_size, seq_len),
|
|
662
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
outputs = self.model(
|
|
666
|
+
input_ids,
|
|
667
|
+
attention_mask,
|
|
668
|
+
position_ids=position_ids,
|
|
669
|
+
output_hidden_states=output_hidden_states,
|
|
670
|
+
return_dict=return_dict
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
last_hidden_state = (
|
|
674
|
+
outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# normalized features
|
|
678
|
+
if not self.config.is_causal:
|
|
679
|
+
text_embeds = mean_pooling(last_hidden_state, attention_mask)
|
|
680
|
+
else:
|
|
681
|
+
text_embeds = last_token_pooling(last_hidden_state, attention_mask)
|
|
682
|
+
|
|
683
|
+
### The HF architecture Gemma3ForSequenceClassification
|
|
684
|
+
logits = self.score(text_embeds)
|
|
685
|
+
|
|
686
|
+
processed_logits = self._process_outputs(logits)
|
|
687
|
+
|
|
688
|
+
loss = None
|
|
689
|
+
if labels is not None :
|
|
690
|
+
loss = self._compute_loss(logits, labels)
|
|
691
|
+
|
|
692
|
+
if not return_dict:
|
|
693
|
+
return [loss, processed_logits, outputs[1:]]
|
|
694
|
+
|
|
695
|
+
return {
|
|
696
|
+
"loss": loss,
|
|
697
|
+
"probabilities": processed_logits,
|
|
698
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
def sanitize(self, weights):
|
|
702
|
+
|
|
703
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
704
|
+
|
|
705
|
+
# Filter out keys from 'embeddingsgemma3' that we don't want (dense projections)
|
|
706
|
+
final_weights = {}
|
|
707
|
+
for k, v in sanitized_weights.items():
|
|
708
|
+
if not k.startswith("model.") and not k.startswith("score."):
|
|
709
|
+
continue
|
|
710
|
+
final_weights[k] = v
|
|
711
|
+
|
|
712
|
+
return final_weights
|
|
713
|
+
|
|
714
|
+
class ModelForMaskedLM(RaclateBaseModel):
|
|
715
|
+
"""
|
|
716
|
+
Computes masked language modeling (MLM) loss for input sequences.
|
|
717
|
+
"""
|
|
718
|
+
def __init__(self, config : ModelArgs):
|
|
719
|
+
super().__init__()
|
|
720
|
+
self.config = config
|
|
721
|
+
if config.is_causal:
|
|
722
|
+
raise ValueError("ModelForMaskedLM requires bidirectional attention.")
|
|
723
|
+
self.model = Gemma3Model(config)
|
|
724
|
+
self.head = Gemma3PredictionHead(config)
|
|
725
|
+
self.decoder = nn.Linear(
|
|
726
|
+
config.hidden_size, config.vocab_size, bias=config.decoder_bias
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# transformers has no MaskedLM class for Gemma3
|
|
730
|
+
|
|
731
|
+
# We explicitly call tie_weights to ensure logic is set up,
|
|
732
|
+
# though standard loading overwrites this unless sanitized correctly.
|
|
733
|
+
self.tie_weights()
|
|
734
|
+
|
|
735
|
+
def tie_weights(self):
|
|
736
|
+
self.decoder.weight = self.model.embed_tokens.weight
|
|
737
|
+
|
|
738
|
+
def get_input_embeddings(self):
|
|
739
|
+
return self.model.get_input_embeddings()
|
|
740
|
+
|
|
741
|
+
def get_output_embeddings(self):
|
|
742
|
+
return self.decoder
|
|
743
|
+
|
|
744
|
+
def set_input_embeddings(self, value):
|
|
745
|
+
self.model.set_input_embeddings(value)
|
|
746
|
+
self.tie_weights() # Re-tie weights after setting new embeddings
|
|
747
|
+
|
|
748
|
+
def set_output_embeddings(self, new_embeddings):
|
|
749
|
+
self.decoder = new_embeddings
|
|
750
|
+
self.tie_weights() # Re-tie weights after setting new decoder
|
|
751
|
+
|
|
752
|
+
def __call__(
|
|
753
|
+
self,
|
|
754
|
+
input_ids,
|
|
755
|
+
attention_mask: Optional[mx.array] = None,
|
|
756
|
+
labels: Optional[mx.array] = None,
|
|
757
|
+
position_ids: Optional[mx.array] = None,
|
|
758
|
+
output_hidden_states: Optional[bool] = None,
|
|
759
|
+
return_dict: Optional[bool] = True,
|
|
760
|
+
) -> Dict:
|
|
761
|
+
|
|
762
|
+
if attention_mask is None:
|
|
763
|
+
batch_size, seq_len = input_ids.shape
|
|
764
|
+
attention_mask = mx.ones((batch_size, seq_len)) ### updated via _update_attention_mask() in the model
|
|
765
|
+
|
|
766
|
+
outputs = self.model(
|
|
767
|
+
input_ids=input_ids,
|
|
768
|
+
attention_mask=attention_mask,
|
|
769
|
+
position_ids=position_ids,
|
|
770
|
+
output_hidden_states=output_hidden_states,
|
|
771
|
+
return_dict=return_dict,
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
|
|
775
|
+
logits = self.head(last_hidden_state)
|
|
776
|
+
logits = self.decoder(logits)
|
|
777
|
+
|
|
778
|
+
loss = None
|
|
779
|
+
if self.training and labels is not None :
|
|
780
|
+
if getattr(self.config, "sparse_prediction", False):
|
|
781
|
+
# Flatten labels and predictions
|
|
782
|
+
flat_labels = labels.reshape(-1)
|
|
783
|
+
flat_predictions = logits.reshape(-1, logits.shape[-1])
|
|
784
|
+
|
|
785
|
+
# Filter out non-masked tokens
|
|
786
|
+
ignore_index = getattr(self.config, "sparse_pred_ignore_index", -100)
|
|
787
|
+
mask_tokens = flat_labels != ignore_index
|
|
788
|
+
|
|
789
|
+
# Only compute loss on masked tokens
|
|
790
|
+
masked_predictions = flat_predictions[mask_tokens]
|
|
791
|
+
masked_labels = flat_labels[mask_tokens]
|
|
792
|
+
|
|
793
|
+
loss = nn.losses.cross_entropy(
|
|
794
|
+
masked_predictions,
|
|
795
|
+
masked_labels,
|
|
796
|
+
reduction='mean'
|
|
797
|
+
)
|
|
798
|
+
else:
|
|
799
|
+
# Standard loss computation on all tokens
|
|
800
|
+
loss = nn.losses.cross_entropy(
|
|
801
|
+
logits.reshape(-1, logits.shape[-1]),
|
|
802
|
+
labels.reshape(-1),
|
|
803
|
+
reduction='mean'
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
if not return_dict:
|
|
807
|
+
return [loss, logits, outputs[1:]]
|
|
808
|
+
|
|
809
|
+
return {
|
|
810
|
+
"loss": loss,
|
|
811
|
+
"logits": logits,
|
|
812
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
def sanitize(self, weights):
|
|
816
|
+
|
|
817
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
818
|
+
|
|
819
|
+
# Specific adjustments for MLM
|
|
820
|
+
final_weights = {}
|
|
821
|
+
for k, v in sanitized_weights.items():
|
|
822
|
+
# Filter unwanted Dense layers from embedding checkpoints
|
|
823
|
+
# and/or score layers from classification checkpoints
|
|
824
|
+
|
|
825
|
+
if not k.startswith("model.") and \
|
|
826
|
+
not k.startswith("head.") and \
|
|
827
|
+
not k.startswith("decoder."):
|
|
828
|
+
continue
|
|
829
|
+
|
|
830
|
+
# Handle Weight Tying for loading:
|
|
831
|
+
if k == "model.embed_tokens.weight" and "decoder.weight" not in weights:
|
|
832
|
+
final_weights["decoder.weight"] = v
|
|
833
|
+
|
|
834
|
+
final_weights[k] = v
|
|
835
|
+
|
|
836
|
+
return final_weights
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
class ModelForTokenClassification(RaclateBaseModel):
|
|
840
|
+
"""
|
|
841
|
+
Computes token classification probabilities for input sequences.
|
|
842
|
+
|
|
843
|
+
NOTE: untested for now
|
|
844
|
+
"""
|
|
845
|
+
def __init__(self, config: ModelArgs):
|
|
846
|
+
super().__init__()
|
|
847
|
+
self.config = config
|
|
848
|
+
if config.is_causal:
|
|
849
|
+
raise ValueError("ModelForTokenClassification requires bidirectional attention.")
|
|
850
|
+
self.num_labels = config.num_labels
|
|
851
|
+
|
|
852
|
+
self.model = Gemma3Model(config)
|
|
853
|
+
self.drop = nn.Dropout(p=config.classifier_dropout)
|
|
854
|
+
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
|
855
|
+
|
|
856
|
+
# transformers does not have TokenClassification class for Gemma3
|
|
857
|
+
|
|
858
|
+
def __call__(
|
|
859
|
+
self,
|
|
860
|
+
input_ids,
|
|
861
|
+
attention_mask: Optional[mx.array] = None,
|
|
862
|
+
position_ids: Optional[mx.array] = None,
|
|
863
|
+
labels: Optional[mx.array] = None,
|
|
864
|
+
output_hidden_states: Optional[bool] = None,
|
|
865
|
+
return_dict: Optional[bool] = True,
|
|
866
|
+
) -> Dict:
|
|
867
|
+
if attention_mask is None:
|
|
868
|
+
batch_size, seq_len = input_ids.shape
|
|
869
|
+
attention_mask = mx.ones((batch_size, seq_len))
|
|
870
|
+
|
|
871
|
+
outputs = self.model(
|
|
872
|
+
input_ids=input_ids,
|
|
873
|
+
attention_mask=attention_mask,
|
|
874
|
+
position_ids=position_ids,
|
|
875
|
+
output_hidden_states=output_hidden_states,
|
|
876
|
+
return_dict=return_dict,
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
|
|
880
|
+
|
|
881
|
+
# Apply prediction head, dropout, and classification layer to each token
|
|
882
|
+
logits = self.score(last_hidden_state)
|
|
883
|
+
|
|
884
|
+
# Process logits for inference
|
|
885
|
+
processed_logits = mx.softmax(logits, axis=-1)
|
|
886
|
+
|
|
887
|
+
loss = None
|
|
888
|
+
if labels is not None:
|
|
889
|
+
# Compute token classification loss
|
|
890
|
+
loss = nn.losses.cross_entropy(
|
|
891
|
+
logits.reshape(-1, self.num_labels),
|
|
892
|
+
labels.reshape(-1)
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
if not return_dict:
|
|
896
|
+
return [loss, processed_logits, outputs[1:]]
|
|
897
|
+
|
|
898
|
+
return {
|
|
899
|
+
"loss": loss,
|
|
900
|
+
"probabilities": processed_logits,
|
|
901
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
def sanitize(self, weights):
|
|
905
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
906
|
+
|
|
907
|
+
final_weights = {}
|
|
908
|
+
for k, v in sanitized_weights.items():
|
|
909
|
+
if not k.startswith("model.") and not k.startswith("score."):
|
|
910
|
+
continue
|
|
911
|
+
final_weights[k] = v
|
|
912
|
+
|
|
913
|
+
return final_weights
|