nexaai 1.0.19rc16__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc17__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +162 -65
- nexaai/mlx_backend/vlm/interface.py +81 -29
- nexaai/mlx_backend/vlm/main.py +58 -13
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +317 -276
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +3 -2
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc17.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc17.dist-info}/RECORD +19 -19
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc17.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc17.dist-info}/top_level.txt +0 -0
|
@@ -120,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
|
|
|
120
120
|
|
|
121
121
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
122
122
|
self.proj = nn.Conv3d(
|
|
123
|
-
self.in_channels,
|
|
124
|
-
self.embed_dim,
|
|
125
|
-
kernel_size=kernel_size,
|
|
126
|
-
stride=kernel_size,
|
|
127
|
-
bias=True
|
|
123
|
+
self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
|
|
128
124
|
)
|
|
129
125
|
|
|
130
126
|
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
131
127
|
target_dtype = self.proj.weight.dtype
|
|
132
|
-
|
|
128
|
+
|
|
133
129
|
# Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
|
|
134
130
|
# This matches the PyTorch ground truth exactly
|
|
135
131
|
hidden_states = hidden_states.reshape(
|
|
136
132
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
137
133
|
)
|
|
138
|
-
|
|
134
|
+
|
|
139
135
|
# Convert to MLX format: [batch, temporal, height, width, channels]
|
|
140
136
|
hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
|
|
141
|
-
|
|
137
|
+
|
|
142
138
|
# Apply conv3d with target dtype and reshape to match PyTorch output
|
|
143
139
|
hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
|
|
144
|
-
|
|
140
|
+
|
|
145
141
|
return hidden_states
|
|
146
142
|
|
|
147
143
|
|
|
@@ -163,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
|
|
|
163
159
|
class VisionPatchMerger(nn.Module):
|
|
164
160
|
def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
|
|
165
161
|
super().__init__()
|
|
166
|
-
self.hidden_size = config.hidden_size * (config.spatial_merge_size
|
|
162
|
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
|
167
163
|
self.use_postshuffle_norm = use_postshuffle_norm
|
|
168
|
-
|
|
164
|
+
|
|
169
165
|
norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
|
|
170
|
-
self.
|
|
166
|
+
self.norm = nn.LayerNorm(norm_size, eps=1e-6)
|
|
171
167
|
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
|
172
168
|
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
|
173
169
|
|
|
174
170
|
def __call__(self, x: mx.array) -> mx.array:
|
|
175
171
|
if self.use_postshuffle_norm:
|
|
176
|
-
x = self.
|
|
172
|
+
x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
|
|
177
173
|
else:
|
|
178
|
-
x = self.
|
|
179
|
-
|
|
174
|
+
x = self.norm(x).reshape(-1, self.hidden_size)
|
|
175
|
+
|
|
180
176
|
x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
|
|
181
177
|
return x
|
|
182
178
|
|
|
@@ -187,8 +183,8 @@ class VisionAttention(nn.Module):
|
|
|
187
183
|
self.dim = config.hidden_size
|
|
188
184
|
self.num_heads = config.num_heads
|
|
189
185
|
self.head_dim = self.dim // self.num_heads
|
|
190
|
-
self.scaling = self.head_dim
|
|
191
|
-
|
|
186
|
+
self.scaling = self.head_dim**-0.5
|
|
187
|
+
|
|
192
188
|
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
|
193
189
|
self.proj = nn.Linear(self.dim, self.dim)
|
|
194
190
|
|
|
@@ -204,51 +200,48 @@ class VisionAttention(nn.Module):
|
|
|
204
200
|
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
|
|
205
201
|
qkv = qkv.transpose(1, 0, 2, 3)
|
|
206
202
|
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
207
|
-
|
|
203
|
+
|
|
208
204
|
cos, sin = position_embeddings
|
|
209
|
-
query_states, key_states = apply_rotary_pos_emb_vision(
|
|
210
|
-
query_states, key_states, cos, sin
|
|
211
|
-
)
|
|
205
|
+
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
212
206
|
|
|
213
207
|
query_states = query_states.transpose(1, 0, 2)
|
|
214
208
|
key_states = key_states.transpose(1, 0, 2)
|
|
215
209
|
value_states = value_states.transpose(1, 0, 2)
|
|
216
|
-
|
|
210
|
+
|
|
217
211
|
query_states = mx.expand_dims(query_states, axis=0)
|
|
218
212
|
key_states = mx.expand_dims(key_states, axis=0)
|
|
219
213
|
value_states = mx.expand_dims(value_states, axis=0)
|
|
220
|
-
|
|
214
|
+
|
|
221
215
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
222
|
-
|
|
216
|
+
|
|
223
217
|
split_indices = []
|
|
224
218
|
cumsum = 0
|
|
225
219
|
for length in lengths[:-1]:
|
|
226
220
|
cumsum += int(length)
|
|
227
221
|
split_indices.append(cumsum)
|
|
228
|
-
|
|
222
|
+
|
|
229
223
|
if split_indices:
|
|
230
224
|
q_splits = mx.split(query_states, split_indices, axis=1)
|
|
231
225
|
k_splits = mx.split(key_states, split_indices, axis=1)
|
|
232
226
|
v_splits = mx.split(value_states, split_indices, axis=1)
|
|
233
227
|
else:
|
|
234
228
|
q_splits = [query_states]
|
|
235
|
-
k_splits = [key_states]
|
|
229
|
+
k_splits = [key_states]
|
|
236
230
|
v_splits = [value_states]
|
|
237
|
-
|
|
231
|
+
|
|
238
232
|
attn_outputs = []
|
|
239
233
|
for q, k, v in zip(q_splits, k_splits, v_splits):
|
|
240
234
|
attn_out = scaled_dot_product_attention(
|
|
241
|
-
q, k, v,
|
|
242
|
-
scale=self.scaling, mask=None, cache=None
|
|
235
|
+
q, k, v, scale=self.scaling, mask=None, cache=None
|
|
243
236
|
)
|
|
244
237
|
attn_outputs.append(attn_out)
|
|
245
|
-
|
|
238
|
+
|
|
246
239
|
attn_output = mx.concatenate(attn_outputs, axis=1)
|
|
247
|
-
|
|
240
|
+
|
|
248
241
|
attn_output = attn_output[0].transpose(1, 0, 2)
|
|
249
242
|
attn_output = attn_output.reshape(seq_length, -1)
|
|
250
243
|
attn_output = self.proj(attn_output)
|
|
251
|
-
|
|
244
|
+
|
|
252
245
|
return attn_output
|
|
253
246
|
|
|
254
247
|
|
|
@@ -284,7 +277,7 @@ class VisionModel(nn.Module):
|
|
|
284
277
|
|
|
285
278
|
self.patch_embed = VisionPatchEmbed(config)
|
|
286
279
|
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
|
287
|
-
self.num_grid_per_side = int(config.num_position_embeddings
|
|
280
|
+
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
|
|
288
281
|
|
|
289
282
|
head_dim = config.hidden_size // config.num_heads
|
|
290
283
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
@@ -310,7 +303,7 @@ class VisionModel(nn.Module):
|
|
|
310
303
|
num_frames = int(grid_thw[i, 0].item())
|
|
311
304
|
height = int(grid_thw[i, 1].item())
|
|
312
305
|
width = int(grid_thw[i, 2].item())
|
|
313
|
-
|
|
306
|
+
|
|
314
307
|
merged_h, merged_w = height // merge_size, width // merge_size
|
|
315
308
|
|
|
316
309
|
block_rows = mx.arange(merged_h) # block row indices
|
|
@@ -322,8 +315,12 @@ class VisionModel(nn.Module):
|
|
|
322
315
|
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
|
323
316
|
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
|
324
317
|
|
|
325
|
-
row_idx = mx.broadcast_to(
|
|
326
|
-
|
|
318
|
+
row_idx = mx.broadcast_to(
|
|
319
|
+
row_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
320
|
+
).reshape(-1)
|
|
321
|
+
col_idx = mx.broadcast_to(
|
|
322
|
+
col_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
323
|
+
).reshape(-1)
|
|
327
324
|
|
|
328
325
|
coords = mx.stack([row_idx, col_idx], axis=-1)
|
|
329
326
|
|
|
@@ -334,19 +331,19 @@ class VisionModel(nn.Module):
|
|
|
334
331
|
|
|
335
332
|
# Concatenate all coordinate parts
|
|
336
333
|
pos_ids = mx.concatenate(pos_ids_parts, axis=0)
|
|
337
|
-
|
|
334
|
+
|
|
338
335
|
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
|
339
336
|
embeddings = embeddings.reshape(embeddings.shape[0], -1)
|
|
340
337
|
return embeddings
|
|
341
338
|
|
|
342
339
|
def fast_pos_embed_interpolate(self, grid_thw: mx.array):
|
|
343
340
|
patch_pos_embeds = []
|
|
344
|
-
|
|
341
|
+
|
|
345
342
|
for i in range(grid_thw.shape[0]):
|
|
346
343
|
t = int(grid_thw[i, 0].item())
|
|
347
344
|
h = int(grid_thw[i, 1].item())
|
|
348
345
|
w = int(grid_thw[i, 2].item())
|
|
349
|
-
|
|
346
|
+
|
|
350
347
|
# Simple position embedding interpolation
|
|
351
348
|
h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
|
|
352
349
|
w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
|
|
@@ -383,37 +380,41 @@ class VisionModel(nn.Module):
|
|
|
383
380
|
|
|
384
381
|
# Repeat for temporal dimension and apply spatial merging
|
|
385
382
|
pos_embed = mx.tile(pos_embed, (t, 1))
|
|
386
|
-
|
|
383
|
+
|
|
387
384
|
# Apply spatial merging pattern
|
|
388
385
|
merge_size = self.config.spatial_merge_size
|
|
389
|
-
pos_embed = pos_embed.reshape(
|
|
386
|
+
pos_embed = pos_embed.reshape(
|
|
387
|
+
t, h // merge_size, merge_size, w // merge_size, merge_size, -1
|
|
388
|
+
)
|
|
390
389
|
pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
|
|
391
390
|
pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
|
|
392
|
-
|
|
391
|
+
|
|
393
392
|
patch_pos_embeds.append(pos_embed)
|
|
394
|
-
|
|
393
|
+
|
|
395
394
|
return mx.concatenate(patch_pos_embeds, axis=0)
|
|
396
395
|
|
|
397
|
-
def __call__(
|
|
396
|
+
def __call__(
|
|
397
|
+
self, hidden_states: mx.array, grid_thw: mx.array
|
|
398
|
+
) -> Tuple[mx.array, List[mx.array]]:
|
|
398
399
|
hidden_states = self.patch_embed(hidden_states)
|
|
399
|
-
|
|
400
|
+
|
|
400
401
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
401
402
|
hidden_states = hidden_states + pos_embeds
|
|
402
403
|
|
|
403
404
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
404
405
|
seq_len = hidden_states.shape[0]
|
|
405
|
-
|
|
406
|
+
|
|
406
407
|
emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
|
|
407
408
|
position_embeddings = (mx.cos(emb), mx.sin(emb))
|
|
408
409
|
|
|
409
|
-
|
|
410
|
+
# Create cumulative sequence lengths (following HuggingFace implementation)
|
|
410
411
|
# torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
|
|
411
412
|
seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
|
|
412
413
|
seq_lens = []
|
|
413
414
|
for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
|
|
414
415
|
seq_lens.extend([seq_len] * int(repeats))
|
|
415
416
|
seq_lens = mx.array(seq_lens)
|
|
416
|
-
|
|
417
|
+
|
|
417
418
|
# Then compute cumulative sum
|
|
418
419
|
cu_seqlens = mx.cumsum(seq_lens)
|
|
419
420
|
# Pad with 0 at the beginning
|
|
@@ -441,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
441
442
|
self.config = config
|
|
442
443
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
443
444
|
self.original_max_seq_len = config.max_position_embeddings
|
|
444
|
-
|
|
445
|
+
|
|
445
446
|
# MRoPE configuration
|
|
446
447
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
447
448
|
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
|
@@ -449,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
449
450
|
else:
|
|
450
451
|
self.rope_type = "default"
|
|
451
452
|
self.mrope_section = [24, 20, 20]
|
|
452
|
-
|
|
453
|
+
|
|
453
454
|
# Store parameters for computing inv_freq on the fly
|
|
454
455
|
self.head_dim = config.head_dim
|
|
455
456
|
self.theta = config.rope_theta
|
|
456
|
-
|
|
457
|
+
|
|
457
458
|
# Attention scaling (simplified - may need adjustment based on actual config)
|
|
458
459
|
self.attention_scaling = 1.0
|
|
459
460
|
|
|
460
461
|
def _get_inv_freq(self):
|
|
461
462
|
"""Compute inverse frequencies on the fly"""
|
|
462
|
-
inv_freq = 1.0 / (
|
|
463
|
+
inv_freq = 1.0 / (
|
|
464
|
+
self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
|
|
465
|
+
)
|
|
463
466
|
# Expand for 3 dimensions (T, H, W)
|
|
464
467
|
return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
|
|
465
468
|
|
|
@@ -485,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
485
488
|
Args:
|
|
486
489
|
x: Input tensor for dtype reference
|
|
487
490
|
position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
|
|
488
|
-
|
|
491
|
+
|
|
489
492
|
Returns:
|
|
490
493
|
cos, sin: Cosine and sine embeddings
|
|
491
494
|
"""
|
|
492
495
|
# Handle 2D position_ids by expanding to 3D for MRoPE
|
|
493
496
|
if position_ids.ndim == 2:
|
|
494
|
-
position_ids = mx.broadcast_to(
|
|
495
|
-
|
|
497
|
+
position_ids = mx.broadcast_to(
|
|
498
|
+
position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
|
|
499
|
+
)
|
|
500
|
+
|
|
496
501
|
batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
|
|
497
|
-
|
|
502
|
+
|
|
498
503
|
# Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
|
|
499
504
|
inv_freq_expanded = mx.broadcast_to(
|
|
500
|
-
self._get_inv_freq()[:, None, None, :],
|
|
501
|
-
(3, batch_size, 1, self._get_inv_freq().shape[-1])
|
|
505
|
+
self._get_inv_freq()[:, None, None, :],
|
|
506
|
+
(3, batch_size, 1, self._get_inv_freq().shape[-1]),
|
|
502
507
|
)
|
|
503
|
-
|
|
508
|
+
|
|
504
509
|
# Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
|
|
505
510
|
position_ids_expanded = position_ids[..., None].astype(mx.float32)
|
|
506
|
-
|
|
511
|
+
|
|
507
512
|
# Compute frequencies: (3, batch_size, seq_len, dim//2)
|
|
508
513
|
freqs = inv_freq_expanded * position_ids_expanded
|
|
509
|
-
|
|
514
|
+
|
|
510
515
|
# Apply interleaved MRoPE
|
|
511
516
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
512
|
-
|
|
517
|
+
|
|
513
518
|
# Create embeddings
|
|
514
519
|
emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
|
|
515
520
|
cos = mx.cos(emb) * self.attention_scaling
|
|
516
521
|
sin = mx.sin(emb) * self.attention_scaling
|
|
517
|
-
|
|
522
|
+
|
|
518
523
|
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
519
524
|
|
|
520
525
|
|
|
@@ -523,12 +528,12 @@ class TextAttention(nn.Module):
|
|
|
523
528
|
super().__init__()
|
|
524
529
|
self.config = config
|
|
525
530
|
self.layer_idx = layer_idx
|
|
526
|
-
|
|
531
|
+
|
|
527
532
|
dim = config.hidden_size
|
|
528
533
|
self.n_heads = config.num_attention_heads
|
|
529
534
|
self.n_kv_heads = config.num_key_value_heads
|
|
530
535
|
self.head_dim = config.head_dim
|
|
531
|
-
self.scale = self.head_dim
|
|
536
|
+
self.scale = self.head_dim**-0.5
|
|
532
537
|
|
|
533
538
|
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
|
|
534
539
|
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
@@ -537,7 +542,7 @@ class TextAttention(nn.Module):
|
|
|
537
542
|
|
|
538
543
|
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
539
544
|
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
540
|
-
|
|
545
|
+
|
|
541
546
|
# Initialize rope directly
|
|
542
547
|
self.rope = initialize_rope(
|
|
543
548
|
config.head_dim,
|
|
@@ -573,8 +578,9 @@ class TextAttention(nn.Module):
|
|
|
573
578
|
keys, values = cache.update_and_fetch(keys, values)
|
|
574
579
|
else:
|
|
575
580
|
if cache is not None:
|
|
576
|
-
|
|
577
|
-
|
|
581
|
+
offset_delta = rope_deltas.item() if rope_deltas is not None and rope_deltas.size == 1 else (rope_deltas.reshape(-1)[0].item() if rope_deltas is not None else 0)
|
|
582
|
+
queries = self.rope(queries, offset=cache.offset + offset_delta)
|
|
583
|
+
keys = self.rope(keys, offset=cache.offset + offset_delta)
|
|
578
584
|
keys, values = cache.update_and_fetch(keys, values)
|
|
579
585
|
else:
|
|
580
586
|
queries = self.rope(queries)
|
|
@@ -618,7 +624,7 @@ class TextDecoderLayer(nn.Module):
|
|
|
618
624
|
) -> mx.array:
|
|
619
625
|
residual = hidden_states
|
|
620
626
|
hidden_states = self.input_layernorm(hidden_states)
|
|
621
|
-
|
|
627
|
+
|
|
622
628
|
hidden_states, _ = self.self_attn(
|
|
623
629
|
hidden_states=hidden_states,
|
|
624
630
|
attention_mask=attention_mask,
|
|
@@ -640,11 +646,10 @@ class TextModel(nn.Module):
|
|
|
640
646
|
super().__init__()
|
|
641
647
|
self.config = config
|
|
642
648
|
self.vocab_size = config.vocab_size
|
|
643
|
-
|
|
649
|
+
|
|
644
650
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
645
651
|
self.layers = [
|
|
646
|
-
TextDecoderLayer(config, layer_idx)
|
|
647
|
-
for layer_idx in range(config.num_hidden_layers)
|
|
652
|
+
TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
|
|
648
653
|
]
|
|
649
654
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
650
655
|
self.rotary_emb = TextRotaryEmbedding(config)
|
|
@@ -701,7 +706,9 @@ class TextModel(nn.Module):
|
|
|
701
706
|
rope_deltas=rope_deltas,
|
|
702
707
|
)
|
|
703
708
|
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
|
|
704
|
-
hidden_states = self._deepstack_process(
|
|
709
|
+
hidden_states = self._deepstack_process(
|
|
710
|
+
hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
|
|
711
|
+
)
|
|
705
712
|
hidden_states = self.norm(hidden_states)
|
|
706
713
|
return hidden_states
|
|
707
714
|
|
|
@@ -712,17 +719,17 @@ class VEGModel(nn.Module):
|
|
|
712
719
|
super().__init__()
|
|
713
720
|
self.config = vision_config
|
|
714
721
|
self.visual = VisionModel(vision_config)
|
|
715
|
-
|
|
722
|
+
|
|
716
723
|
def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
|
|
717
724
|
return self.visual(pixel_values, image_grid_thw)
|
|
718
|
-
|
|
725
|
+
|
|
719
726
|
def sanitize(self, weights):
|
|
720
727
|
sanitized = {}
|
|
721
728
|
for k, v in weights.items():
|
|
722
|
-
if
|
|
729
|
+
if "visual." in k:
|
|
723
730
|
# Remove prefixes to match our model structure
|
|
724
|
-
clean_key = k.replace(
|
|
725
|
-
sanitized[f
|
|
731
|
+
clean_key = k.replace("model.visual.", "").replace("visual.", "")
|
|
732
|
+
sanitized[f"visual.{clean_key}"] = v
|
|
726
733
|
return sanitized
|
|
727
734
|
|
|
728
735
|
|
|
@@ -735,140 +742,164 @@ class LLMModel(nn.Module):
|
|
|
735
742
|
self.language_model = TextModel(text_config)
|
|
736
743
|
if not text_config.tie_word_embeddings:
|
|
737
744
|
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
|
738
|
-
|
|
745
|
+
|
|
739
746
|
def get_rope_index(
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
747
|
+
self,
|
|
748
|
+
input_ids: Optional[mx.array] = None,
|
|
749
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
750
|
+
attention_mask: Optional[mx.array] = None,
|
|
751
|
+
) -> Tuple[mx.array, mx.array]:
|
|
752
|
+
"""Simplified version for images only (no video support)."""
|
|
753
|
+
|
|
754
|
+
spatial_merge_size = 2
|
|
755
|
+
image_token_id = 151655
|
|
756
|
+
vision_start_token_id = 151652
|
|
757
|
+
mrope_position_deltas = []
|
|
758
|
+
|
|
759
|
+
if input_ids is not None and image_grid_thw is not None:
|
|
760
|
+
total_input_ids = input_ids
|
|
761
|
+
if attention_mask is None:
|
|
762
|
+
attention_mask = mx.ones_like(total_input_ids)
|
|
763
|
+
|
|
764
|
+
batch_size, seq_len = input_ids.shape
|
|
765
|
+
position_ids_list = []
|
|
766
|
+
image_index = 0
|
|
767
|
+
|
|
768
|
+
for i in range(batch_size):
|
|
769
|
+
input_ids_seq = total_input_ids[i]
|
|
770
|
+
mask_seq = attention_mask[i]
|
|
771
|
+
|
|
772
|
+
# Use mask to get valid length
|
|
773
|
+
valid_length = int(mx.sum(mask_seq).item())
|
|
774
|
+
input_ids_seq = input_ids_seq[:valid_length]
|
|
775
|
+
|
|
776
|
+
image_nums = 0
|
|
777
|
+
# Find vision start tokens by iterating through the sequence
|
|
778
|
+
vision_start_positions = []
|
|
779
|
+
for pos in range(input_ids_seq.shape[0]):
|
|
780
|
+
if input_ids_seq[pos].item() == vision_start_token_id:
|
|
781
|
+
vision_start_positions.append(pos)
|
|
782
|
+
|
|
783
|
+
if len(vision_start_positions) > 0:
|
|
784
|
+
for pos in vision_start_positions:
|
|
785
|
+
if pos + 1 < input_ids_seq.shape[0]:
|
|
786
|
+
if input_ids_seq[pos + 1].item() == image_token_id:
|
|
787
|
+
image_nums += 1
|
|
788
|
+
|
|
789
|
+
input_tokens = input_ids_seq.tolist()
|
|
790
|
+
llm_pos_ids_list = []
|
|
791
|
+
st = 0
|
|
792
|
+
remain_images = image_nums
|
|
793
|
+
|
|
794
|
+
for _ in range(image_nums):
|
|
795
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
796
|
+
|
|
797
|
+
t = image_grid_thw[image_index, 0].item()
|
|
798
|
+
h = image_grid_thw[image_index, 1].item()
|
|
799
|
+
w = image_grid_thw[image_index, 2].item()
|
|
800
|
+
image_index += 1
|
|
801
|
+
remain_images -= 1
|
|
802
|
+
ed = ed_image
|
|
803
|
+
|
|
804
|
+
llm_grid_t = int(t)
|
|
805
|
+
llm_grid_h = int(h) // spatial_merge_size
|
|
806
|
+
llm_grid_w = int(w) // spatial_merge_size
|
|
807
|
+
text_len = ed - st
|
|
808
|
+
|
|
809
|
+
st_idx = (
|
|
810
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
811
|
+
)
|
|
812
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
813
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
814
|
+
llm_pos_ids_list.append(text_pos)
|
|
815
|
+
|
|
816
|
+
# t_index is always 0 because llm_grid_t is always 1 for images
|
|
817
|
+
t_index = mx.arange(llm_grid_t).reshape(-1, 1)
|
|
818
|
+
t_index = mx.broadcast_to(
|
|
819
|
+
t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
|
|
820
|
+
).reshape(-1)
|
|
821
|
+
|
|
822
|
+
h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
|
|
823
|
+
h_index = mx.broadcast_to(
|
|
824
|
+
h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
825
|
+
).reshape(-1)
|
|
826
|
+
|
|
827
|
+
w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
|
|
828
|
+
w_index = mx.broadcast_to(
|
|
829
|
+
w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
830
|
+
).reshape(-1)
|
|
831
|
+
|
|
832
|
+
vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
833
|
+
llm_pos_ids_list.append(vision_pos)
|
|
834
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
835
|
+
|
|
836
|
+
if st < len(input_tokens):
|
|
837
|
+
st_idx = (
|
|
838
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
839
|
+
)
|
|
840
|
+
text_len = len(input_tokens) - st
|
|
841
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
842
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
843
|
+
llm_pos_ids_list.append(text_pos)
|
|
844
|
+
|
|
845
|
+
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
846
|
+
|
|
847
|
+
# Create position_ids for this batch item, pad to seq_len
|
|
848
|
+
batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
|
|
849
|
+
valid_length = min(seq_len, llm_positions.shape[1])
|
|
850
|
+
|
|
851
|
+
# Create new arrays for each dimension
|
|
852
|
+
pos_dim0 = mx.concatenate(
|
|
853
|
+
[
|
|
854
|
+
llm_positions[0, :valid_length],
|
|
855
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
856
|
+
]
|
|
857
|
+
)
|
|
858
|
+
pos_dim1 = mx.concatenate(
|
|
859
|
+
[
|
|
860
|
+
llm_positions[1, :valid_length],
|
|
861
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
862
|
+
]
|
|
863
|
+
)
|
|
864
|
+
pos_dim2 = mx.concatenate(
|
|
865
|
+
[
|
|
866
|
+
llm_positions[2, :valid_length],
|
|
867
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
868
|
+
]
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
|
|
872
|
+
position_ids_list.append(batch_position_ids)
|
|
873
|
+
|
|
874
|
+
mrope_position_deltas.append(
|
|
875
|
+
llm_positions.max().item() + 1 - len(total_input_ids[i])
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# Stack all batch position_ids
|
|
879
|
+
position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
|
|
880
|
+
mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
|
|
881
|
+
return position_ids, mrope_position_deltas
|
|
882
|
+
else:
|
|
883
|
+
if attention_mask is not None:
|
|
884
|
+
position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
|
|
885
|
+
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
|
886
|
+
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
887
|
+
position_ids = mx.broadcast_to(
|
|
888
|
+
position_ids, (3, position_ids.shape[1], position_ids.shape[2])
|
|
889
|
+
)
|
|
890
|
+
max_position_ids = mx.max(
|
|
891
|
+
mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
|
|
892
|
+
)
|
|
893
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
852
894
|
else:
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
|
|
862
|
-
else:
|
|
863
|
-
seq_len = input_ids.shape[1]
|
|
864
|
-
batch_size = input_ids.shape[0]
|
|
865
|
-
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
866
|
-
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
867
|
-
# 1D zeros for rope deltas
|
|
868
|
-
mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
|
|
869
|
-
|
|
870
|
-
return position_ids, mrope_position_deltas
|
|
871
|
-
|
|
895
|
+
seq_len = input_ids.shape[1]
|
|
896
|
+
batch_size = input_ids.shape[0]
|
|
897
|
+
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
898
|
+
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
899
|
+
mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
|
|
900
|
+
|
|
901
|
+
return position_ids, mrope_position_deltas
|
|
902
|
+
|
|
872
903
|
def __call__(
|
|
873
904
|
self,
|
|
874
905
|
inputs: mx.array = None,
|
|
@@ -896,35 +927,41 @@ class LLMModel(nn.Module):
|
|
|
896
927
|
return self.language_model.embed_tokens.as_linear(out)
|
|
897
928
|
else:
|
|
898
929
|
return self.lm_head(out)
|
|
899
|
-
|
|
930
|
+
|
|
900
931
|
def sanitize(self, weights):
|
|
901
932
|
sanitized = {}
|
|
902
933
|
for k, v in weights.items():
|
|
903
|
-
if not (
|
|
934
|
+
if not ("visual." in k):
|
|
904
935
|
# Handle key mapping from combined model to LLM-only model
|
|
905
936
|
clean_key = k
|
|
906
|
-
|
|
937
|
+
|
|
907
938
|
# Remove model. prefix if present
|
|
908
|
-
if clean_key.startswith(
|
|
939
|
+
if clean_key.startswith("model."):
|
|
909
940
|
clean_key = clean_key[6:] # Remove 'model.'
|
|
910
|
-
|
|
941
|
+
|
|
911
942
|
# Map language_ prefixed keys to language_model structure
|
|
912
|
-
if clean_key.startswith(
|
|
913
|
-
if clean_key.startswith(
|
|
914
|
-
clean_key =
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
elif clean_key.startswith(
|
|
918
|
-
clean_key =
|
|
919
|
-
|
|
943
|
+
if clean_key.startswith("language_"):
|
|
944
|
+
if clean_key.startswith("language_layers."):
|
|
945
|
+
clean_key = (
|
|
946
|
+
"language_model.layers." + clean_key[16:]
|
|
947
|
+
) # Map to language_model.layers.
|
|
948
|
+
elif clean_key.startswith("language_embed_tokens."):
|
|
949
|
+
clean_key = (
|
|
950
|
+
"language_model.embed_tokens." + clean_key[22:]
|
|
951
|
+
) # Map to language_model.embed_tokens.
|
|
952
|
+
elif clean_key.startswith("language_norm."):
|
|
953
|
+
clean_key = (
|
|
954
|
+
"language_model.norm." + clean_key[14:]
|
|
955
|
+
) # Map to language_model.norm.
|
|
956
|
+
|
|
920
957
|
sanitized[clean_key] = v
|
|
921
|
-
|
|
958
|
+
|
|
922
959
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
923
960
|
if self.args.tie_word_embeddings:
|
|
924
961
|
sanitized.pop("lm_head.weight", None)
|
|
925
|
-
|
|
962
|
+
|
|
926
963
|
return sanitized
|
|
927
|
-
|
|
964
|
+
|
|
928
965
|
@property
|
|
929
966
|
def layers(self):
|
|
930
967
|
return self.language_model.layers
|
|
@@ -938,39 +975,36 @@ class Qwen3VLModel(nn.Module):
|
|
|
938
975
|
self.config = args
|
|
939
976
|
self.visual = VisionModel(args.vision_config)
|
|
940
977
|
self.language_model = TextModel(args.text_config)
|
|
941
|
-
|
|
978
|
+
|
|
942
979
|
def sanitize(self, weights):
|
|
943
980
|
# Map weights to match the combined model structure
|
|
944
981
|
sanitized = {}
|
|
945
982
|
for k, v in weights.items():
|
|
946
983
|
# Remove 'model.' prefix if present to match our structure
|
|
947
|
-
clean_key = k.replace(
|
|
984
|
+
clean_key = k.replace("model.", "") if k.startswith("model.") else k
|
|
948
985
|
sanitized[clean_key] = v
|
|
949
986
|
return sanitized
|
|
950
987
|
|
|
951
|
-
def get_image_features(
|
|
952
|
-
self,
|
|
953
|
-
pixel_values: mx.array,
|
|
954
|
-
image_grid_thw: Optional[mx.array] = None
|
|
955
|
-
):
|
|
988
|
+
def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
|
|
956
989
|
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
|
|
957
990
|
# Split based on grid dimensions
|
|
958
991
|
if image_grid_thw is not None:
|
|
959
|
-
split_sizes = (
|
|
992
|
+
split_sizes = (
|
|
993
|
+
mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
|
|
994
|
+
).tolist()
|
|
960
995
|
# Convert sizes to indices for mx.split (cumulative sum, excluding the last)
|
|
961
996
|
split_indices = []
|
|
962
997
|
cumsum = 0
|
|
963
998
|
for size in split_sizes[:-1]: # Exclude last element
|
|
964
999
|
cumsum += size
|
|
965
1000
|
split_indices.append(cumsum)
|
|
966
|
-
|
|
1001
|
+
|
|
967
1002
|
if split_indices: # Only split if we have indices
|
|
968
1003
|
image_embeds = mx.split(image_embeds, split_indices)
|
|
969
1004
|
else:
|
|
970
1005
|
image_embeds = [image_embeds] # Single image case
|
|
971
1006
|
return image_embeds, deepstack_visual_embeds
|
|
972
1007
|
|
|
973
|
-
|
|
974
1008
|
def __call__(
|
|
975
1009
|
self,
|
|
976
1010
|
input_ids: mx.array = None,
|
|
@@ -989,26 +1023,25 @@ class Qwen3VLModel(nn.Module):
|
|
|
989
1023
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
990
1024
|
|
|
991
1025
|
# Process images
|
|
992
|
-
|
|
1026
|
+
|
|
993
1027
|
if pixel_values is not None:
|
|
994
1028
|
image_embeds, deepstack_visual_embeds = self.get_image_features(
|
|
995
1029
|
pixel_values, image_grid_thw
|
|
996
1030
|
)
|
|
997
|
-
|
|
1031
|
+
|
|
998
1032
|
# Create masks and embed visual features
|
|
999
1033
|
if isinstance(image_embeds, list):
|
|
1000
1034
|
image_embeds = mx.concatenate(image_embeds, axis=0)
|
|
1001
|
-
|
|
1035
|
+
|
|
1002
1036
|
# Find image token positions and replace with visual embeddings
|
|
1003
|
-
image_mask =
|
|
1037
|
+
image_mask = input_ids == self.args.image_token_id
|
|
1004
1038
|
visual_pos_masks = image_mask
|
|
1005
|
-
|
|
1039
|
+
|
|
1006
1040
|
# Replace image tokens with visual embeddings
|
|
1007
1041
|
inputs_embeds = inputs_embeds.at[image_mask].set(
|
|
1008
1042
|
image_embeds.astype(inputs_embeds.dtype)
|
|
1009
1043
|
)
|
|
1010
1044
|
|
|
1011
|
-
|
|
1012
1045
|
outputs = self.language_model(
|
|
1013
1046
|
inputs_embeds=inputs_embeds,
|
|
1014
1047
|
attention_mask=attention_mask,
|
|
@@ -1026,28 +1059,28 @@ class Qwen3VLModel(nn.Module):
|
|
|
1026
1059
|
def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
|
|
1027
1060
|
"""
|
|
1028
1061
|
Handle the processing of multimodal embeddings including image features and position encoding.
|
|
1029
|
-
|
|
1062
|
+
|
|
1030
1063
|
This function processes vision and text inputs to create unified embeddings that can be fed
|
|
1031
1064
|
into the language model. It handles:
|
|
1032
1065
|
- Vision feature extraction from pixel values
|
|
1033
1066
|
- Deepstack visual embedding collection
|
|
1034
1067
|
- Image token replacement in text embeddings
|
|
1035
1068
|
- Position encoding setup for MRoPE (Multi-dimensional RoPE)
|
|
1036
|
-
|
|
1069
|
+
|
|
1037
1070
|
Args:
|
|
1038
1071
|
vision_model: The vision encoder model (VEGModel instance)
|
|
1039
|
-
llm_model: The language model (LLMModel instance)
|
|
1072
|
+
llm_model: The language model (LLMModel instance)
|
|
1040
1073
|
input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
|
|
1041
1074
|
pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
|
|
1042
1075
|
image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
|
|
1043
|
-
|
|
1076
|
+
|
|
1044
1077
|
Returns:
|
|
1045
1078
|
tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
|
|
1046
1079
|
- inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
|
|
1047
1080
|
- deepstack_visual_embeds: Multi-layer visual features for deepstack processing
|
|
1048
1081
|
- visual_pos_masks: Boolean mask indicating image token positions
|
|
1049
1082
|
- cos: Cosine values for rotary position encoding
|
|
1050
|
-
- sin: Sine values for rotary position encoding
|
|
1083
|
+
- sin: Sine values for rotary position encoding
|
|
1051
1084
|
- rope_deltas: Position offset deltas for rope computation
|
|
1052
1085
|
"""
|
|
1053
1086
|
inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
|
|
@@ -1056,74 +1089,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1056
1089
|
cos = None
|
|
1057
1090
|
sin = None
|
|
1058
1091
|
rope_deltas = 0
|
|
1059
|
-
|
|
1092
|
+
|
|
1060
1093
|
if pixel_values is not None:
|
|
1061
1094
|
if pixel_values.ndim == 4:
|
|
1062
1095
|
pixel_values = mx.expand_dims(pixel_values, axis=2)
|
|
1063
|
-
|
|
1096
|
+
|
|
1064
1097
|
# Process each image individually to prevent feature mixing
|
|
1065
1098
|
image_embeds_list = []
|
|
1066
1099
|
all_deepstack_embeds = []
|
|
1067
|
-
|
|
1100
|
+
|
|
1068
1101
|
# Calculate cumulative indices for each image
|
|
1069
1102
|
cumulative_patches = 0
|
|
1070
|
-
|
|
1103
|
+
|
|
1071
1104
|
for i in range(image_grid_thw.shape[0]):
|
|
1072
1105
|
# Calculate number of patches for current image
|
|
1073
1106
|
current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
|
|
1074
1107
|
start_idx = cumulative_patches
|
|
1075
1108
|
end_idx = cumulative_patches + current_patches
|
|
1076
1109
|
cumulative_patches += current_patches
|
|
1077
|
-
|
|
1110
|
+
|
|
1078
1111
|
single_pixel_values = pixel_values[start_idx:end_idx]
|
|
1079
|
-
single_grid_thw = image_grid_thw[i:i+1]
|
|
1080
|
-
|
|
1112
|
+
single_grid_thw = image_grid_thw[i : i + 1]
|
|
1113
|
+
|
|
1081
1114
|
# Use vision model directly
|
|
1082
1115
|
single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
|
|
1083
|
-
|
|
1116
|
+
|
|
1084
1117
|
# Split based on grid dimensions
|
|
1085
1118
|
if single_grid_thw is not None:
|
|
1086
|
-
split_sizes = (
|
|
1119
|
+
split_sizes = (
|
|
1120
|
+
mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
|
|
1121
|
+
).tolist()
|
|
1087
1122
|
split_indices = []
|
|
1088
1123
|
cumsum = 0
|
|
1089
1124
|
for size in split_sizes[:-1]:
|
|
1090
1125
|
cumsum += size
|
|
1091
1126
|
split_indices.append(cumsum)
|
|
1092
|
-
|
|
1127
|
+
|
|
1093
1128
|
if split_indices:
|
|
1094
1129
|
single_embeds = mx.split(single_embeds, split_indices)
|
|
1095
1130
|
else:
|
|
1096
1131
|
single_embeds = [single_embeds]
|
|
1097
|
-
|
|
1132
|
+
|
|
1098
1133
|
image_embeds_list.extend(single_embeds)
|
|
1099
|
-
|
|
1134
|
+
|
|
1100
1135
|
# Collect deepstack embeddings
|
|
1101
1136
|
if i == 0:
|
|
1102
1137
|
all_deepstack_embeds = single_deepstack
|
|
1103
1138
|
else:
|
|
1104
1139
|
# Concatenate deepstack embeddings from different images
|
|
1105
1140
|
for j in range(len(all_deepstack_embeds)):
|
|
1106
|
-
all_deepstack_embeds[j] = mx.concatenate(
|
|
1107
|
-
|
|
1141
|
+
all_deepstack_embeds[j] = mx.concatenate(
|
|
1142
|
+
[all_deepstack_embeds[j], single_deepstack[j]], axis=0
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1108
1145
|
deepstack_visual_embeds = all_deepstack_embeds
|
|
1109
|
-
|
|
1146
|
+
|
|
1110
1147
|
# Concatenate all image embeddings for processing
|
|
1111
1148
|
image_embeds = mx.concatenate(image_embeds_list, axis=0)
|
|
1112
|
-
|
|
1149
|
+
|
|
1113
1150
|
# Find all image token positions
|
|
1114
1151
|
image_token_id = 151655 # Default image token ID
|
|
1115
|
-
image_mask =
|
|
1152
|
+
image_mask = input_ids.squeeze(0) == image_token_id
|
|
1116
1153
|
image_mask_np = np.array(image_mask)
|
|
1117
1154
|
image_token_positions = np.where(image_mask_np)[0]
|
|
1118
|
-
|
|
1155
|
+
|
|
1119
1156
|
# Verify we have the correct number of image tokens
|
|
1120
1157
|
expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
|
|
1121
|
-
assert
|
|
1122
|
-
|
|
1158
|
+
assert (
|
|
1159
|
+
len(image_token_positions) == expected_total_tokens
|
|
1160
|
+
), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
|
|
1161
|
+
|
|
1123
1162
|
# Replace image tokens with image embeddings
|
|
1124
1163
|
seq_len = inputs_embeds.shape[0]
|
|
1125
1164
|
result = inputs_embeds
|
|
1126
|
-
|
|
1165
|
+
|
|
1127
1166
|
# Replace image tokens with image embeddings sequentially
|
|
1128
1167
|
embed_idx = 0
|
|
1129
1168
|
for img_embed in image_embeds_list:
|
|
@@ -1133,7 +1172,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1133
1172
|
result = mx.where(
|
|
1134
1173
|
mx.expand_dims(pos_mask, axis=-1),
|
|
1135
1174
|
mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
|
|
1136
|
-
result
|
|
1175
|
+
result,
|
|
1137
1176
|
)
|
|
1138
1177
|
embed_idx += 1
|
|
1139
1178
|
|
|
@@ -1142,10 +1181,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1142
1181
|
cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
|
|
1143
1182
|
if inputs_embeds.ndim == 2:
|
|
1144
1183
|
inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
|
|
1145
|
-
|
|
1184
|
+
|
|
1146
1185
|
if image_mask is not None:
|
|
1147
1186
|
visual_pos_masks = image_mask
|
|
1148
|
-
|
|
1187
|
+
|
|
1149
1188
|
return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
|
|
1150
1189
|
|
|
1151
1190
|
|
|
@@ -1156,7 +1195,9 @@ class Model(nn.Module):
|
|
|
1156
1195
|
self.args = args
|
|
1157
1196
|
self.model = Qwen3VLModel(args)
|
|
1158
1197
|
if not args.text_config.tie_word_embeddings:
|
|
1159
|
-
self.lm_head = nn.Linear(
|
|
1198
|
+
self.lm_head = nn.Linear(
|
|
1199
|
+
args.text_config.hidden_size, args.text_config.vocab_size, bias=False
|
|
1200
|
+
)
|
|
1160
1201
|
|
|
1161
1202
|
def __call__(
|
|
1162
1203
|
self,
|
|
@@ -1164,7 +1205,7 @@ class Model(nn.Module):
|
|
|
1164
1205
|
mask: mx.array = None,
|
|
1165
1206
|
cache=None,
|
|
1166
1207
|
inputs_embeds: Optional[mx.array] = None,
|
|
1167
|
-
pixel_values: Optional[mx.array] = None,
|
|
1208
|
+
pixel_values: Optional[mx.array] = None,
|
|
1168
1209
|
image_grid_thw: Optional[mx.array] = None,
|
|
1169
1210
|
visual_pos_masks: Optional[mx.array] = None,
|
|
1170
1211
|
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
@@ -1195,13 +1236,13 @@ class Model(nn.Module):
|
|
|
1195
1236
|
sanitized = {}
|
|
1196
1237
|
for k, v in weights.items():
|
|
1197
1238
|
sanitized[k] = v
|
|
1198
|
-
|
|
1239
|
+
|
|
1199
1240
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1200
1241
|
if self.args.text_config.tie_word_embeddings:
|
|
1201
1242
|
sanitized.pop("lm_head.weight", None)
|
|
1202
|
-
|
|
1243
|
+
|
|
1203
1244
|
return sanitized
|
|
1204
1245
|
|
|
1205
1246
|
@property
|
|
1206
1247
|
def layers(self):
|
|
1207
|
-
return self.model.language_model.layers
|
|
1248
|
+
return self.model.language_model.layers
|