nexaai 1.0.19rc15__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc17__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

@@ -120,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
120
120
 
121
121
  kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
122
122
  self.proj = nn.Conv3d(
123
- self.in_channels,
124
- self.embed_dim,
125
- kernel_size=kernel_size,
126
- stride=kernel_size,
127
- bias=True
123
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
128
124
  )
129
125
 
130
126
  def __call__(self, hidden_states: mx.array) -> mx.array:
131
127
  target_dtype = self.proj.weight.dtype
132
-
128
+
133
129
  # Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
134
130
  # This matches the PyTorch ground truth exactly
135
131
  hidden_states = hidden_states.reshape(
136
132
  -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
137
133
  )
138
-
134
+
139
135
  # Convert to MLX format: [batch, temporal, height, width, channels]
140
136
  hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
141
-
137
+
142
138
  # Apply conv3d with target dtype and reshape to match PyTorch output
143
139
  hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
144
-
140
+
145
141
  return hidden_states
146
142
 
147
143
 
@@ -163,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
163
159
  class VisionPatchMerger(nn.Module):
164
160
  def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
165
161
  super().__init__()
166
- self.hidden_size = config.hidden_size * (config.spatial_merge_size ** 2)
162
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
167
163
  self.use_postshuffle_norm = use_postshuffle_norm
168
-
164
+
169
165
  norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
170
- self.ln_q = nn.LayerNorm(norm_size, eps=1e-6)
166
+ self.norm = nn.LayerNorm(norm_size, eps=1e-6)
171
167
  self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
172
168
  self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
173
169
 
174
170
  def __call__(self, x: mx.array) -> mx.array:
175
171
  if self.use_postshuffle_norm:
