nexaai 1.0.19rc16__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc18__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

@@ -120,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
120
120
 
121
121
  kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
122
122
  self.proj = nn.Conv3d(
123
- self.in_channels,
124
- self.embed_dim,
125
- kernel_size=kernel_size,
126
- stride=kernel_size,
127
- bias=True
123
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
128
124
  )
129
125
 
130
126
  def __call__(self, hidden_states: mx.array) -> mx.array:
131
127
  target_dtype = self.proj.weight.dtype
132
-
128
+
133
129
  # Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
134
130
  # This matches the PyTorch ground truth exactly
135
131
  hidden_states = hidden_states.reshape(
136
132
  -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
137
133
  )
138
-
134
+
139
135
  # Convert to MLX format: [batch, temporal, height, width, channels]
140
136
  hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
141
-
137
+
142
138
  # Apply conv3d with target dtype and reshape to match PyTorch output
143
139
  hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
144
-
140
+
145
141
  return hidden_states
146
142
 
147
143
 
@@ -163,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
163
159
  class VisionPatchMerger(nn.Module):
164
160
  def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
165
161
  super().__init__()
166
- self.hidden_size = config.hidden_size * (config.spatial_merge_size ** 2)
162
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
167
163
  self.use_postshuffle_norm = use_postshuffle_norm
168
-
164
+
169
165
  norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
170
- self.ln_q = nn.LayerNorm(norm_size, eps=1e-6)
166
+ self.norm = nn.LayerNorm(norm_size, eps=1e-6)
171
167
  self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
172
168
  self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
173
169
 
174
170
  def __call__(self, x: mx.array) -> mx.array:
175
171
  if self.use_postshuffle_norm:
176
- x = self.ln_q(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
172
+ x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
177
173
  else:
178
- x = self.ln_q(x).reshape(-1, self.hidden_size)
179
-
174
+ x = self.norm(x).reshape(-1, self.hidden_size)
175
+
180
176
  x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
181
177
  return x
182
178
 
@@ -187,8 +183,8 @@ class VisionAttention(nn.Module):
187
183
  self.dim = config.hidden_size
188
184
  self.num_heads = config.num_heads
189
185
  self.head_dim = self.dim // self.num_heads
190
- self.scaling = self.head_dim ** -0.5
191
-
186
+ self.scaling = self.head_dim**-0.5
187
+
192
188
  self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
193
189
  self.proj = nn.Linear(self.dim, self.dim)
194
190
 
@@ -204,51 +200,48 @@ class VisionAttention(nn.Module):
204
200
  qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
205
201
  qkv = qkv.transpose(1, 0, 2, 3)
206
202
  query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
207
-
203
+
208
204
  cos, sin = position_embeddings
209
- query_states, key_states = apply_rotary_pos_emb_vision(
210
- query_states, key_states, cos, sin
211
- )
205
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
212
206
 
213
207
  query_states = query_states.transpose(1, 0, 2)
214
208
  key_states = key_states.transpose(1, 0, 2)
215
209
  value_states = value_states.transpose(1, 0, 2)
216
-
210
+
217
211
  query_states = mx.expand_dims(query_states, axis=0)
218
212
  key_states = mx.expand_dims(key_states, axis=0)
219
213
  value_states = mx.expand_dims(value_states, axis=0)
220
-
214
+
221
215
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
222
-
216
+
223
217
  split_indices = []
224
218
  cumsum = 0
225
219
  for length in lengths[:-1]:
226
220
  cumsum += int(length)
227
221
  split_indices.append(cumsum)
228
-
222
+
229
223
  if split_indices:
230
224
  q_splits = mx.split(query_states, split_indices, axis=1)
231
225
  k_splits = mx.split(key_states, split_indices, axis=1)
232
226
  v_splits = mx.split(value_states, split_indices, axis=1)
233
227
  else:
234
228
  q_splits = [query_states]
235
- k_splits = [key_states]
229
+ k_splits = [key_states]
236
230
  v_splits = [value_states]
237
-
231
+
238
232
  attn_outputs = []
239
233
  for q, k, v in zip(q_splits, k_splits, v_splits):
240
234
  attn_out = scaled_dot_product_attention(
241
- q, k, v,
242
- scale=self.scaling, mask=None, cache=None
235
+ q, k, v, scale=self.scaling, mask=None, cache=None
243
236
  )
244
237
  attn_outputs.append(attn_out)
245
-
238
+
246
239
  attn_output = mx.concatenate(attn_outputs, axis=1)
247
-
240
+
248
241
  attn_output = attn_output[0].transpose(1, 0, 2)
249
242
  attn_output = attn_output.reshape(seq_length, -1)
250
243
  attn_output = self.proj(attn_output)
251
-
244
+
252
245
  return attn_output
253
246
 
254
247
 
@@ -284,7 +277,7 @@ class VisionModel(nn.Module):
284
277
 
285
278
  self.patch_embed = VisionPatchEmbed(config)
286
279
  self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
287
- self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
280
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
288
281
 
289
282
  head_dim = config.hidden_size // config.num_heads
290
283
  self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -310,7 +303,7 @@ class VisionModel(nn.Module):
310
303
  num_frames = int(grid_thw[i, 0].item())
311
304
  height = int(grid_thw[i, 1].item())
312
305
  width = int(grid_thw[i, 2].item())
313
-
306
+
314
307
  merged_h, merged_w = height // merge_size, width // merge_size
315
308
 
316
309
  block_rows = mx.arange(merged_h) # block row indices
@@ -322,8 +315,12 @@ class VisionModel(nn.Module):
322
315
  row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
323
316
  col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
324
317
 
325
- row_idx = mx.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
326
- col_idx = mx.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
318
+ row_idx = mx.broadcast_to(
319
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
320
+ ).reshape(-1)
321
+ col_idx = mx.broadcast_to(
322
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
323
+ ).reshape(-1)
327
324
 
328
325
  coords = mx.stack([row_idx, col_idx], axis=-1)
329
326
 
@@ -334,19 +331,19 @@ class VisionModel(nn.Module):
334
331
 
335
332
  # Concatenate all coordinate parts
336
333
  pos_ids = mx.concatenate(pos_ids_parts, axis=0)
337
-
334
+
338
335
  embeddings = freq_table[pos_ids] # lookup rotary embeddings
339
336
  embeddings = embeddings.reshape(embeddings.shape[0], -1)
340
337
  return embeddings
341
338
 
342
339
  def fast_pos_embed_interpolate(self, grid_thw: mx.array):
343
340
  patch_pos_embeds = []
344
-
341
+
345
342
  for i in range(grid_thw.shape[0]):
346
343
  t = int(grid_thw[i, 0].item())
347
344
  h = int(grid_thw[i, 1].item())
348
345
  w = int(grid_thw[i, 2].item())
349
-
346
+
350
347
  # Simple position embedding interpolation
351
348
  h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
352
349
  w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
@@ -383,37 +380,41 @@ class VisionModel(nn.Module):
383
380
 
384
381
  # Repeat for temporal dimension and apply spatial merging
385
382
  pos_embed = mx.tile(pos_embed, (t, 1))
386
-
383
+
387
384
  # Apply spatial merging pattern
388
385
  merge_size = self.config.spatial_merge_size
389
- pos_embed = pos_embed.reshape(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
386
+ pos_embed = pos_embed.reshape(
387
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
388
+ )
390
389
  pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
391
390
  pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
392
-
391
+
393
392
  patch_pos_embeds.append(pos_embed)
394
-
393
+
395
394
  return mx.concatenate(patch_pos_embeds, axis=0)
396
395
 
397
- def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> Tuple[mx.array, List[mx.array]]:
396
+ def __call__(
397
+ self, hidden_states: mx.array, grid_thw: mx.array
398
+ ) -> Tuple[mx.array, List[mx.array]]:
398
399
  hidden_states = self.patch_embed(hidden_states)
399
-
400
+
400
401
  pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
401
402
  hidden_states = hidden_states + pos_embeds
402
403
 
403
404
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
404
405
  seq_len = hidden_states.shape[0]
405
-
406
+
406
407
  emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
407
408
  position_embeddings = (mx.cos(emb), mx.sin(emb))
408
409
 
409
- # Create cumulative sequence lengths (following HuggingFace implementation)
410
+ # Create cumulative sequence lengths (following HuggingFace implementation)
410
411
  # torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
411
412
  seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
412
413
  seq_lens = []
413
414
  for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
414
415
  seq_lens.extend([seq_len] * int(repeats))
415
416
  seq_lens = mx.array(seq_lens)
416
-
417
+
417
418
  # Then compute cumulative sum
418
419
  cu_seqlens = mx.cumsum(seq_lens)
419
420
  # Pad with 0 at the beginning
@@ -441,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
441
442
  self.config = config
442
443
  self.max_seq_len_cached = config.max_position_embeddings
443
444
  self.original_max_seq_len = config.max_position_embeddings
444
-
445
+
445
446
  # MRoPE configuration
446
447
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
447
448
  self.rope_type = config.rope_scaling.get("rope_type", "default")
@@ -449,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
449
450
  else:
450
451
  self.rope_type = "default"
451
452
  self.mrope_section = [24, 20, 20]
452
-
453
+
453
454
  # Store parameters for computing inv_freq on the fly
454
455
  self.head_dim = config.head_dim
455
456
  self.theta = config.rope_theta
456
-
457
+
457
458
  # Attention scaling (simplified - may need adjustment based on actual config)
458
459
  self.attention_scaling = 1.0
459
460
 
460
461
  def _get_inv_freq(self):
461
462
  """Compute inverse frequencies on the fly"""
462
- inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim))
463
+ inv_freq = 1.0 / (
464
+ self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
465
+ )
463
466
  # Expand for 3 dimensions (T, H, W)
464
467
  return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
465
468
 
@@ -485,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
485
488
  Args:
486
489
  x: Input tensor for dtype reference
487
490
  position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
488
-
491
+
489
492
  Returns:
490
493
  cos, sin: Cosine and sine embeddings
491
494
  """
492
495
  # Handle 2D position_ids by expanding to 3D for MRoPE
493
496
  if position_ids.ndim == 2:
494
- position_ids = mx.broadcast_to(position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1]))
495
-
497
+ position_ids = mx.broadcast_to(
498
+ position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
499
+ )
500
+
496
501
  batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
497
-
502
+
498
503
  # Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
499
504
  inv_freq_expanded = mx.broadcast_to(
500
- self._get_inv_freq()[:, None, None, :],
501
- (3, batch_size, 1, self._get_inv_freq().shape[-1])
505
+ self._get_inv_freq()[:, None, None, :],
506
+ (3, batch_size, 1, self._get_inv_freq().shape[-1]),
502
507
  )
503
-
508
+
504
509
  # Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
505
510
  position_ids_expanded = position_ids[..., None].astype(mx.float32)
506
-
511
+
507
512
  # Compute frequencies: (3, batch_size, seq_len, dim//2)
508
513
  freqs = inv_freq_expanded * position_ids_expanded
509
-
514
+
510
515
  # Apply interleaved MRoPE
511
516
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
512
-
517
+
513
518
  # Create embeddings
514
519
  emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
515
520
  cos = mx.cos(emb) * self.attention_scaling
516
521
  sin = mx.sin(emb) * self.attention_scaling
517
-
522
+
518
523
  return cos.astype(x.dtype), sin.astype(x.dtype)
519
524
 
520
525
 
@@ -523,12 +528,12 @@ class TextAttention(nn.Module):
523
528
  super().__init__()
524
529
  self.config = config
525
530
  self.layer_idx = layer_idx
526
-
531
+
527
532
  dim = config.hidden_size
528
533
  self.n_heads = config.num_attention_heads
529
534
  self.n_kv_heads = config.num_key_value_heads
530
535
  self.head_dim = config.head_dim
531
- self.scale = self.head_dim ** -0.5
536
+ self.scale = self.head_dim**-0.5
532
537
 
533
538
  self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
534
539
  self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
@@ -537,7 +542,7 @@ class TextAttention(nn.Module):
537
542
 
538
543
  self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
539
544
  self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
540
-
545
+
541
546
  # Initialize rope directly
542
547
  self.rope = initialize_rope(
543
548
  config.head_dim,
@@ -573,8 +578,23 @@ class TextAttention(nn.Module):
573
578
  keys, values = cache.update_and_fetch(keys, values)
574
579
  else:
575
580
  if cache is not None:
576
- queries = self.rope(queries, offset=cache.offset+rope_deltas)
577
- keys = self.rope(keys, offset=cache.offset+rope_deltas)
581
+ # Handle different types of rope_deltas: scalar, array, or None
582
+ if rope_deltas is None:
583
+ offset_delta = 0
584
+ elif isinstance(rope_deltas, (int, float)):
585
+ # rope_deltas is a scalar
586
+ offset_delta = rope_deltas
587
+ elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
588
+ # rope_deltas is an array with single element
589
+ offset_delta = rope_deltas.item()
590
+ elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
591
+ # rope_deltas is an array with multiple elements, take first
592
+ offset_delta = rope_deltas.reshape(-1)[0].item()
593
+ else:
594
+ offset_delta = 0
595
+
596
+ queries = self.rope(queries, offset=cache.offset + offset_delta)
597
+ keys = self.rope(keys, offset=cache.offset + offset_delta)
578
598
  keys, values = cache.update_and_fetch(keys, values)
579
599
  else:
580
600
  queries = self.rope(queries)
@@ -618,7 +638,7 @@ class TextDecoderLayer(nn.Module):
618
638
  ) -> mx.array:
619
639
  residual = hidden_states
620
640
  hidden_states = self.input_layernorm(hidden_states)
621
-
641
+
622
642
  hidden_states, _ = self.self_attn(
623
643
  hidden_states=hidden_states,
624
644
  attention_mask=attention_mask,
@@ -640,11 +660,10 @@ class TextModel(nn.Module):
640
660
  super().__init__()
641
661
  self.config = config
642
662
  self.vocab_size = config.vocab_size
643
-
663
+
644
664
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
645
665
  self.layers = [
646
- TextDecoderLayer(config, layer_idx)
647
- for layer_idx in range(config.num_hidden_layers)
666
+ TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
648
667
  ]
649
668
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
650
669
  self.rotary_emb = TextRotaryEmbedding(config)
@@ -701,7 +720,9 @@ class TextModel(nn.Module):
701
720
  rope_deltas=rope_deltas,
702
721
  )
703
722
  if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
704
- hidden_states = self._deepstack_process(hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx])
723
+ hidden_states = self._deepstack_process(
724
+ hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
725
+ )
705
726
  hidden_states = self.norm(hidden_states)
706
727
  return hidden_states
707
728
 
@@ -712,17 +733,17 @@ class VEGModel(nn.Module):
712
733
  super().__init__()
713
734
  self.config = vision_config
714
735
  self.visual = VisionModel(vision_config)
715
-
736
+
716
737
  def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
717
738
  return self.visual(pixel_values, image_grid_thw)
718
-
739
+
719
740
  def sanitize(self, weights):
720
741
  sanitized = {}
721
742
  for k, v in weights.items():
722
- if 'visual.' in k:
743
+ if "visual." in k:
723
744
  # Remove prefixes to match our model structure
724
- clean_key = k.replace('model.visual.', '').replace('visual.', '')
725
- sanitized[f'visual.{clean_key}'] = v
745
+ clean_key = k.replace("model.visual.", "").replace("visual.", "")
746
+ sanitized[f"visual.{clean_key}"] = v
726
747
  return sanitized
727
748
 
728
749
 
@@ -735,140 +756,164 @@ class LLMModel(nn.Module):
735
756
  self.language_model = TextModel(text_config)
736
757
  if not text_config.tie_word_embeddings:
737
758
  self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
738
-
759
+
739
760
  def get_rope_index(
740
- self,
741
- input_ids: Optional[mx.array] = None,
742
- image_grid_thw: Optional[mx.array] = None,
743
- attention_mask: Optional[mx.array] = None,
744
- ) -> Tuple[mx.array, mx.array]:
745
- """Simplified version for images only (no video support)."""
746
-
747
- spatial_merge_size = 2
748
- image_token_id = 151655
749
- vision_start_token_id = 151652
750
- mrope_position_deltas = []
751
-
752
- if input_ids is not None and image_grid_thw is not None:
753
- total_input_ids = input_ids
754
- if attention_mask is None:
755
- attention_mask = mx.ones_like(total_input_ids)
756
-
757
- batch_size, seq_len = input_ids.shape
758
- position_ids_list = []
759
- image_index = 0
760
-
761
- for i in range(batch_size):
762
- input_ids_seq = total_input_ids[i]
763
- mask_seq = attention_mask[i]
764
-
765
- # Use mask to get valid length
766
- valid_length = int(mx.sum(mask_seq).item())
767
- input_ids_seq = input_ids_seq[:valid_length]
768
-
769
- image_nums = 0
770
- # Find vision start tokens by iterating through the sequence
771
- vision_start_positions = []
772
- for pos in range(input_ids_seq.shape[0]):
773
- if input_ids_seq[pos].item() == vision_start_token_id:
774
- vision_start_positions.append(pos)
775
-
776
- if len(vision_start_positions) > 0:
777
- for pos in vision_start_positions:
778
- if pos + 1 < input_ids_seq.shape[0]:
779
- if input_ids_seq[pos + 1].item() == image_token_id:
780
- image_nums += 1
781
-
782
- input_tokens = input_ids_seq.tolist()
783
- llm_pos_ids_list = []
784
- st = 0
785
- remain_images = image_nums
786
-
787
- for _ in range(image_nums):
788
- ed_image = input_tokens.index(image_token_id, st)
789
-
790
- t = image_grid_thw[image_index, 0].item()
791
- h = image_grid_thw[image_index, 1].item()
792
- w = image_grid_thw[image_index, 2].item()
793
- image_index += 1
794
- remain_images -= 1
795
- ed = ed_image
796
-
797
- llm_grid_t = int(t)
798
- llm_grid_h = int(h) // spatial_merge_size
799
- llm_grid_w = int(w) // spatial_merge_size
800
- text_len = ed - st
801
-
802
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
803
- text_pos = mx.arange(text_len).reshape(1, -1)
804
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
805
- llm_pos_ids_list.append(text_pos)
806
-
807
- # t_index is always 0 because llm_grid_t is always 1 for images
808
- t_index = mx.arange(llm_grid_t).reshape(-1, 1)
809
- t_index = mx.broadcast_to(t_index, (llm_grid_t, llm_grid_h * llm_grid_w)).reshape(-1)
810
-
811
- h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
812
- h_index = mx.broadcast_to(h_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
813
-
814
- w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
815
- w_index = mx.broadcast_to(w_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
816
-
817
- vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
818
- llm_pos_ids_list.append(vision_pos)
819
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
820
-
821
- if st < len(input_tokens):
822
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
823
- text_len = len(input_tokens) - st
824
- text_pos = mx.arange(text_len).reshape(1, -1)
825
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
826
- llm_pos_ids_list.append(text_pos)
827
-
828
- llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
829
-
830
- # Create position_ids for this batch item, pad to seq_len
831
- batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
832
- valid_length = min(seq_len, llm_positions.shape[1])
833
-
834
- # Create new arrays for each dimension
835
- pos_dim0 = mx.concatenate([llm_positions[0, :valid_length],
836
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
837
- pos_dim1 = mx.concatenate([llm_positions[1, :valid_length],
838
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
839
- pos_dim2 = mx.concatenate([llm_positions[2, :valid_length],
840
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
841
-
842
- batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
843
- position_ids_list.append(batch_position_ids)
844
-
845
- mrope_position_deltas.append(llm_positions.max().item() + 1 - len(total_input_ids[i]))
846
-
847
- # Stack all batch position_ids
848
- position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
849
- # Ensure rope deltas are 1D: (batch,)
850
- mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1)
851
- return position_ids, mrope_position_deltas
761
+ self,
762
+ input_ids: Optional[mx.array] = None,
763
+ image_grid_thw: Optional[mx.array] = None,
764
+ attention_mask: Optional[mx.array] = None,
765
+ ) -> Tuple[mx.array, mx.array]:
766
+ """Simplified version for images only (no video support)."""
767
+
768
+ spatial_merge_size = 2
769
+ image_token_id = 151655
770
+ vision_start_token_id = 151652
771
+ mrope_position_deltas = []
772
+
773
+ if input_ids is not None and image_grid_thw is not None:
774
+ total_input_ids = input_ids
775
+ if attention_mask is None:
776
+ attention_mask = mx.ones_like(total_input_ids)
777
+
778
+ batch_size, seq_len = input_ids.shape
779
+ position_ids_list = []
780
+ image_index = 0
781
+
782
+ for i in range(batch_size):
783
+ input_ids_seq = total_input_ids[i]
784
+ mask_seq = attention_mask[i]
785
+
786
+ # Use mask to get valid length
787
+ valid_length = int(mx.sum(mask_seq).item())
788
+ input_ids_seq = input_ids_seq[:valid_length]
789
+
790
+ image_nums = 0
791
+ # Find vision start tokens by iterating through the sequence
792
+ vision_start_positions = []
793
+ for pos in range(input_ids_seq.shape[0]):
794
+ if input_ids_seq[pos].item() == vision_start_token_id:
795
+ vision_start_positions.append(pos)
796
+
797
+ if len(vision_start_positions) > 0:
798
+ for pos in vision_start_positions:
799
+ if pos + 1 < input_ids_seq.shape[0]:
800
+ if input_ids_seq[pos + 1].item() == image_token_id:
801
+ image_nums += 1
802
+
803
+ input_tokens = input_ids_seq.tolist()
804
+ llm_pos_ids_list = []
805
+ st = 0
806
+ remain_images = image_nums
807
+
808
+ for _ in range(image_nums):
809
+ ed_image = input_tokens.index(image_token_id, st)
810
+
811
+ t = image_grid_thw[image_index, 0].item()
812
+ h = image_grid_thw[image_index, 1].item()
813
+ w = image_grid_thw[image_index, 2].item()
814
+ image_index += 1
815
+ remain_images -= 1
816
+ ed = ed_image
817
+
818
+ llm_grid_t = int(t)
819
+ llm_grid_h = int(h) // spatial_merge_size
820
+ llm_grid_w = int(w) // spatial_merge_size
821
+ text_len = ed - st
822
+
823
+ st_idx = (
824
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
825
+ )
826
+ text_pos = mx.arange(text_len).reshape(1, -1)
827
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
828
+ llm_pos_ids_list.append(text_pos)
829
+
830
+ # t_index is always 0 because llm_grid_t is always 1 for images
831
+ t_index = mx.arange(llm_grid_t).reshape(-1, 1)
832
+ t_index = mx.broadcast_to(
833
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
834
+ ).reshape(-1)
835
+
836
+ h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
837
+ h_index = mx.broadcast_to(
838
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
839
+ ).reshape(-1)
840
+
841
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
842
+ w_index = mx.broadcast_to(
843
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
844
+ ).reshape(-1)
845
+
846
+ vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
847
+ llm_pos_ids_list.append(vision_pos)
848
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
849
+
850
+ if st < len(input_tokens):
851
+ st_idx = (
852
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
853
+ )
854
+ text_len = len(input_tokens) - st
855
+ text_pos = mx.arange(text_len).reshape(1, -1)
856
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
857
+ llm_pos_ids_list.append(text_pos)
858
+
859
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
860
+
861
+ # Create position_ids for this batch item, pad to seq_len
862
+ batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
863
+ valid_length = min(seq_len, llm_positions.shape[1])
864
+
865
+ # Create new arrays for each dimension
866
+ pos_dim0 = mx.concatenate(
867
+ [
868
+ llm_positions[0, :valid_length],
869
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
870
+ ]
871
+ )
872
+ pos_dim1 = mx.concatenate(
873
+ [
874
+ llm_positions[1, :valid_length],
875
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
876
+ ]
877
+ )
878
+ pos_dim2 = mx.concatenate(
879
+ [
880
+ llm_positions[2, :valid_length],
881
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
882
+ ]
883
+ )
884
+
885
+ batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
886
+ position_ids_list.append(batch_position_ids)
887
+
888
+ mrope_position_deltas.append(
889
+ llm_positions.max().item() + 1 - len(total_input_ids[i])
890
+ )
891
+
892
+ # Stack all batch position_ids
893
+ position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
894
+ mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
895
+ return position_ids, mrope_position_deltas
896
+ else:
897
+ if attention_mask is not None:
898
+ position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
899
+ position_ids = mx.where(attention_mask == 0, 1, position_ids)
900
+ position_ids = mx.expand_dims(position_ids, axis=0)
901
+ position_ids = mx.broadcast_to(
902
+ position_ids, (3, position_ids.shape[1], position_ids.shape[2])
903
+ )
904
+ max_position_ids = mx.max(
905
+ mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
906
+ )
907
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
852
908
  else:
853
- if attention_mask is not None:
854
- position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
855
- position_ids = mx.where(attention_mask == 0, 1, position_ids)
856
- position_ids = mx.expand_dims(position_ids, axis=0)
857
- position_ids = mx.broadcast_to(position_ids, (3, position_ids.shape[1], position_ids.shape[2]))
858
- # Compute max position per batch, ensure 1D shape (batch,)
859
- max_position_ids = mx.max(mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=False)
860
- mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
861
- mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
862
- else:
863
- seq_len = input_ids.shape[1]
864
- batch_size = input_ids.shape[0]
865
- position_ids = mx.arange(seq_len).reshape(1, 1, -1)
866
- position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
867
- # 1D zeros for rope deltas
868
- mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
869
-
870
- return position_ids, mrope_position_deltas
871
-
909
+ seq_len = input_ids.shape[1]
910
+ batch_size = input_ids.shape[0]
911
+ position_ids = mx.arange(seq_len).reshape(1, 1, -1)
912
+ position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
913
+ mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
914
+
915
+ return position_ids, mrope_position_deltas
916
+
872
917
  def __call__(
873
918
  self,
874
919
  inputs: mx.array = None,
@@ -896,35 +941,41 @@ class LLMModel(nn.Module):
896
941
  return self.language_model.embed_tokens.as_linear(out)
897
942
  else:
898
943
  return self.lm_head(out)
899
-
944
+
900
945
  def sanitize(self, weights):
901
946
  sanitized = {}
902
947
  for k, v in weights.items():
903
- if not ('visual.' in k):
948
+ if not ("visual." in k):
904
949
  # Handle key mapping from combined model to LLM-only model
905
950
  clean_key = k
906
-
951
+
907
952
  # Remove model. prefix if present
908
- if clean_key.startswith('model.'):
953
+ if clean_key.startswith("model."):
909
954
  clean_key = clean_key[6:] # Remove 'model.'
910
-
955
+
911
956
  # Map language_ prefixed keys to language_model structure
912
- if clean_key.startswith('language_'):
913
- if clean_key.startswith('language_layers.'):
914
- clean_key = 'language_model.layers.' + clean_key[16:] # Map to language_model.layers.
915
- elif clean_key.startswith('language_embed_tokens.'):
916
- clean_key = 'language_model.embed_tokens.' + clean_key[22:] # Map to language_model.embed_tokens.
917
- elif clean_key.startswith('language_norm.'):
918
- clean_key = 'language_model.norm.' + clean_key[14:] # Map to language_model.norm.
919
-
957
+ if clean_key.startswith("language_"):
958
+ if clean_key.startswith("language_layers."):
959
+ clean_key = (
960
+ "language_model.layers." + clean_key[16:]
961
+ ) # Map to language_model.layers.
962
+ elif clean_key.startswith("language_embed_tokens."):
963
+ clean_key = (
964
+ "language_model.embed_tokens." + clean_key[22:]
965
+ ) # Map to language_model.embed_tokens.
966
+ elif clean_key.startswith("language_norm."):
967
+ clean_key = (
968
+ "language_model.norm." + clean_key[14:]
969
+ ) # Map to language_model.norm.
970
+
920
971
  sanitized[clean_key] = v
921
-
972
+
922
973
  # Handle tied embeddings - remove lm_head if using tied embeddings
923
974
  if self.args.tie_word_embeddings:
924
975
  sanitized.pop("lm_head.weight", None)
925
-
976
+
926
977
  return sanitized
927
-
978
+
928
979
  @property
929
980
  def layers(self):
930
981
  return self.language_model.layers
@@ -938,39 +989,36 @@ class Qwen3VLModel(nn.Module):
938
989
  self.config = args
939
990
  self.visual = VisionModel(args.vision_config)
940
991
  self.language_model = TextModel(args.text_config)
941
-
992
+
942
993
  def sanitize(self, weights):
943
994
  # Map weights to match the combined model structure
944
995
  sanitized = {}
945
996
  for k, v in weights.items():
946
997
  # Remove 'model.' prefix if present to match our structure
947
- clean_key = k.replace('model.', '') if k.startswith('model.') else k
998
+ clean_key = k.replace("model.", "") if k.startswith("model.") else k
948
999
  sanitized[clean_key] = v
949
1000
  return sanitized
950
1001
 
951
- def get_image_features(
952
- self,
953
- pixel_values: mx.array,
954
- image_grid_thw: Optional[mx.array] = None
955
- ):
1002
+ def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
956
1003
  image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
957
1004
  # Split based on grid dimensions
958
1005
  if image_grid_thw is not None:
959
- split_sizes = (mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size ** 2)).tolist()
1006
+ split_sizes = (
1007
+ mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
1008
+ ).tolist()
960
1009
  # Convert sizes to indices for mx.split (cumulative sum, excluding the last)
961
1010
  split_indices = []
962
1011
  cumsum = 0
963
1012
  for size in split_sizes[:-1]: # Exclude last element
964
1013
  cumsum += size
965
1014
  split_indices.append(cumsum)
966
-
1015
+
967
1016
  if split_indices: # Only split if we have indices
968
1017
  image_embeds = mx.split(image_embeds, split_indices)
969
1018
  else:
970
1019
  image_embeds = [image_embeds] # Single image case
971
1020
  return image_embeds, deepstack_visual_embeds
972
1021
 
973
-
974
1022
  def __call__(
975
1023
  self,
976
1024
  input_ids: mx.array = None,
@@ -989,26 +1037,25 @@ class Qwen3VLModel(nn.Module):
989
1037
  inputs_embeds = self.language_model.embed_tokens(input_ids)
990
1038
 
991
1039
  # Process images
992
-
1040
+
993
1041
  if pixel_values is not None:
994
1042
  image_embeds, deepstack_visual_embeds = self.get_image_features(
995
1043
  pixel_values, image_grid_thw
996
1044
  )
997
-
1045
+
998
1046
  # Create masks and embed visual features
999
1047
  if isinstance(image_embeds, list):
1000
1048
  image_embeds = mx.concatenate(image_embeds, axis=0)
1001
-
1049
+
1002
1050
  # Find image token positions and replace with visual embeddings
1003
- image_mask = (input_ids == self.args.image_token_id)
1051
+ image_mask = input_ids == self.args.image_token_id
1004
1052
  visual_pos_masks = image_mask
1005
-
1053
+
1006
1054
  # Replace image tokens with visual embeddings
1007
1055
  inputs_embeds = inputs_embeds.at[image_mask].set(
1008
1056
  image_embeds.astype(inputs_embeds.dtype)
1009
1057
  )
1010
1058
 
1011
-
1012
1059
  outputs = self.language_model(
1013
1060
  inputs_embeds=inputs_embeds,
1014
1061
  attention_mask=attention_mask,
@@ -1026,28 +1073,28 @@ class Qwen3VLModel(nn.Module):
1026
1073
  def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
1027
1074
  """
