nexaai 1.0.19rc16__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc18__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +162 -65
- nexaai/mlx_backend/vlm/interface.py +81 -29
- nexaai/mlx_backend/vlm/main.py +58 -13
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +331 -276
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +17 -2
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc18.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc18.dist-info}/RECORD +19 -19
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc18.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc16.dist-info → nexaai-1.0.19rc18.dist-info}/top_level.txt +0 -0
|
@@ -120,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
|
|
|
120
120
|
|
|
121
121
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
122
122
|
self.proj = nn.Conv3d(
|
|
123
|
-
self.in_channels,
|
|
124
|
-
self.embed_dim,
|
|
125
|
-
kernel_size=kernel_size,
|
|
126
|
-
stride=kernel_size,
|
|
127
|
-
bias=True
|
|
123
|
+
self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
|
|
128
124
|
)
|
|
129
125
|
|
|
130
126
|
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
131
127
|
target_dtype = self.proj.weight.dtype
|
|
132
|
-
|
|
128
|
+
|
|
133
129
|
# Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
|
|
134
130
|
# This matches the PyTorch ground truth exactly
|
|
135
131
|
hidden_states = hidden_states.reshape(
|
|
136
132
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
137
133
|
)
|
|
138
|
-
|
|
134
|
+
|
|
139
135
|
# Convert to MLX format: [batch, temporal, height, width, channels]
|
|
140
136
|
hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
|
|
141
|
-
|
|
137
|
+
|
|
142
138
|
# Apply conv3d with target dtype and reshape to match PyTorch output
|
|
143
139
|
hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
|
|
144
|
-
|
|
140
|
+
|
|
145
141
|
return hidden_states
|
|
146
142
|
|
|
147
143
|
|
|
@@ -163,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
|
|
|
163
159
|
class VisionPatchMerger(nn.Module):
|
|
164
160
|
def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
|
|
165
161
|
super().__init__()
|
|
166
|
-
self.hidden_size = config.hidden_size * (config.spatial_merge_size
|
|
162
|
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
|
167
163
|
self.use_postshuffle_norm = use_postshuffle_norm
|
|
168
|
-
|
|
164
|
+
|
|
169
165
|
norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
|
|
170
|
-
self.
|
|
166
|
+
self.norm = nn.LayerNorm(norm_size, eps=1e-6)
|
|
171
167
|
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
|
172
168
|
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
|
173
169
|
|
|
174
170
|
def __call__(self, x: mx.array) -> mx.array:
|
|
175
171
|
if self.use_postshuffle_norm:
|
|
176
|
-
x = self.
|
|
172
|
+
x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
|
|
177
173
|
else:
|
|
178
|
-
x = self.
|
|
179
|
-
|
|
174
|
+
x = self.norm(x).reshape(-1, self.hidden_size)
|
|
175
|
+
|
|
180
176
|
x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
|
|
181
177
|
return x
|
|
182
178
|
|
|
@@ -187,8 +183,8 @@ class VisionAttention(nn.Module):
|
|
|
187
183
|
self.dim = config.hidden_size
|
|
188
184
|
self.num_heads = config.num_heads
|
|
189
185
|
self.head_dim = self.dim // self.num_heads
|
|
190
|
-
self.scaling = self.head_dim
|
|
191
|
-
|
|
186
|
+
self.scaling = self.head_dim**-0.5
|
|
187
|
+
|
|
192
188
|
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
|
193
189
|
self.proj = nn.Linear(self.dim, self.dim)
|
|
194
190
|
|
|
@@ -204,51 +200,48 @@ class VisionAttention(nn.Module):
|
|
|
204
200
|
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
|
|
205
201
|
qkv = qkv.transpose(1, 0, 2, 3)
|
|
206
202
|
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
207
|
-
|
|
203
|
+
|
|
208
204
|
cos, sin = position_embeddings
|
|
209
|
-
query_states, key_states = apply_rotary_pos_emb_vision(
|
|
210
|
-
query_states, key_states, cos, sin
|
|
211
|
-
)
|
|
205
|
+
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
212
206
|
|
|
213
207
|
query_states = query_states.transpose(1, 0, 2)
|
|
214
208
|
key_states = key_states.transpose(1, 0, 2)
|
|
215
209
|
value_states = value_states.transpose(1, 0, 2)
|
|
216
|
-
|
|
210
|
+
|
|
217
211
|
query_states = mx.expand_dims(query_states, axis=0)
|
|
218
212
|
key_states = mx.expand_dims(key_states, axis=0)
|
|
219
213
|
value_states = mx.expand_dims(value_states, axis=0)
|
|
220
|
-
|
|
214
|
+
|
|
221
215
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
222
|
-
|
|
216
|
+
|
|
223
217
|
split_indices = []
|
|
224
218
|
cumsum = 0
|
|
225
219
|
for length in lengths[:-1]:
|
|
226
220
|
cumsum += int(length)
|
|
227
221
|
split_indices.append(cumsum)
|
|
228
|
-
|
|
222
|
+
|
|
229
223
|
if split_indices:
|
|
230
224
|
q_splits = mx.split(query_states, split_indices, axis=1)
|
|
231
225
|
k_splits = mx.split(key_states, split_indices, axis=1)
|
|
232
226
|
v_splits = mx.split(value_states, split_indices, axis=1)
|
|
233
227
|
else:
|
|
234
228
|
q_splits = [query_states]
|
|
235
|
-
k_splits = [key_states]
|
|
229
|
+
k_splits = [key_states]
|
|
236
230
|
v_splits = [value_states]
|
|
237
|
-
|
|
231
|
+
|
|
238
232
|
attn_outputs = []
|
|
239
233
|
for q, k, v in zip(q_splits, k_splits, v_splits):
|
|
240
234
|
attn_out = scaled_dot_product_attention(
|
|
241
|
-
q, k, v,
|
|
242
|
-
scale=self.scaling, mask=None, cache=None
|
|
235
|
+
q, k, v, scale=self.scaling, mask=None, cache=None
|
|
243
236
|
)
|
|
244
237
|
attn_outputs.append(attn_out)
|
|
245
|
-
|
|
238
|
+
|
|
246
239
|
attn_output = mx.concatenate(attn_outputs, axis=1)
|
|
247
|
-
|
|
240
|
+
|
|
248
241
|
attn_output = attn_output[0].transpose(1, 0, 2)
|
|
249
242
|
attn_output = attn_output.reshape(seq_length, -1)
|
|
250
243
|
attn_output = self.proj(attn_output)
|
|
251
|
-
|
|
244
|
+
|
|
252
245
|
return attn_output
|
|
253
246
|
|
|
254
247
|
|
|
@@ -284,7 +277,7 @@ class VisionModel(nn.Module):
|
|
|
284
277
|
|
|
285
278
|
self.patch_embed = VisionPatchEmbed(config)
|
|
286
279
|
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
|
287
|
-
self.num_grid_per_side = int(config.num_position_embeddings
|
|
280
|
+
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
|
|
288
281
|
|
|
289
282
|
head_dim = config.hidden_size // config.num_heads
|
|
290
283
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
@@ -310,7 +303,7 @@ class VisionModel(nn.Module):
|
|
|
310
303
|
num_frames = int(grid_thw[i, 0].item())
|
|
311
304
|
height = int(grid_thw[i, 1].item())
|
|
312
305
|
width = int(grid_thw[i, 2].item())
|
|
313
|
-
|
|
306
|
+
|
|
314
307
|
merged_h, merged_w = height // merge_size, width // merge_size
|
|
315
308
|
|
|
316
309
|
block_rows = mx.arange(merged_h) # block row indices
|
|
@@ -322,8 +315,12 @@ class VisionModel(nn.Module):
|
|
|
322
315
|
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
|
323
316
|
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
|
324
317
|
|
|
325
|
-
row_idx = mx.broadcast_to(
|
|
326
|
-
|
|
318
|
+
row_idx = mx.broadcast_to(
|
|
319
|
+
row_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
320
|
+
).reshape(-1)
|
|
321
|
+
col_idx = mx.broadcast_to(
|
|
322
|
+
col_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
323
|
+
).reshape(-1)
|
|
327
324
|
|
|
328
325
|
coords = mx.stack([row_idx, col_idx], axis=-1)
|
|
329
326
|
|
|
@@ -334,19 +331,19 @@ class VisionModel(nn.Module):
|
|
|
334
331
|
|
|
335
332
|
# Concatenate all coordinate parts
|
|
336
333
|
pos_ids = mx.concatenate(pos_ids_parts, axis=0)
|
|
337
|
-
|
|
334
|
+
|
|
338
335
|
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
|
339
336
|
embeddings = embeddings.reshape(embeddings.shape[0], -1)
|
|
340
337
|
return embeddings
|
|
341
338
|
|
|
342
339
|
def fast_pos_embed_interpolate(self, grid_thw: mx.array):
|
|
343
340
|
patch_pos_embeds = []
|
|
344
|
-
|
|
341
|
+
|
|
345
342
|
for i in range(grid_thw.shape[0]):
|
|
346
343
|
t = int(grid_thw[i, 0].item())
|
|
347
344
|
h = int(grid_thw[i, 1].item())
|
|
348
345
|
w = int(grid_thw[i, 2].item())
|
|
349
|
-
|
|
346
|
+
|
|
350
347
|
# Simple position embedding interpolation
|
|
351
348
|
h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
|
|
352
349
|
w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
|
|
@@ -383,37 +380,41 @@ class VisionModel(nn.Module):
|
|
|
383
380
|
|
|
384
381
|
# Repeat for temporal dimension and apply spatial merging
|
|
385
382
|
pos_embed = mx.tile(pos_embed, (t, 1))
|
|
386
|
-
|
|
383
|
+
|
|
387
384
|
# Apply spatial merging pattern
|
|
388
385
|
merge_size = self.config.spatial_merge_size
|
|
389
|
-
pos_embed = pos_embed.reshape(
|
|
386
|
+
pos_embed = pos_embed.reshape(
|
|
387
|
+
t, h // merge_size, merge_size, w // merge_size, merge_size, -1
|
|
388
|
+
)
|
|
390
389
|
pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
|
|
391
390
|
pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
|
|
392
|
-
|
|
391
|
+
|
|
393
392
|
patch_pos_embeds.append(pos_embed)
|
|
394
|
-
|
|
393
|
+
|
|
395
394
|
return mx.concatenate(patch_pos_embeds, axis=0)
|
|
396
395
|
|
|
397
|
-
def __call__(
|
|
396
|
+
def __call__(
|
|
397
|
+
self, hidden_states: mx.array, grid_thw: mx.array
|
|
398
|
+
) -> Tuple[mx.array, List[mx.array]]:
|
|
398
399
|
hidden_states = self.patch_embed(hidden_states)
|
|
399
|
-
|
|
400
|
+
|
|
400
401
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
401
402
|
hidden_states = hidden_states + pos_embeds
|
|
402
403
|
|
|
403
404
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
404
405
|
seq_len = hidden_states.shape[0]
|
|
405
|
-
|
|
406
|
+
|
|
406
407
|
emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
|
|
407
408
|
position_embeddings = (mx.cos(emb), mx.sin(emb))
|
|
408
409
|
|
|
409
|
-
|
|
410
|
+
# Create cumulative sequence lengths (following HuggingFace implementation)
|
|
410
411
|
# torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
|
|
411
412
|
seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
|
|
412
413
|
seq_lens = []
|
|
413
414
|
for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
|
|
414
415
|
seq_lens.extend([seq_len] * int(repeats))
|
|
415
416
|
seq_lens = mx.array(seq_lens)
|
|
416
|
-
|
|
417
|
+
|
|
417
418
|
# Then compute cumulative sum
|
|
418
419
|
cu_seqlens = mx.cumsum(seq_lens)
|
|
419
420
|
# Pad with 0 at the beginning
|
|
@@ -441,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
441
442
|
self.config = config
|
|
442
443
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
443
444
|
self.original_max_seq_len = config.max_position_embeddings
|
|
444
|
-
|
|
445
|
+
|
|
445
446
|
# MRoPE configuration
|
|
446
447
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
447
448
|
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
|
@@ -449,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
449
450
|
else:
|
|
450
451
|
self.rope_type = "default"
|
|
451
452
|
self.mrope_section = [24, 20, 20]
|
|
452
|
-
|
|
453
|
+
|
|
453
454
|
# Store parameters for computing inv_freq on the fly
|
|
454
455
|
self.head_dim = config.head_dim
|
|
455
456
|
self.theta = config.rope_theta
|
|
456
|
-
|
|
457
|
+
|
|
457
458
|
# Attention scaling (simplified - may need adjustment based on actual config)
|
|
458
459
|
self.attention_scaling = 1.0
|
|
459
460
|
|
|
460
461
|
def _get_inv_freq(self):
|
|
461
462
|
"""Compute inverse frequencies on the fly"""
|
|
462
|
-
inv_freq = 1.0 / (
|
|
463
|
+
inv_freq = 1.0 / (
|
|
464
|
+
self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
|
|
465
|
+
)
|
|
463
466
|
# Expand for 3 dimensions (T, H, W)
|
|
464
467
|
return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
|
|
465
468
|
|
|
@@ -485,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
485
488
|
Args:
|
|
486
489
|
x: Input tensor for dtype reference
|
|
487
490
|
position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
|
|
488
|
-
|
|
491
|
+
|
|
489
492
|
Returns:
|
|
490
493
|
cos, sin: Cosine and sine embeddings
|
|
491
494
|
"""
|
|
492
495
|
# Handle 2D position_ids by expanding to 3D for MRoPE
|
|
493
496
|
if position_ids.ndim == 2:
|
|
494
|
-
position_ids = mx.broadcast_to(
|
|
495
|
-
|
|
497
|
+
position_ids = mx.broadcast_to(
|
|
498
|
+
position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
|
|
499
|
+
)
|
|
500
|
+
|
|
496
501
|
batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
|
|
497
|
-
|
|
502
|
+
|
|
498
503
|
# Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
|
|
499
504
|
inv_freq_expanded = mx.broadcast_to(
|
|
500
|
-
self._get_inv_freq()[:, None, None, :],
|
|
501
|
-
(3, batch_size, 1, self._get_inv_freq().shape[-1])
|
|
505
|
+
self._get_inv_freq()[:, None, None, :],
|
|
506
|
+
(3, batch_size, 1, self._get_inv_freq().shape[-1]),
|
|
502
507
|
)
|
|
503
|
-
|
|
508
|
+
|
|
504
509
|
# Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
|
|
505
510
|
position_ids_expanded = position_ids[..., None].astype(mx.float32)
|
|
506
|
-
|
|
511
|
+
|
|
507
512
|
# Compute frequencies: (3, batch_size, seq_len, dim//2)
|
|
508
513
|
freqs = inv_freq_expanded * position_ids_expanded
|
|
509
|
-
|
|
514
|
+
|
|
510
515
|
# Apply interleaved MRoPE
|
|
511
516
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
512
|
-
|
|
517
|
+
|
|
513
518
|
# Create embeddings
|
|
514
519
|
emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
|
|
515
520
|
cos = mx.cos(emb) * self.attention_scaling
|
|
516
521
|
sin = mx.sin(emb) * self.attention_scaling
|
|
517
|
-
|
|
522
|
+
|
|
518
523
|
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
519
524
|
|
|
520
525
|
|
|
@@ -523,12 +528,12 @@ class TextAttention(nn.Module):
|
|
|
523
528
|
super().__init__()
|
|
524
529
|
self.config = config
|
|
525
530
|
self.layer_idx = layer_idx
|
|
526
|
-
|
|
531
|
+
|
|
527
532
|
dim = config.hidden_size
|
|
528
533
|
self.n_heads = config.num_attention_heads
|
|
529
534
|
self.n_kv_heads = config.num_key_value_heads
|
|
530
535
|
self.head_dim = config.head_dim
|
|
531
|
-
self.scale = self.head_dim
|
|
536
|
+
self.scale = self.head_dim**-0.5
|
|
532
537
|
|
|
533
538
|
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
|
|
534
539
|
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
@@ -537,7 +542,7 @@ class TextAttention(nn.Module):
|
|
|
537
542
|
|
|
538
543
|
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
539
544
|
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
540
|
-
|
|
545
|
+
|
|
541
546
|
# Initialize rope directly
|
|
542
547
|
self.rope = initialize_rope(
|
|
543
548
|
config.head_dim,
|
|
@@ -573,8 +578,23 @@ class TextAttention(nn.Module):
|
|
|
573
578
|
keys, values = cache.update_and_fetch(keys, values)
|
|
574
579
|
else:
|
|
575
580
|
if cache is not None:
|
|
576
|
-
|
|
577
|
-
|
|
581
|
+
# Handle different types of rope_deltas: scalar, array, or None
|
|
582
|
+
if rope_deltas is None:
|
|
583
|
+
offset_delta = 0
|
|
584
|
+
elif isinstance(rope_deltas, (int, float)):
|
|
585
|
+
# rope_deltas is a scalar
|
|
586
|
+
offset_delta = rope_deltas
|
|
587
|
+
elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
|
|
588
|
+
# rope_deltas is an array with single element
|
|
589
|
+
offset_delta = rope_deltas.item()
|
|
590
|
+
elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
|
|
591
|
+
# rope_deltas is an array with multiple elements, take first
|
|
592
|
+
offset_delta = rope_deltas.reshape(-1)[0].item()
|
|
593
|
+
else:
|
|
594
|
+
offset_delta = 0
|
|
595
|
+
|
|
596
|
+
queries = self.rope(queries, offset=cache.offset + offset_delta)
|
|
597
|
+
keys = self.rope(keys, offset=cache.offset + offset_delta)
|
|
578
598
|
keys, values = cache.update_and_fetch(keys, values)
|
|
579
599
|
else:
|
|
580
600
|
queries = self.rope(queries)
|
|
@@ -618,7 +638,7 @@ class TextDecoderLayer(nn.Module):
|
|
|
618
638
|
) -> mx.array:
|
|
619
639
|
residual = hidden_states
|
|
620
640
|
hidden_states = self.input_layernorm(hidden_states)
|
|
621
|
-
|
|
641
|
+
|
|
622
642
|
hidden_states, _ = self.self_attn(
|
|
623
643
|
hidden_states=hidden_states,
|
|
624
644
|
attention_mask=attention_mask,
|
|
@@ -640,11 +660,10 @@ class TextModel(nn.Module):
|
|
|
640
660
|
super().__init__()
|
|
641
661
|
self.config = config
|
|
642
662
|
self.vocab_size = config.vocab_size
|
|
643
|
-
|
|
663
|
+
|
|
644
664
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
645
665
|
self.layers = [
|
|
646
|
-
TextDecoderLayer(config, layer_idx)
|
|
647
|
-
for layer_idx in range(config.num_hidden_layers)
|
|
666
|
+
TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
|
|
648
667
|
]
|
|
649
668
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
650
669
|
self.rotary_emb = TextRotaryEmbedding(config)
|
|
@@ -701,7 +720,9 @@ class TextModel(nn.Module):
|
|
|
701
720
|
rope_deltas=rope_deltas,
|
|
702
721
|
)
|
|
703
722
|
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
|
|
704
|
-
hidden_states = self._deepstack_process(
|
|
723
|
+
hidden_states = self._deepstack_process(
|
|
724
|
+
hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
|
|
725
|
+
)
|
|
705
726
|
hidden_states = self.norm(hidden_states)
|
|
706
727
|
return hidden_states
|
|
707
728
|
|
|
@@ -712,17 +733,17 @@ class VEGModel(nn.Module):
|
|
|
712
733
|
super().__init__()
|
|
713
734
|
self.config = vision_config
|
|
714
735
|
self.visual = VisionModel(vision_config)
|
|
715
|
-
|
|
736
|
+
|
|
716
737
|
def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
|
|
717
738
|
return self.visual(pixel_values, image_grid_thw)
|
|
718
|
-
|
|
739
|
+
|
|
719
740
|
def sanitize(self, weights):
|
|
720
741
|
sanitized = {}
|
|
721
742
|
for k, v in weights.items():
|
|
722
|
-
if
|
|
743
|
+
if "visual." in k:
|
|
723
744
|
# Remove prefixes to match our model structure
|
|
724
|
-
clean_key = k.replace(
|
|
725
|
-
sanitized[f
|
|
745
|
+
clean_key = k.replace("model.visual.", "").replace("visual.", "")
|
|
746
|
+
sanitized[f"visual.{clean_key}"] = v
|
|
726
747
|
return sanitized
|
|
727
748
|
|
|
728
749
|
|
|
@@ -735,140 +756,164 @@ class LLMModel(nn.Module):
|
|
|
735
756
|
self.language_model = TextModel(text_config)
|
|
736
757
|
if not text_config.tie_word_embeddings:
|
|
737
758
|
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
|
738
|
-
|
|
759
|
+
|
|
739
760
|
def get_rope_index(
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
761
|
+
self,
|
|
762
|
+
input_ids: Optional[mx.array] = None,
|
|
763
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
764
|
+
attention_mask: Optional[mx.array] = None,
|
|
765
|
+
) -> Tuple[mx.array, mx.array]:
|
|
766
|
+
"""Simplified version for images only (no video support)."""
|
|
767
|
+
|
|
768
|
+
spatial_merge_size = 2
|
|
769
|
+
image_token_id = 151655
|
|
770
|
+
vision_start_token_id = 151652
|
|
771
|
+
mrope_position_deltas = []
|
|
772
|
+
|
|
773
|
+
if input_ids is not None and image_grid_thw is not None:
|
|
774
|
+
total_input_ids = input_ids
|
|
775
|
+
if attention_mask is None:
|
|
776
|
+
attention_mask = mx.ones_like(total_input_ids)
|
|
777
|
+
|
|
778
|
+
batch_size, seq_len = input_ids.shape
|
|
779
|
+
position_ids_list = []
|
|
780
|
+
image_index = 0
|
|
781
|
+
|
|
782
|
+
for i in range(batch_size):
|
|
783
|
+
input_ids_seq = total_input_ids[i]
|
|
784
|
+
mask_seq = attention_mask[i]
|
|
785
|
+
|
|
786
|
+
# Use mask to get valid length
|
|
787
|
+
valid_length = int(mx.sum(mask_seq).item())
|
|
788
|
+
input_ids_seq = input_ids_seq[:valid_length]
|
|
789
|
+
|
|
790
|
+
image_nums = 0
|
|
791
|
+
# Find vision start tokens by iterating through the sequence
|
|
792
|
+
vision_start_positions = []
|
|
793
|
+
for pos in range(input_ids_seq.shape[0]):
|
|
794
|
+
if input_ids_seq[pos].item() == vision_start_token_id:
|
|
795
|
+
vision_start_positions.append(pos)
|
|
796
|
+
|
|
797
|
+
if len(vision_start_positions) > 0:
|
|
798
|
+
for pos in vision_start_positions:
|
|
799
|
+
if pos + 1 < input_ids_seq.shape[0]:
|
|
800
|
+
if input_ids_seq[pos + 1].item() == image_token_id:
|
|
801
|
+
image_nums += 1
|
|
802
|
+
|
|
803
|
+
input_tokens = input_ids_seq.tolist()
|
|
804
|
+
llm_pos_ids_list = []
|
|
805
|
+
st = 0
|
|
806
|
+
remain_images = image_nums
|
|
807
|
+
|
|
808
|
+
for _ in range(image_nums):
|
|
809
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
810
|
+
|
|
811
|
+
t = image_grid_thw[image_index, 0].item()
|
|
812
|
+
h = image_grid_thw[image_index, 1].item()
|
|
813
|
+
w = image_grid_thw[image_index, 2].item()
|
|
814
|
+
image_index += 1
|
|
815
|
+
remain_images -= 1
|
|
816
|
+
ed = ed_image
|
|
817
|
+
|
|
818
|
+
llm_grid_t = int(t)
|
|
819
|
+
llm_grid_h = int(h) // spatial_merge_size
|
|
820
|
+
llm_grid_w = int(w) // spatial_merge_size
|
|
821
|
+
text_len = ed - st
|
|
822
|
+
|
|
823
|
+
st_idx = (
|
|
824
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
825
|
+
)
|
|
826
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
827
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
828
|
+
llm_pos_ids_list.append(text_pos)
|
|
829
|
+
|
|
830
|
+
# t_index is always 0 because llm_grid_t is always 1 for images
|
|
831
|
+
t_index = mx.arange(llm_grid_t).reshape(-1, 1)
|
|
832
|
+
t_index = mx.broadcast_to(
|
|
833
|
+
t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
|
|
834
|
+
).reshape(-1)
|
|
835
|
+
|
|
836
|
+
h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
|
|
837
|
+
h_index = mx.broadcast_to(
|
|
838
|
+
h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
839
|
+
).reshape(-1)
|
|
840
|
+
|
|
841
|
+
w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
|
|
842
|
+
w_index = mx.broadcast_to(
|
|
843
|
+
w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
844
|
+
).reshape(-1)
|
|
845
|
+
|
|
846
|
+
vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
847
|
+
llm_pos_ids_list.append(vision_pos)
|
|
848
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
849
|
+
|
|
850
|
+
if st < len(input_tokens):
|
|
851
|
+
st_idx = (
|
|
852
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
853
|
+
)
|
|
854
|
+
text_len = len(input_tokens) - st
|
|
855
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
856
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
857
|
+
llm_pos_ids_list.append(text_pos)
|
|
858
|
+
|
|
859
|
+
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
860
|
+
|
|
861
|
+
# Create position_ids for this batch item, pad to seq_len
|
|
862
|
+
batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
|
|
863
|
+
valid_length = min(seq_len, llm_positions.shape[1])
|
|
864
|
+
|
|
865
|
+
# Create new arrays for each dimension
|
|
866
|
+
pos_dim0 = mx.concatenate(
|
|
867
|
+
[
|
|
868
|
+
llm_positions[0, :valid_length],
|
|
869
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
870
|
+
]
|
|
871
|
+
)
|
|
872
|
+
pos_dim1 = mx.concatenate(
|
|
873
|
+
[
|
|
874
|
+
llm_positions[1, :valid_length],
|
|
875
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
pos_dim2 = mx.concatenate(
|
|
879
|
+
[
|
|
880
|
+
llm_positions[2, :valid_length],
|
|
881
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
882
|
+
]
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
|
|
886
|
+
position_ids_list.append(batch_position_ids)
|
|
887
|
+
|
|
888
|
+
mrope_position_deltas.append(
|
|
889
|
+
llm_positions.max().item() + 1 - len(total_input_ids[i])
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Stack all batch position_ids
|
|
893
|
+
position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
|
|
894
|
+
mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
|
|
895
|
+
return position_ids, mrope_position_deltas
|
|
896
|
+
else:
|
|
897
|
+
if attention_mask is not None:
|
|
898
|
+
position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
|
|
899
|
+
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
|
900
|
+
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
901
|
+
position_ids = mx.broadcast_to(
|
|
902
|
+
position_ids, (3, position_ids.shape[1], position_ids.shape[2])
|
|
903
|
+
)
|
|
904
|
+
max_position_ids = mx.max(
|
|
905
|
+
mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
|
|
906
|
+
)
|
|
907
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
852
908
|
else:
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
|
|
862
|
-
else:
|
|
863
|
-
seq_len = input_ids.shape[1]
|
|
864
|
-
batch_size = input_ids.shape[0]
|
|
865
|
-
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
866
|
-
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
867
|
-
# 1D zeros for rope deltas
|
|
868
|
-
mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
|
|
869
|
-
|
|
870
|
-
return position_ids, mrope_position_deltas
|
|
871
|
-
|
|
909
|
+
seq_len = input_ids.shape[1]
|
|
910
|
+
batch_size = input_ids.shape[0]
|
|
911
|
+
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
912
|
+
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
913
|
+
mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
|
|
914
|
+
|
|
915
|
+
return position_ids, mrope_position_deltas
|
|
916
|
+
|
|
872
917
|
def __call__(
|
|
873
918
|
self,
|
|
874
919
|
inputs: mx.array = None,
|
|
@@ -896,35 +941,41 @@ class LLMModel(nn.Module):
|
|
|
896
941
|
return self.language_model.embed_tokens.as_linear(out)
|
|
897
942
|
else:
|
|
898
943
|
return self.lm_head(out)
|
|
899
|
-
|
|
944
|
+
|
|
900
945
|
def sanitize(self, weights):
|
|
901
946
|
sanitized = {}
|
|
902
947
|
for k, v in weights.items():
|
|
903
|
-
if not (
|
|
948
|
+
if not ("visual." in k):
|
|
904
949
|
# Handle key mapping from combined model to LLM-only model
|
|
905
950
|
clean_key = k
|
|
906
|
-
|
|
951
|
+
|
|
907
952
|
# Remove model. prefix if present
|
|
908
|
-
if clean_key.startswith(
|
|
953
|
+
if clean_key.startswith("model."):
|
|
909
954
|
clean_key = clean_key[6:] # Remove 'model.'
|
|
910
|
-
|
|
955
|
+
|
|
911
956
|
# Map language_ prefixed keys to language_model structure
|
|
912
|
-
if clean_key.startswith(
|
|
913
|
-
if clean_key.startswith(
|
|
914
|
-
clean_key =
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
elif clean_key.startswith(
|
|
918
|
-
clean_key =
|
|
919
|
-
|
|
957
|
+
if clean_key.startswith("language_"):
|
|
958
|
+
if clean_key.startswith("language_layers."):
|
|
959
|
+
clean_key = (
|
|
960
|
+
"language_model.layers." + clean_key[16:]
|
|
961
|
+
) # Map to language_model.layers.
|
|
962
|
+
elif clean_key.startswith("language_embed_tokens."):
|
|
963
|
+
clean_key = (
|
|
964
|
+
"language_model.embed_tokens." + clean_key[22:]
|
|
965
|
+
) # Map to language_model.embed_tokens.
|
|
966
|
+
elif clean_key.startswith("language_norm."):
|
|
967
|
+
clean_key = (
|
|
968
|
+
"language_model.norm." + clean_key[14:]
|
|
969
|
+
) # Map to language_model.norm.
|
|
970
|
+
|
|
920
971
|
sanitized[clean_key] = v
|
|
921
|
-
|
|
972
|
+
|
|
922
973
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
923
974
|
if self.args.tie_word_embeddings:
|
|
924
975
|
sanitized.pop("lm_head.weight", None)
|
|
925
|
-
|
|
976
|
+
|
|
926
977
|
return sanitized
|
|
927
|
-
|
|
978
|
+
|
|
928
979
|
@property
|
|
929
980
|
def layers(self):
|
|
930
981
|
return self.language_model.layers
|
|
@@ -938,39 +989,36 @@ class Qwen3VLModel(nn.Module):
|
|
|
938
989
|
self.config = args
|
|
939
990
|
self.visual = VisionModel(args.vision_config)
|
|
940
991
|
self.language_model = TextModel(args.text_config)
|
|
941
|
-
|
|
992
|
+
|
|
942
993
|
def sanitize(self, weights):
|
|
943
994
|
# Map weights to match the combined model structure
|
|
944
995
|
sanitized = {}
|
|
945
996
|
for k, v in weights.items():
|
|
946
997
|
# Remove 'model.' prefix if present to match our structure
|
|
947
|
-
clean_key = k.replace(
|
|
998
|
+
clean_key = k.replace("model.", "") if k.startswith("model.") else k
|
|
948
999
|
sanitized[clean_key] = v
|
|
949
1000
|
return sanitized
|
|
950
1001
|
|
|
951
|
-
def get_image_features(
|
|
952
|
-
self,
|
|
953
|
-
pixel_values: mx.array,
|
|
954
|
-
image_grid_thw: Optional[mx.array] = None
|
|
955
|
-
):
|
|
1002
|
+
def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
|
|
956
1003
|
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
|
|
957
1004
|
# Split based on grid dimensions
|
|
958
1005
|
if image_grid_thw is not None:
|
|
959
|
-
split_sizes = (
|
|
1006
|
+
split_sizes = (
|
|
1007
|
+
mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
|
|
1008
|
+
).tolist()
|
|
960
1009
|
# Convert sizes to indices for mx.split (cumulative sum, excluding the last)
|
|
961
1010
|
split_indices = []
|
|
962
1011
|
cumsum = 0
|
|
963
1012
|
for size in split_sizes[:-1]: # Exclude last element
|
|
964
1013
|
cumsum += size
|
|
965
1014
|
split_indices.append(cumsum)
|
|
966
|
-
|
|
1015
|
+
|
|
967
1016
|
if split_indices: # Only split if we have indices
|
|
968
1017
|
image_embeds = mx.split(image_embeds, split_indices)
|
|
969
1018
|
else:
|
|
970
1019
|
image_embeds = [image_embeds] # Single image case
|
|
971
1020
|
return image_embeds, deepstack_visual_embeds
|
|
972
1021
|
|
|
973
|
-
|
|
974
1022
|
def __call__(
|
|
975
1023
|
self,
|
|
976
1024
|
input_ids: mx.array = None,
|
|
@@ -989,26 +1037,25 @@ class Qwen3VLModel(nn.Module):
|
|
|
989
1037
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
990
1038
|
|
|
991
1039
|
# Process images
|
|
992
|
-
|
|
1040
|
+
|
|
993
1041
|
if pixel_values is not None:
|
|
994
1042
|
image_embeds, deepstack_visual_embeds = self.get_image_features(
|
|
995
1043
|
pixel_values, image_grid_thw
|
|
996
1044
|
)
|
|
997
|
-
|
|
1045
|
+
|
|
998
1046
|
# Create masks and embed visual features
|
|
999
1047
|
if isinstance(image_embeds, list):
|
|
1000
1048
|
image_embeds = mx.concatenate(image_embeds, axis=0)
|
|
1001
|
-
|
|
1049
|
+
|
|
1002
1050
|
# Find image token positions and replace with visual embeddings
|
|
1003
|
-
image_mask =
|
|
1051
|
+
image_mask = input_ids == self.args.image_token_id
|
|
1004
1052
|
visual_pos_masks = image_mask
|
|
1005
|
-
|
|
1053
|
+
|
|
1006
1054
|
# Replace image tokens with visual embeddings
|
|
1007
1055
|
inputs_embeds = inputs_embeds.at[image_mask].set(
|
|
1008
1056
|
image_embeds.astype(inputs_embeds.dtype)
|
|
1009
1057
|
)
|
|
1010
1058
|
|
|
1011
|
-
|
|
1012
1059
|
outputs = self.language_model(
|
|
1013
1060
|
inputs_embeds=inputs_embeds,
|
|
1014
1061
|
attention_mask=attention_mask,
|
|
@@ -1026,28 +1073,28 @@ class Qwen3VLModel(nn.Module):
|
|
|
1026
1073
|
def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
|
|
1027
1074
|
"""
|
|
1028
1075
|
Handle the processing of multimodal embeddings including image features and position encoding.
|
|
1029
|
-
|
|
1076
|
+
|
|
1030
1077
|
This function processes vision and text inputs to create unified embeddings that can be fed
|
|
1031
1078
|
into the language model. It handles:
|
|
1032
1079
|
- Vision feature extraction from pixel values
|
|
1033
1080
|
- Deepstack visual embedding collection
|
|
1034
1081
|
- Image token replacement in text embeddings
|
|
1035
1082
|
- Position encoding setup for MRoPE (Multi-dimensional RoPE)
|
|
1036
|
-
|
|
1083
|
+
|
|
1037
1084
|
Args:
|
|
1038
1085
|
vision_model: The vision encoder model (VEGModel instance)
|
|
1039
|
-
llm_model: The language model (LLMModel instance)
|
|
1086
|
+
llm_model: The language model (LLMModel instance)
|
|
1040
1087
|
input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
|
|
1041
1088
|
pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
|
|
1042
1089
|
image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
|
|
1043
|
-
|
|
1090
|
+
|
|
1044
1091
|
Returns:
|
|
1045
1092
|
tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
|
|
1046
1093
|
- inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
|
|
1047
1094
|
- deepstack_visual_embeds: Multi-layer visual features for deepstack processing
|
|
1048
1095
|
- visual_pos_masks: Boolean mask indicating image token positions
|
|
1049
1096
|
- cos: Cosine values for rotary position encoding
|
|
1050
|
-
- sin: Sine values for rotary position encoding
|
|
1097
|
+
- sin: Sine values for rotary position encoding
|
|
1051
1098
|
- rope_deltas: Position offset deltas for rope computation
|
|
1052
1099
|
"""
|
|
1053
1100
|
inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
|
|
@@ -1056,74 +1103,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1056
1103
|
cos = None
|
|
1057
1104
|
sin = None
|
|
1058
1105
|
rope_deltas = 0
|
|
1059
|
-
|
|
1106
|
+
|
|
1060
1107
|
if pixel_values is not None:
|
|
1061
1108
|
if pixel_values.ndim == 4:
|
|
1062
1109
|
pixel_values = mx.expand_dims(pixel_values, axis=2)
|
|
1063
|
-
|
|
1110
|
+
|
|
1064
1111
|
# Process each image individually to prevent feature mixing
|
|
1065
1112
|
image_embeds_list = []
|
|
1066
1113
|
all_deepstack_embeds = []
|
|
1067
|
-
|
|
1114
|
+
|
|
1068
1115
|
# Calculate cumulative indices for each image
|
|
1069
1116
|
cumulative_patches = 0
|
|
1070
|
-
|
|
1117
|
+
|
|
1071
1118
|
for i in range(image_grid_thw.shape[0]):
|
|
1072
1119
|
# Calculate number of patches for current image
|
|
1073
1120
|
current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
|
|
1074
1121
|
start_idx = cumulative_patches
|
|
1075
1122
|
end_idx = cumulative_patches + current_patches
|
|
1076
1123
|
cumulative_patches += current_patches
|
|
1077
|
-
|
|
1124
|
+
|
|
1078
1125
|
single_pixel_values = pixel_values[start_idx:end_idx]
|
|
1079
|
-
single_grid_thw = image_grid_thw[i:i+1]
|
|
1080
|
-
|
|
1126
|
+
single_grid_thw = image_grid_thw[i : i + 1]
|
|
1127
|
+
|
|
1081
1128
|
# Use vision model directly
|
|
1082
1129
|
single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
|
|
1083
|
-
|
|
1130
|
+
|
|
1084
1131
|
# Split based on grid dimensions
|
|
1085
1132
|
if single_grid_thw is not None:
|
|
1086
|
-
split_sizes = (
|
|
1133
|
+
split_sizes = (
|
|
1134
|
+
mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
|
|
1135
|
+
).tolist()
|
|
1087
1136
|
split_indices = []
|
|
1088
1137
|
cumsum = 0
|
|
1089
1138
|
for size in split_sizes[:-1]:
|
|
1090
1139
|
cumsum += size
|
|
1091
1140
|
split_indices.append(cumsum)
|
|
1092
|
-
|
|
1141
|
+
|
|
1093
1142
|
if split_indices:
|
|
1094
1143
|
single_embeds = mx.split(single_embeds, split_indices)
|
|
1095
1144
|
else:
|
|
1096
1145
|
single_embeds = [single_embeds]
|
|
1097
|
-
|
|
1146
|
+
|
|
1098
1147
|
image_embeds_list.extend(single_embeds)
|
|
1099
|
-
|
|
1148
|
+
|
|
1100
1149
|
# Collect deepstack embeddings
|
|
1101
1150
|
if i == 0:
|
|
1102
1151
|
all_deepstack_embeds = single_deepstack
|
|
1103
1152
|
else:
|
|
1104
1153
|
# Concatenate deepstack embeddings from different images
|
|
1105
1154
|
for j in range(len(all_deepstack_embeds)):
|
|
1106
|
-
all_deepstack_embeds[j] = mx.concatenate(
|
|
1107
|
-
|
|
1155
|
+
all_deepstack_embeds[j] = mx.concatenate(
|
|
1156
|
+
[all_deepstack_embeds[j], single_deepstack[j]], axis=0
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1108
1159
|
deepstack_visual_embeds = all_deepstack_embeds
|
|
1109
|
-
|
|
1160
|
+
|
|
1110
1161
|
# Concatenate all image embeddings for processing
|
|
1111
1162
|
image_embeds = mx.concatenate(image_embeds_list, axis=0)
|
|
1112
|
-
|
|
1163
|
+
|
|
1113
1164
|
# Find all image token positions
|
|
1114
1165
|
image_token_id = 151655 # Default image token ID
|
|
1115
|
-
image_mask =
|
|
1166
|
+
image_mask = input_ids.squeeze(0) == image_token_id
|
|
1116
1167
|
image_mask_np = np.array(image_mask)
|
|
1117
1168
|
image_token_positions = np.where(image_mask_np)[0]
|
|
1118
|
-
|
|
1169
|
+
|
|
1119
1170
|
# Verify we have the correct number of image tokens
|
|
1120
1171
|
expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
|
|
1121
|
-
assert
|
|
1122
|
-
|
|
1172
|
+
assert (
|
|
1173
|
+
len(image_token_positions) == expected_total_tokens
|
|
1174
|
+
), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
|
|
1175
|
+
|
|
1123
1176
|
# Replace image tokens with image embeddings
|
|
1124
1177
|
seq_len = inputs_embeds.shape[0]
|
|
1125
1178
|
result = inputs_embeds
|
|
1126
|
-
|
|
1179
|
+
|
|
1127
1180
|
# Replace image tokens with image embeddings sequentially
|
|
1128
1181
|
embed_idx = 0
|
|
1129
1182
|
for img_embed in image_embeds_list:
|
|
@@ -1133,7 +1186,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1133
1186
|
result = mx.where(
|
|
1134
1187
|
mx.expand_dims(pos_mask, axis=-1),
|
|
1135
1188
|
mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
|
|
1136
|
-
result
|
|
1189
|
+
result,
|
|
1137
1190
|
)
|
|
1138
1191
|
embed_idx += 1
|
|
1139
1192
|
|
|
@@ -1142,10 +1195,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1142
1195
|
cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
|
|
1143
1196
|
if inputs_embeds.ndim == 2:
|
|
1144
1197
|
inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
|
|
1145
|
-
|
|
1198
|
+
|
|
1146
1199
|
if image_mask is not None:
|
|
1147
1200
|
visual_pos_masks = image_mask
|
|
1148
|
-
|
|
1201
|
+
|
|
1149
1202
|
return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
|
|
1150
1203
|
|
|
1151
1204
|
|
|
@@ -1156,7 +1209,9 @@ class Model(nn.Module):
|
|
|
1156
1209
|
self.args = args
|
|
1157
1210
|
self.model = Qwen3VLModel(args)
|
|
1158
1211
|
if not args.text_config.tie_word_embeddings:
|
|
1159
|
-
self.lm_head = nn.Linear(
|
|
1212
|
+
self.lm_head = nn.Linear(
|
|
1213
|
+
args.text_config.hidden_size, args.text_config.vocab_size, bias=False
|
|
1214
|
+
)
|
|
1160
1215
|
|
|
1161
1216
|
def __call__(
|
|
1162
1217
|
self,
|
|
@@ -1164,7 +1219,7 @@ class Model(nn.Module):
|
|
|
1164
1219
|
mask: mx.array = None,
|
|
1165
1220
|
cache=None,
|
|
1166
1221
|
inputs_embeds: Optional[mx.array] = None,
|
|
1167
|
-
pixel_values: Optional[mx.array] = None,
|
|
1222
|
+
pixel_values: Optional[mx.array] = None,
|
|
1168
1223
|
image_grid_thw: Optional[mx.array] = None,
|
|
1169
1224
|
visual_pos_masks: Optional[mx.array] = None,
|
|
1170
1225
|
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
@@ -1195,13 +1250,13 @@ class Model(nn.Module):
|
|
|
1195
1250
|
sanitized = {}
|
|
1196
1251
|
for k, v in weights.items():
|
|
1197
1252
|
sanitized[k] = v
|
|
1198
|
-
|
|
1253
|
+
|
|
1199
1254
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1200
1255
|
if self.args.text_config.tie_word_embeddings:
|
|
1201
1256
|
sanitized.pop("lm_head.weight", None)
|
|
1202
|
-
|
|
1257
|
+
|
|
1203
1258
|
return sanitized
|
|
1204
1259
|
|
|
1205
1260
|
@property
|
|
1206
1261
|
def layers(self):
|
|
1207
|
-
return self.model.language_model.layers
|
|
1262
|
+
return self.model.language_model.layers
|