176
- x = self.ln_q(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
172
+ x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
177
173
  else:
178
- x = self.ln_q(x).reshape(-1, self.hidden_size)
179
-
174
+ x = self.norm(x).reshape(-1, self.hidden_size)
175
+
180
176
  x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
181
177
  return x
182
178
 
@@ -187,8 +183,8 @@ class VisionAttention(nn.Module):
187
183
  self.dim = config.hidden_size
188
184
  self.num_heads = config.num_heads
189
185
  self.head_dim = self.dim // self.num_heads
190
- self.scaling = self.head_dim ** -0.5
191
-
186
+ self.scaling = self.head_dim**-0.5
187
+
192
188
  self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
193
189
  self.proj = nn.Linear(self.dim, self.dim)
194
190
 
@@ -204,51 +200,48 @@ class VisionAttention(nn.Module):
204
200
  qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
205
201
  qkv = qkv.transpose(1, 0, 2, 3)
206
202
  query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
207
-
203
+
208
204
  cos, sin = position_embeddings
209
- query_states, key_states = apply_rotary_pos_emb_vision(
210
- query_states, key_states, cos, sin
211
- )
205
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
212
206
 
213
207
  query_states = query_states.transpose(1, 0, 2)
214
208
  key_states = key_states.transpose(1, 0, 2)
215
209
  value_states = value_states.transpose(1, 0, 2)
216
-
210
+
217
211
  query_states = mx.expand_dims(query_states, axis=0)
218
212
  key_states = mx.expand_dims(key_states, axis=0)
219
213
  value_states = mx.expand_dims(value_states, axis=0)
220
-
214
+
221
215
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
222
-
216
+
223
217
  split_indices = []
224
218
  cumsum = 0
225
219
  for length in lengths[:-1]:
226
220
  cumsum += int(length)
227
221
  split_indices.append(cumsum)
228
-
222
+
229
223
  if split_indices:
230
224
  q_splits = mx.split(query_states, split_indices, axis=1)
231
225
  k_splits = mx.split(key_states, split_indices, axis=1)
232
226
  v_splits = mx.split(value_states, split_indices, axis=1)
233
227
  else:
234
228
  q_splits = [query_states]
235
- k_splits = [key_states]
229
+ k_splits = [key_states]
236
230
  v_splits = [value_states]
237
-
231
+
238
232
  attn_outputs = []
239
233
  for q, k, v in zip(q_splits, k_splits, v_splits):
240
234
  attn_out = scaled_dot_product_attention(
241
- q, k, v,
242
- scale=self.scaling, mask=None, cache=None
235
+ q, k, v, scale=self.scaling, mask=None, cache=None
243
236
  )
244
237
  attn_outputs.append(attn_out)
245
-
238
+
246
239
  attn_output = mx.concatenate(attn_outputs, axis=1)
247
-
240
+
248
241
  attn_output = attn_output[0].transpose(1, 0, 2)
249
242
  attn_output = attn_output.reshape(seq_length, -1)
250
243
  attn_output = self.proj(attn_output)
251
-
244
+
252
245
  return attn_output
253
246
 
254
247
 
@@ -284,7 +277,7 @@ class VisionModel(nn.Module):
284
277
 
285
278
  self.patch_embed = VisionPatchEmbed(config)
286
279
  self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
287
- self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
280
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
288
281
 
289
282
  head_dim = config.hidden_size // config.num_heads
290
283
  self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -310,7 +303,7 @@ class VisionModel(nn.Module):
310
303
  num_frames = int(grid_thw[i, 0].item())
311
304
  height = int(grid_thw[i, 1].item())
312
305
  width = int(grid_thw[i, 2].item())
313
-
306
+
314
307
  merged_h, merged_w = height // merge_size, width // merge_size
315
308
 
316
309
  block_rows = mx.arange(merged_h) # block row indices
@@ -322,8 +315,12 @@ class VisionModel(nn.Module):
322
315
  row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
323
316
  col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
324
317
 
325
- row_idx = mx.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
326
- col_idx = mx.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
318
+ row_idx = mx.broadcast_to(
319
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
320
+ ).reshape(-1)
321
+ col_idx = mx.broadcast_to(
322
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
323
+ ).reshape(-1)
327
324
 
328
325
  coords = mx.stack([row_idx, col_idx], axis=-1)
329
326
 
@@ -334,19 +331,19 @@ class VisionModel(nn.Module):
334
331
 
335
332
  # Concatenate all coordinate parts
336
333
  pos_ids = mx.concatenate(pos_ids_parts, axis=0)
337
-
334
+
338
335
  embeddings = freq_table[pos_ids] # lookup rotary embeddings
339
336
  embeddings = embeddings.reshape(embeddings.shape[0], -1)
340
337
  return embeddings
341
338
 
342
339
  def fast_pos_embed_interpolate(self, grid_thw: mx.array):
343
340
  patch_pos_embeds = []
344
-
341
+
345
342
  for i in range(grid_thw.shape[0]):
346
343
  t = int(grid_thw[i, 0].item())
347
344
  h = int(grid_thw[i, 1].item())
348
345
  w = int(grid_thw[i, 2].item())
349
-
346
+
350
347
  # Simple position embedding interpolation
351
348
  h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
352
349
  w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
@@ -383,37 +380,41 @@ class VisionModel(nn.Module):
383
380
 
384
381
  # Repeat for temporal dimension and apply spatial merging
385
382
  pos_embed = mx.tile(pos_embed, (t, 1))
386
-
383
+
387
384
  # Apply spatial merging pattern
388
385
  merge_size = self.config.spatial_merge_size
389
- pos_embed = pos_embed.reshape(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
386
+ pos_embed = pos_embed.reshape(
387
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
388
+ )
390
389
  pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
391
390
  pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
392
-
391
+
393
392
  patch_pos_embeds.append(pos_embed)
394
-
393
+
395
394
  return mx.concatenate(patch_pos_embeds, axis=0)
396
395
 
397
- def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> Tuple[mx.array, List[mx.array]]:
396
+ def __call__(
397
+ self, hidden_states: mx.array, grid_thw: mx.array
398
+ ) -> Tuple[mx.array, List[mx.array]]:
398
399
  hidden_states = self.patch_embed(hidden_states)
399
-
400
+
400
401
  pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
401
402
  hidden_states = hidden_states + pos_embeds
402
403
 
403
404
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
404
405
  seq_len = hidden_states.shape[0]
405
-
406
+
406
407
  emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
407
408
  position_embeddings = (mx.cos(emb), mx.sin(emb))
408
409
 
409
- # Create cumulative sequence lengths (following HuggingFace implementation)
410
+ # Create cumulative sequence lengths (following HuggingFace implementation)
410
411
  # torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
411
412
  seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
412
413
  seq_lens = []
413
414
  for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
414
415
  seq_lens.extend([seq_len] * int(repeats))
415
416
  seq_lens = mx.array(seq_lens)
416
-
417
+
417
418
  # Then compute cumulative sum
418
419
  cu_seqlens = mx.cumsum(seq_lens)
419
420
  # Pad with 0 at the beginning
@@ -441,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
441
442
  self.config = config
442
443
  self.max_seq_len_cached = config.max_position_embeddings
443
444
  self.original_max_seq_len = config.max_position_embeddings
444
-
445
+
445
446
  # MRoPE configuration
446
447
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
447
448
  self.rope_type = config.rope_scaling.get("rope_type", "default")
@@ -449,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
449
450
  else:
450
451
  self.rope_type = "default"
451
452
  self.mrope_section = [24, 20, 20]
452
-
453
+
453
454
  # Store parameters for computing inv_freq on the fly
454
455
  self.head_dim = config.head_dim
455
456
  self.theta = config.rope_theta
456
-
457
+
457
458
  # Attention scaling (simplified - may need adjustment based on actual config)
458
459
  self.attention_scaling = 1.0
459
460
 
460
461
  def _get_inv_freq(self):
461
462
  """Compute inverse frequencies on the fly"""
462
- inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim))
463
+ inv_freq = 1.0 / (
464
+ self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
465
+ )
463
466
  # Expand for 3 dimensions (T, H, W)
464
467
  return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
465
468
 
@@ -485,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
485
488
  Args:
486
489
  x: Input tensor for dtype reference
487
490
  position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
488
-
491
+
489
492
  Returns:
490
493
  cos, sin: Cosine and sine embeddings
491
494
  """
492
495
  # Handle 2D position_ids by expanding to 3D for MRoPE
493
496
  if position_ids.ndim == 2:
494
- position_ids = mx.broadcast_to(position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1]))
495
-
497
+ position_ids = mx.broadcast_to(
498
+ position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
499
+ )
500
+
496
501
  batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
497
-
502
+
498
503
  # Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
499
504
  inv_freq_expanded = mx.broadcast_to(
500
- self._get_inv_freq()[:, None, None, :],
501
- (3, batch_size, 1, self._get_inv_freq().shape[-1])
505
+ self._get_inv_freq()[:, None, None, :],
506
+ (3, batch_size, 1, self._get_inv_freq().shape[-1]),
502
507
  )
503
-
508
+
504
509
  # Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
505
510
  position_ids_expanded = position_ids[..., None].astype(mx.float32)
506
-
511
+
507
512
  # Compute frequencies: (3, batch_size, seq_len, dim//2)
508
513
  freqs = inv_freq_expanded * position_ids_expanded
509
-
514
+
510
515
  # Apply interleaved MRoPE
511
516
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
512
-
517
+
513
518
  # Create embeddings
514
519
  emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
515
520
  cos = mx.cos(emb) * self.attention_scaling
516
521
  sin = mx.sin(emb) * self.attention_scaling
517
-
522
+
518
523
  return cos.astype(x.dtype), sin.astype(x.dtype)
519
524
 
520
525
 
@@ -523,12 +528,12 @@ class TextAttention(nn.Module):
523
528
  super().__init__()
524
529
  self.config = config
525
530
  self.layer_idx = layer_idx
526
-
531
+
527
532
  dim = config.hidden_size
528
533
  self.n_heads = config.num_attention_heads
529
534
  self.n_kv_heads = config.num_key_value_heads
530
535
  self.head_dim = config.head_dim
531
- self.scale = self.head_dim ** -0.5
536
+ self.scale = self.head_dim**-0.5
532
537
 
533
538
  self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
534
539
  self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
@@ -537,7 +542,7 @@ class TextAttention(nn.Module):
537
542
 
538
543
  self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
539
544
  self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
540
-
545
+
541
546
  # Initialize rope directly
542
547
  self.rope = initialize_rope(
543
548
  config.head_dim,
@@ -573,8 +578,9 @@ class TextAttention(nn.Module):
573
578
  keys, values = cache.update_and_fetch(keys, values)
574
579
  else:
575
580
  if cache is not None:
576
- queries = self.rope(queries, offset=cache.offset+rope_deltas)
577
- keys = self.rope(keys, offset=cache.offset+rope_deltas)
581
+ offset_delta = rope_deltas.item() if rope_deltas is not None and rope_deltas.size == 1 else (rope_deltas.reshape(-1)[0].item() if rope_deltas is not None else 0)
582
+ queries = self.rope(queries, offset=cache.offset + offset_delta)
583
+ keys = self.rope(keys, offset=cache.offset + offset_delta)
578
584
  keys, values = cache.update_and_fetch(keys, values)
579
585
  else:
580
586
  queries = self.rope(queries)
@@ -618,7 +624,7 @@ class TextDecoderLayer(nn.Module):
618
624
  ) -> mx.array:
619
625
  residual = hidden_states
620
626
  hidden_states = self.input_layernorm(hidden_states)
621
-
627
+
622
628
  hidden_states, _ = self.self_attn(
623
629
  hidden_states=hidden_states,
624
630
  attention_mask=attention_mask,
@@ -640,11 +646,10 @@ class TextModel(nn.Module):
640
646
  super().__init__()
641
647
  self.config = config
642
648
  self.vocab_size = config.vocab_size
643
-
649
+
644
650
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
645
651
  self.layers = [
646
- TextDecoderLayer(config, layer_idx)
647
- for layer_idx in range(config.num_hidden_layers)
652
+ TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
648
653
  ]
649
654
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
650
655
  self.rotary_emb = TextRotaryEmbedding(config)
@@ -701,7 +706,9 @@ class TextModel(nn.Module):
701
706
  rope_deltas=rope_deltas,
702
707
  )
703
708
  if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
704
- hidden_states = self._deepstack_process(hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx])
709
+ hidden_states = self._deepstack_process(
710
+ hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
711
+ )
705
712
  hidden_states = self.norm(hidden_states)
706
713
  return hidden_states
707
714
 
@@ -712,17 +719,17 @@ class VEGModel(nn.Module):
712
719
  super().__init__()
713
720
  self.config = vision_config
714
721
  self.visual = VisionModel(vision_config)
715
-
722
+
716
723
  def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
717
724
  return self.visual(pixel_values, image_grid_thw)
718
-
725
+
719
726
  def sanitize(self, weights):
720
727
  sanitized = {}
721
728
  for k, v in weights.items():
722
- if 'visual.' in k:
729
+ if "visual." in k:
723
730
  # Remove prefixes to match our model structure
724
- clean_key = k.replace('model.visual.', '').replace('visual.', '')
725
- sanitized[f'visual.{clean_key}'] = v
731
+ clean_key = k.replace("model.visual.", "").replace("visual.", "")
732
+ sanitized[f"visual.{clean_key}"] = v
726
733
  return sanitized
727
734
 
728
735
 
@@ -735,140 +742,164 @@ class LLMModel(nn.Module):
735
742
  self.language_model = TextModel(text_config)
736
743
  if not text_config.tie_word_embeddings:
737
744
  self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
738
-
745
+
739
746
  def get_rope_index(
740
- self,
741
- input_ids: Optional[mx.array] = None,
742
- image_grid_thw: Optional[mx.array] = None,
743
- attention_mask: Optional[mx.array] = None,
744
- ) -> Tuple[mx.array, mx.array]:
745
- """Simplified version for images only (no video support)."""
746
-
747
- spatial_merge_size = 2
748
- image_token_id = 151655
749
- vision_start_token_id = 151652
750
- mrope_position_deltas = []
751
-
752
- if input_ids is not None and image_grid_thw is not None:
753
- total_input_ids = input_ids
754
- if attention_mask is None:
755
- attention_mask = mx.ones_like(total_input_ids)
756
-
757
- batch_size, seq_len = input_ids.shape
758
- position_ids_list = []
759
- image_index = 0
760
-
761
- for i in range(batch_size):
762
- input_ids_seq = total_input_ids[i]
763
- mask_seq = attention_mask[i]
764
-
765
- # Use mask to get valid length
766
- valid_length = int(mx.sum(mask_seq).item())
767
- input_ids_seq = input_ids_seq[:valid_length]
768
-
769
- image_nums = 0
770
- # Find vision start tokens by iterating through the sequence
771
- vision_start_positions = []
772
- for pos in range(input_ids_seq.shape[0]):
773
- if input_ids_seq[pos].item() == vision_start_token_id:
774
- vision_start_positions.append(pos)
775
-
776
- if len(vision_start_positions) > 0:
777
- for pos in vision_start_positions:
778
- if pos + 1 < input_ids_seq.shape[0]:
779
- if input_ids_seq[pos + 1].item() == image_token_id:
780
- image_nums += 1
781
-
782
- input_tokens = input_ids_seq.tolist()
783
- llm_pos_ids_list = []
784
- st = 0
785
- remain_images = image_nums
786
-
787
- for _ in range(image_nums):
788
- ed_image = input_tokens.index(image_token_id, st)
789
-
790
- t = image_grid_thw[image_index, 0].item()
791
- h = image_grid_thw[image_index, 1].item()
792
- w = image_grid_thw[image_index, 2].item()
793
- image_index += 1
794
- remain_images -= 1
795
- ed = ed_image
796
-
797
- llm_grid_t = int(t)
798
- llm_grid_h = int(h) // spatial_merge_size
799
- llm_grid_w = int(w) // spatial_merge_size
800
- text_len = ed - st
801
-
802
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
803
- text_pos = mx.arange(text_len).reshape(1, -1)
804
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
805
- llm_pos_ids_list.append(text_pos)
806
-
807
- # t_index is always 0 because llm_grid_t is always 1 for images
808
- t_index = mx.arange(llm_grid_t).reshape(-1, 1)
809
- t_index = mx.broadcast_to(t_index, (llm_grid_t, llm_grid_h * llm_grid_w)).reshape(-1)
810
-
811
- h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
812
- h_index = mx.broadcast_to(h_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
813
-
814
- w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
815
- w_index = mx.broadcast_to(w_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
816
-
817
- vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
818
- llm_pos_ids_list.append(vision_pos)
819
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
820
-
821
- if st < len(input_tokens):
822
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
823
- text_len = len(input_tokens) - st
824
- text_pos = mx.arange(text_len).reshape(1, -1)
825
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
826
- llm_pos_ids_list.append(text_pos)
827
-
828
- llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
829
-
830
- # Create position_ids for this batch item, pad to seq_len
831
- batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
832
- valid_length = min(seq_len, llm_positions.shape[1])
833
-
834
- # Create new arrays for each dimension
835
- pos_dim0 = mx.concatenate([llm_positions[0, :valid_length],
836
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
837
- pos_dim1 = mx.concatenate([llm_positions[1, :valid_length],
838
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
839
- pos_dim2 = mx.concatenate([llm_positions[2, :valid_length],
840
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
841
-
842
- batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
843
- position_ids_list.append(batch_position_ids)
844
-
845
- mrope_position_deltas.append(llm_positions.max().item() + 1 - len(total_input_ids[i]))
846
-
847
- # Stack all batch position_ids
848
- position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
849
- # Ensure rope deltas are 1D: (batch,)
850
- mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1)
851
- return position_ids, mrope_position_deltas
747
+ self,
748
+ input_ids: Optional[mx.array] = None,
749
+ image_grid_thw: Optional[mx.array] = None,
750
+ attention_mask: Optional[mx.array] = None,
751
+ ) -> Tuple[mx.array, mx.array]:
752
+ """Simplified version for images only (no video support)."""
753
+
754
+ spatial_merge_size = 2
755
+ image_token_id = 151655
756
+ vision_start_token_id = 151652
757
+ mrope_position_deltas = []
758
+
759
+ if input_ids is not None and image_grid_thw is not None:
760
+ total_input_ids = input_ids
761
+ if attention_mask is None:
762
+ attention_mask = mx.ones_like(total_input_ids)
763
+
764
+ batch_size, seq_len = input_ids.shape
765
+ position_ids_list = []
766
+ image_index = 0
767
+
768
+ for i in range(batch_size):
769
+ input_ids_seq = total_input_ids[i]
770
+ mask_seq = attention_mask[i]
771
+
772
+ # Use mask to get valid length
773
+ valid_length = int(mx.sum(mask_seq).item())
774
+ input_ids_seq = input_ids_seq[:valid_length]
775
+
776
+ image_nums = 0
777
+ # Find vision start tokens by iterating through the sequence
778
+ vision_start_positions = []
779
+ for pos in range(input_ids_seq.shape[0]):
780
+ if input_ids_seq[pos].item() == vision_start_token_id:
781
+ vision_start_positions.append(pos)
782
+
783
+ if len(vision_start_positions) > 0:
784
+ for pos in vision_start_positions:
785
+ if pos + 1 < input_ids_seq.shape[0]:
786
+ if input_ids_seq[pos + 1].item() == image_token_id:
787
+ image_nums += 1
788
+
789
+ input_tokens = input_ids_seq.tolist()
790
+ llm_pos_ids_list = []
791
+ st = 0
792
+ remain_images = image_nums
793
+
794
+ for _ in range(image_nums):
795
+ ed_image = input_tokens.index(image_token_id, st)
796
+
797
+ t = image_grid_thw[image_index, 0].item()
798
+ h = image_grid_thw[image_index, 1].item()
799
+ w = image_grid_thw[image_index, 2].item()
800
+ image_index += 1
801
+ remain_images -= 1
802
+ ed = ed_image
803
+
804
+ llm_grid_t = int(t)
805
+ llm_grid_h = int(h) // spatial_merge_size
806
+ llm_grid_w = int(w) // spatial_merge_size
807
+ text_len = ed - st
808
+
809
+ st_idx = (
810
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
811
+ )
812
+ text_pos = mx.arange(text_len).reshape(1, -1)
813
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
814
+ llm_pos_ids_list.append(text_pos)
815
+
816
+ # t_index is always 0 because llm_grid_t is always 1 for images
817
+ t_index = mx.arange(llm_grid_t).reshape(-1, 1)
818
+ t_index = mx.broadcast_to(
819
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
820
+ ).reshape(-1)
821
+
822
+ h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
823
+ h_index = mx.broadcast_to(
824
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
825
+ ).reshape(-1)
826
+
827
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
828
+ w_index = mx.broadcast_to(
829
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
830
+ ).reshape(-1)
831
+
832
+ vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
833
+ llm_pos_ids_list.append(vision_pos)
834
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
835
+
836
+ if st < len(input_tokens):
837
+ st_idx = (
838
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
839
+ )
840
+ text_len = len(input_tokens) - st
841
+ text_pos = mx.arange(text_len).reshape(1, -1)
842
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
843
+ llm_pos_ids_list.append(text_pos)
844
+
845
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
846
+
847
+ # Create position_ids for this batch item, pad to seq_len
848
+ batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
849
+ valid_length = min(seq_len, llm_positions.shape[1])
850
+
851
+ # Create new arrays for each dimension
852
+ pos_dim0 = mx.concatenate(
853
+ [
854
+ llm_positions[0, :valid_length],
855
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
856
+ ]
857
+ )
858
+ pos_dim1 = mx.concatenate(
859
+ [
860
+ llm_positions[1, :valid_length],
861
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
862
+ ]
863
+ )
864
+ pos_dim2 = mx.concatenate(
865
+ [
866
+ llm_positions[2, :valid_length],
867
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
868
+ ]
869
+ )
870
+
871
+ batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
872
+ position_ids_list.append(batch_position_ids)
873
+
874
+ mrope_position_deltas.append(
875
+ llm_positions.max().item() + 1 - len(total_input_ids[i])
876
+ )
877
+
878
+ # Stack all batch position_ids
879
+ position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
880
+ mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
881
+ return position_ids, mrope_position_deltas
882
+ else:
883
+ if attention_mask is not None:
884
+ position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
885
+ position_ids = mx.where(attention_mask == 0, 1, position_ids)
886
+ position_ids = mx.expand_dims(position_ids, axis=0)
887
+ position_ids = mx.broadcast_to(
888
+ position_ids, (3, position_ids.shape[1], position_ids.shape[2])
889
+ )
890
+ max_position_ids = mx.max(
891
+ mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
892
+ )
893
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
852
894
  else:
853
- if attention_mask is not None:
854
- position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
855
- position_ids = mx.where(attention_mask == 0, 1, position_ids)
856
- position_ids = mx.expand_dims(position_ids, axis=0)
857
- position_ids = mx.broadcast_to(position_ids, (3, position_ids.shape[1], position_ids.shape[2]))
858
- # Compute max position per batch, ensure 1D shape (batch,)
859
- max_position_ids = mx.max(mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=False)
860
- mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
861
- mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
862
- else:
863
- seq_len = input_ids.shape[1]
864
- batch_size = input_ids.shape[0]
865
- position_ids = mx.arange(seq_len).reshape(1, 1, -1)
866
- position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
867
- # 1D zeros for rope deltas
868
- mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
869
-
870
- return position_ids, mrope_position_deltas
871
-
895
+ seq_len = input_ids.shape[1]
896
+ batch_size = input_ids.shape[0]
897
+ position_ids = mx.arange(seq_len).reshape(1, 1, -1)
898
+ position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
899
+ mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
900
+
901
+ return position_ids, mrope_position_deltas
902
+
872
903
  def __call__(
873
904
  self,
874
905
  inputs: mx.array = None,
@@ -896,35 +927,41 @@ class LLMModel(nn.Module):
896
927
  return self.language_model.embed_tokens.as_linear(out)
897
928
  else:
898
929
  return self.lm_head(out)
899
-
930
+
900
931
  def sanitize(self, weights):
901
932
  sanitized = {}
902
933
  for k, v in weights.items():
903
- if not ('visual.' in k):
934
+ if not ("visual." in k):
904
935
  # Handle key mapping from combined model to LLM-only model
905
936
  clean_key = k
906
-
937
+
907
938
  # Remove model. prefix if present
908
- if clean_key.startswith('model.'):
939
+ if clean_key.startswith("model."):
909
940
  clean_key = clean_key[6:] # Remove 'model.'
910
-
941
+
911
942
  # Map language_ prefixed keys to language_model structure
912
- if clean_key.startswith('language_'):
913
- if clean_key.startswith('language_layers.'):
914
- clean_key = 'language_model.layers.' + clean_key[16:] # Map to language_model.layers.
915
- elif clean_key.startswith('language_embed_tokens.'):
916
- clean_key = 'language_model.embed_tokens.' + clean_key[22:] # Map to language_model.embed_tokens.
917
- elif clean_key.startswith('language_norm.'):
918
- clean_key = 'language_model.norm.' + clean_key[14:] # Map to language_model.norm.
919
-
943
+ if clean_key.startswith("language_"):
944
+ if clean_key.startswith("language_layers."):
945
+ clean_key = (
946
+ "language_model.layers." + clean_key[16:]
947
+ ) # Map to language_model.layers.
948
+ elif clean_key.startswith("language_embed_tokens."):
949
+ clean_key = (
950
+ "language_model.embed_tokens." + clean_key[22:]
951
+ ) # Map to language_model.embed_tokens.
952
+ elif clean_key.startswith("language_norm."):
953
+ clean_key = (
954
+ "language_model.norm." + clean_key[14:]
955
+ ) # Map to language_model.norm.
956
+
920
957
  sanitized[clean_key] = v
921
-
958
+
922
959
  # Handle tied embeddings - remove lm_head if using tied embeddings
923
960
  if self.args.tie_word_embeddings:
924
961
  sanitized.pop("lm_head.weight", None)
925
-
962
+
926
963
  return sanitized
927
-
964
+
928
965
  @property
929
966
  def layers(self):
930
967
  return self.language_model.layers
@@ -938,39 +975,36 @@ class Qwen3VLModel(nn.Module):
938
975
  self.config = args
939
976
  self.visual = VisionModel(args.vision_config)
940
977
  self.language_model = TextModel(args.text_config)
941
-
978
+
942
979
  def sanitize(self, weights):
943
980
  # Map weights to match the combined model structure
944
981
  sanitized = {}
945
982
  for k, v in weights.items():
946
983
  # Remove 'model.' prefix if present to match our structure
947
- clean_key = k.replace('model.', '') if k.startswith('model.') else k
984
+ clean_key = k.replace("model.", "") if k.startswith("model.") else k
948
985
  sanitized[clean_key] = v
949
986
  return sanitized
950
987
 
951
- def get_image_features(
952
- self,
953
- pixel_values: mx.array,
954
- image_grid_thw: Optional[mx.array] = None
955
- ):
988
+ def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
956
989
  image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
957
990
  # Split based on grid dimensions
958
991
  if image_grid_thw is not None:
959
- split_sizes = (mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size ** 2)).tolist()
992
+ split_sizes = (
993
+ mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
994
+ ).tolist()
960
995
  # Convert sizes to indices for mx.split (cumulative sum, excluding the last)
961
996
  split_indices = []
962
997
  cumsum = 0
963
998
  for size in split_sizes[:-1]: # Exclude last element
964
999
  cumsum += size
965
1000
  split_indices.append(cumsum)
966
-
1001
+
967
1002
  if split_indices: # Only split if we have indices
968
1003
  image_embeds = mx.split(image_embeds, split_indices)
969
1004
  else:
970
1005
  image_embeds = [image_embeds] # Single image case
971
1006
  return image_embeds, deepstack_visual_embeds
972
1007
 
973
-
974
1008
  def __call__(
975
1009
  self,
976
1010
  input_ids: mx.array = None,
@@ -989,26 +1023,25 @@ class Qwen3VLModel(nn.Module):
989
1023
  inputs_embeds = self.language_model.embed_tokens(input_ids)
990
1024
 
991
1025
  # Process images
992
-
1026
+
993
1027
  if pixel_values is not None:
994
1028
  image_embeds, deepstack_visual_embeds = self.get_image_features(
995
1029
  pixel_values, image_grid_thw
996
1030
  )
997
-
1031
+
998
1032
  # Create masks and embed visual features
999
1033
  if isinstance(image_embeds, list):
1000
1034
  image_embeds = mx.concatenate(image_embeds, axis=0)
1001
-
1035
+
1002
1036
  # Find image token positions and replace with visual embeddings
1003
- image_mask = (input_ids == self.args.image_token_id)
1037
+ image_mask = input_ids == self.args.image_token_id
1004
1038
  visual_pos_masks = image_mask
1005
-
1039
+
1006
1040
  # Replace image tokens with visual embeddings
1007
1041
  inputs_embeds = inputs_embeds.at[image_mask].set(
1008
1042
  image_embeds.astype(inputs_embeds.dtype)
1009
1043
  )
1010
1044
 
1011
-
1012
1045
  outputs = self.language_model(
1013
1046
  inputs_embeds=inputs_embeds,
1014
1047
  attention_mask=attention_mask,
@@ -1026,28 +1059,28 @@ class Qwen3VLModel(nn.Module):
1026
1059
  def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
1027
1060
  """