1028
1075
  Handle the processing of multimodal embeddings including image features and position encoding.
1029
-
1076
+
1030
1077
  This function processes vision and text inputs to create unified embeddings that can be fed
1031
1078
  into the language model. It handles:
1032
1079
  - Vision feature extraction from pixel values
1033
1080
  - Deepstack visual embedding collection
1034
1081
  - Image token replacement in text embeddings
1035
1082
  - Position encoding setup for MRoPE (Multi-dimensional RoPE)
1036
-
1083
+
1037
1084
  Args:
1038
1085
  vision_model: The vision encoder model (VEGModel instance)
1039
- llm_model: The language model (LLMModel instance)
1086
+ llm_model: The language model (LLMModel instance)
1040
1087
  input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
1041
1088
  pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
1042
1089
  image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
1043
-
1090
+
1044
1091
  Returns:
1045
1092
  tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
1046
1093
  - inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
1047
1094
  - deepstack_visual_embeds: Multi-layer visual features for deepstack processing
1048
1095
  - visual_pos_masks: Boolean mask indicating image token positions
1049
1096
  - cos: Cosine values for rotary position encoding
1050
- - sin: Sine values for rotary position encoding
1097
+ - sin: Sine values for rotary position encoding
1051
1098
  - rope_deltas: Position offset deltas for rope computation