1028
1061
  Handle the processing of multimodal embeddings including image features and position encoding.
1029
-
1062
+
1030
1063
  This function processes vision and text inputs to create unified embeddings that can be fed
1031
1064
  into the language model. It handles:
1032
1065
  - Vision feature extraction from pixel values
1033
1066
  - Deepstack visual embedding collection
1034
1067
  - Image token replacement in text embeddings
1035
1068
  - Position encoding setup for MRoPE (Multi-dimensional RoPE)
1036
-
1069
+
1037
1070
  Args:
1038
1071
  vision_model: The vision encoder model (VEGModel instance)
1039
- llm_model: The language model (LLMModel instance)
1072
+ llm_model: The language model (LLMModel instance)
1040
1073
  input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
1041
1074
  pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
1042
1075
  image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
1043
-
1076
+
1044
1077
  Returns:
1045
1078
  tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
1046
1079
  - inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
1047
1080
  - deepstack_visual_embeds: Multi-layer visual features for deepstack processing
1048
1081
  - visual_pos_masks: Boolean mask indicating image token positions
1049
1082
  - cos: Cosine values for rotary position encoding
1050
- - sin: Sine values for rotary position encoding
1083
+ - sin: Sine values for rotary position encoding
1051
1084
  - rope_deltas: Position offset deltas for rope computation
1052
1085
  """
1053
1086
  inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
@@ -1056,74 +1089,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1056
1089
  cos = None
1057
1090
  sin = None
1058
1091
  rope_deltas = 0
1059
-
1092
+
1060
1093
  if pixel_values is not None:
1061
1094
  if pixel_values.ndim == 4:
1062
1095
  pixel_values = mx.expand_dims(pixel_values, axis=2)
1063
-
1096
+
1064
1097
  # Process each image individually to prevent feature mixing
1065
1098
  image_embeds_list = []
1066
1099
  all_deepstack_embeds = []
1067
-
1100
+
1068
1101
  # Calculate cumulative indices for each image
1069
1102
  cumulative_patches = 0
1070
-
1103
+
1071
1104
  for i in range(image_grid_thw.shape[0]):
1072
1105
  # Calculate number of patches for current image
1073
1106
  current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
1074
1107
  start_idx = cumulative_patches
1075
1108
  end_idx = cumulative_patches + current_patches
1076
1109
  cumulative_patches += current_patches
1077
-
1110
+
1078
1111
  single_pixel_values = pixel_values[start_idx:end_idx]
1079
- single_grid_thw = image_grid_thw[i:i+1]
1080
-
1112
+ single_grid_thw = image_grid_thw[i : i + 1]
1113
+
1081
1114
  # Use vision model directly
1082
1115
  single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
1083
-
1116
+
1084
1117
  # Split based on grid dimensions
1085
1118
  if single_grid_thw is not None:
1086
- split_sizes = (mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size ** 2)).tolist()
1119
+ split_sizes = (
1120
+ mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
1121
+ ).tolist()
1087
1122
  split_indices = []
1088
1123
  cumsum = 0
1089
1124
  for size in split_sizes[:-1]:
1090
1125
  cumsum += size
1091
1126
  split_indices.append(cumsum)
1092
-
1127
+
1093
1128
  if split_indices:
1094
1129
  single_embeds = mx.split(single_embeds, split_indices)
1095
1130
  else:
1096
1131
  single_embeds = [single_embeds]
1097
-
1132
+
1098
1133
  image_embeds_list.extend(single_embeds)
1099
-
1134
+
1100
1135
  # Collect deepstack embeddings
1101
1136
  if i == 0:
1102
1137
  all_deepstack_embeds = single_deepstack
1103
1138
  else:
1104
1139
  # Concatenate deepstack embeddings from different images
1105
1140
  for j in range(len(all_deepstack_embeds)):
1106
- all_deepstack_embeds[j] = mx.concatenate([all_deepstack_embeds[j], single_deepstack[j]], axis=0)
1107
-
1141
+ all_deepstack_embeds[j] = mx.concatenate(
1142
+ [all_deepstack_embeds[j], single_deepstack[j]], axis=0
1143
+ )
1144
+
1108
1145
  deepstack_visual_embeds = all_deepstack_embeds
1109
-
1146
+
1110
1147
  # Concatenate all image embeddings for processing
1111
1148
  image_embeds = mx.concatenate(image_embeds_list, axis=0)
1112
-
1149
+
1113
1150
  # Find all image token positions
1114
1151
  image_token_id = 151655 # Default image token ID
1115
- image_mask = (input_ids.squeeze(0) == image_token_id)
1152
+ image_mask = input_ids.squeeze(0) == image_token_id
1116
1153
  image_mask_np = np.array(image_mask)
1117
1154
  image_token_positions = np.where(image_mask_np)[0]
1118
-
1155
+
1119
1156
  # Verify we have the correct number of image tokens
1120
1157
  expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
1121
- assert len(image_token_positions) == expected_total_tokens, f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1122
-
1158
+ assert (
1159
+ len(image_token_positions) == expected_total_tokens
1160
+ ), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1161
+
1123
1162
  # Replace image tokens with image embeddings
1124
1163
  seq_len = inputs_embeds.shape[0]
1125
1164
  result = inputs_embeds
1126
-
1165
+
1127
1166
  # Replace image tokens with image embeddings sequentially
1128
1167
  embed_idx = 0
1129
1168
  for img_embed in image_embeds_list:
@@ -1133,7 +1172,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1133
1172
  result = mx.where(
1134
1173
  mx.expand_dims(pos_mask, axis=-1),
1135
1174
  mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
1136
- result
1175
+ result,
1137
1176
  )
1138
1177
  embed_idx += 1
1139
1178
 
@@ -1142,10 +1181,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1142
1181
  cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
1143
1182
  if inputs_embeds.ndim == 2:
1144
1183
  inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
1145
-
1184
+
1146
1185
  if image_mask is not None:
1147
1186
  visual_pos_masks = image_mask
1148
-
1187
+
1149
1188
  return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
1150
1189
 
1151
1190
 
@@ -1156,7 +1195,9 @@ class Model(nn.Module):
1156
1195
  self.args = args
1157
1196
  self.model = Qwen3VLModel(args)
1158
1197
  if not args.text_config.tie_word_embeddings:
1159
- self.lm_head = nn.Linear(args.text_config.hidden_size, args.text_config.vocab_size, bias=False)
1198
+ self.lm_head = nn.Linear(
1199
+ args.text_config.hidden_size, args.text_config.vocab_size, bias=False
1200
+ )
1160
1201
 
1161
1202
  def __call__(
1162
1203
  self,
@@ -1164,7 +1205,7 @@ class Model(nn.Module):
1164
1205
  mask: mx.array = None,
1165
1206
  cache=None,
1166
1207
  inputs_embeds: Optional[mx.array] = None,
1167
- pixel_values: Optional[mx.array] = None,
1208
+ pixel_values: Optional[mx.array] = None,
1168
1209
  image_grid_thw: Optional[mx.array] = None,
1169
1210
  visual_pos_masks: Optional[mx.array] = None,
1170
1211
  deepstack_visual_embeds: Optional[List[mx.array]] = None,
@@ -1195,13 +1236,13 @@ class Model(nn.Module):
1195
1236
  sanitized = {}
1196
1237
  for k, v in weights.items():
1197
1238
  sanitized[k] = v
1198
-
1239
+
1199
1240
  # Handle tied embeddings - remove lm_head if using tied embeddings
1200
1241
  if self.args.text_config.tie_word_embeddings:
1201
1242
  sanitized.pop("lm_head.weight", None)
1202
-
1243
+
1203
1244
  return sanitized
1204
1245
 
1205
1246
  @property
1206
1247
  def layers(self):
1207
- return self.model.language_model.layers
1248
+ return self.model.language_model.layers