1052
1099
  """
1053
1100
  inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
@@ -1056,74 +1103,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1056
1103
  cos = None
1057
1104
  sin = None
1058
1105
  rope_deltas = 0
1059
-
1106
+
1060
1107
  if pixel_values is not None:
1061
1108
  if pixel_values.ndim == 4:
1062
1109
  pixel_values = mx.expand_dims(pixel_values, axis=2)
1063
-
1110
+
1064
1111
  # Process each image individually to prevent feature mixing
1065
1112
  image_embeds_list = []
1066
1113
  all_deepstack_embeds = []
1067
-
1114
+
1068
1115
  # Calculate cumulative indices for each image
1069
1116
  cumulative_patches = 0
1070
-
1117
+
1071
1118
  for i in range(image_grid_thw.shape[0]):
1072
1119
  # Calculate number of patches for current image
1073
1120
  current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
1074
1121
  start_idx = cumulative_patches
1075
1122
  end_idx = cumulative_patches + current_patches
1076
1123
  cumulative_patches += current_patches
1077
-
1124
+
1078
1125
  single_pixel_values = pixel_values[start_idx:end_idx]
1079
- single_grid_thw = image_grid_thw[i:i+1]
1080
-
1126
+ single_grid_thw = image_grid_thw[i : i + 1]
1127
+
1081
1128
  # Use vision model directly
1082
1129
  single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
1083
-
1130
+
1084
1131
  # Split based on grid dimensions
1085
1132
  if single_grid_thw is not None:
1086
- split_sizes = (mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size ** 2)).tolist()
1133
+ split_sizes = (
1134
+ mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
1135
+ ).tolist()
1087
1136
  split_indices = []
1088
1137
  cumsum = 0
1089
1138
  for size in split_sizes[:-1]:
1090
1139
  cumsum += size
1091
1140
  split_indices.append(cumsum)
1092
-
1141
+
1093
1142
  if split_indices:
1094
1143
  single_embeds = mx.split(single_embeds, split_indices)
1095
1144
  else:
1096
1145
  single_embeds = [single_embeds]
1097
-
1146
+
1098
1147
  image_embeds_list.extend(single_embeds)
1099
-
1148
+
1100
1149
  # Collect deepstack embeddings
1101
1150
  if i == 0:
1102
1151
  all_deepstack_embeds = single_deepstack
1103
1152
  else:
1104
1153
  # Concatenate deepstack embeddings from different images
1105
1154
  for j in range(len(all_deepstack_embeds)):
1106
- all_deepstack_embeds[j] = mx.concatenate([all_deepstack_embeds[j], single_deepstack[j]], axis=0)
1107
-
1155
+ all_deepstack_embeds[j] = mx.concatenate(
1156
+ [all_deepstack_embeds[j], single_deepstack[j]], axis=0
1157
+ )
1158
+
1108
1159
  deepstack_visual_embeds = all_deepstack_embeds
1109
-
1160
+
1110
1161
  # Concatenate all image embeddings for processing
1111
1162
  image_embeds = mx.concatenate(image_embeds_list, axis=0)
1112
-
1163
+
1113
1164
  # Find all image token positions
1114
1165
  image_token_id = 151655 # Default image token ID
1115
- image_mask = (input_ids.squeeze(0) == image_token_id)
1166
+ image_mask = input_ids.squeeze(0) == image_token_id
1116
1167
  image_mask_np = np.array(image_mask)
1117
1168
  image_token_positions = np.where(image_mask_np)[0]
1118
-
1169
+
1119
1170
  # Verify we have the correct number of image tokens
1120
1171
  expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
1121
- assert len(image_token_positions) == expected_total_tokens, f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1122
-
1172
+ assert (
1173
+ len(image_token_positions) == expected_total_tokens
1174
+ ), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1175
+
1123
1176
  # Replace image tokens with image embeddings
1124
1177
  seq_len = inputs_embeds.shape[0]
1125
1178
  result = inputs_embeds
1126
-
1179
+
1127
1180
  # Replace image tokens with image embeddings sequentially
1128
1181
  embed_idx = 0
1129
1182
  for img_embed in image_embeds_list:
@@ -1133,7 +1186,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1133
1186
  result = mx.where(
1134
1187
  mx.expand_dims(pos_mask, axis=-1),
1135
1188
  mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
1136
- result
1189
+ result,
1137
1190
  )
1138
1191
  embed_idx += 1
1139
1192
 
@@ -1142,10 +1195,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1142
1195
  cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
1143
1196
  if inputs_embeds.ndim == 2:
1144
1197
  inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
1145
-
1198
+
1146
1199
  if image_mask is not None:
1147
1200
  visual_pos_masks = image_mask
1148
-
1201
+
1149
1202
  return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
1150
1203
 
1151
1204
 
@@ -1156,7 +1209,9 @@ class Model(nn.Module):
1156
1209
  self.args = args
1157
1210
  self.model = Qwen3VLModel(args)
1158
1211
  if not args.text_config.tie_word_embeddings:
1159
- self.lm_head = nn.Linear(args.text_config.hidden_size, args.text_config.vocab_size, bias=False)
1212
+ self.lm_head = nn.Linear(
1213
+ args.text_config.hidden_size, args.text_config.vocab_size, bias=False
1214
+ )
1160
1215
 
1161
1216
  def __call__(
1162
1217
  self,
@@ -1164,7 +1219,7 @@ class Model(nn.Module):
1164
1219
  mask: mx.array = None,
1165
1220
  cache=None,
1166
1221
  inputs_embeds: Optional[mx.array] = None,
1167
- pixel_values: Optional[mx.array] = None,
1222
+ pixel_values: Optional[mx.array] = None,
1168
1223
  image_grid_thw: Optional[mx.array] = None,
1169
1224
  visual_pos_masks: Optional[mx.array] = None,
1170
1225
  deepstack_visual_embeds: Optional[List[mx.array]] = None,
@@ -1195,13 +1250,13 @@ class Model(nn.Module):
1195
1250
  sanitized = {}
1196
1251
  for k, v in weights.items():
1197
1252
  sanitized[k] = v
1198
-
1253
+
1199
1254
  # Handle tied embeddings - remove lm_head if using tied embeddings
1200
1255
  if self.args.text_config.tie_word_embeddings:
1201
1256
  sanitized.pop("lm_head.weight", None)
1202
-
1257
+
1203
1258
  return sanitized
1204
1259
 
1205
1260
  @property
1206
1261
  def layers(self):
1207
- return self.model.language_model.layers
1262
+ return self.model.language_model.layers