fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .vision import VisionModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Model(nn.Module):
|
|
13
|
+
def __init__(self, config: ModelConfig):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.config = config
|
|
16
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
17
|
+
self.language_model = LanguageModel(config.text_config, config)
|
|
18
|
+
|
|
19
|
+
def get_input_embeddings(
|
|
20
|
+
self,
|
|
21
|
+
input_ids: Optional[mx.array] = None,
|
|
22
|
+
pixel_values: Optional[mx.array] = None,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
image_grid_thw = kwargs.get("image_grid_thw", None)
|
|
26
|
+
video_grid_thw = kwargs.get("video_grid_thw", None)
|
|
27
|
+
mask = kwargs.get("mask", None)
|
|
28
|
+
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
29
|
+
|
|
30
|
+
if pixel_values is None:
|
|
31
|
+
# Reset position state for text-only generation
|
|
32
|
+
self.language_model._position_ids = None
|
|
33
|
+
self.language_model._rope_deltas = None
|
|
34
|
+
return InputEmbeddingsFeatures(
|
|
35
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
dtype = self.vision_tower.patch_embed.proj.weight.dtype
|
|
39
|
+
pixel_values = pixel_values.astype(dtype)
|
|
40
|
+
|
|
41
|
+
# Get the input embeddings from the language model
|
|
42
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
43
|
+
|
|
44
|
+
# Get the ouptut hidden states from the vision model
|
|
45
|
+
hidden_states = self.vision_tower(
|
|
46
|
+
pixel_values, grid_thw, output_hidden_states=False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Insert special image tokens in the input_ids
|
|
50
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
51
|
+
self.config.image_token_id,
|
|
52
|
+
self.config.video_token_id,
|
|
53
|
+
hidden_states,
|
|
54
|
+
inputs_embeds,
|
|
55
|
+
input_ids,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Pre-calculate position_ids for chunked prefill
|
|
59
|
+
if image_grid_thw is not None or video_grid_thw is not None:
|
|
60
|
+
position_ids, rope_deltas = self.language_model.get_rope_index(
|
|
61
|
+
input_ids, image_grid_thw, video_grid_thw, mask
|
|
62
|
+
)
|
|
63
|
+
self.language_model._position_ids = position_ids
|
|
64
|
+
self.language_model._rope_deltas = rope_deltas
|
|
65
|
+
|
|
66
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def merge_input_ids_with_image_features(
|
|
70
|
+
image_token_id,
|
|
71
|
+
video_token_id,
|
|
72
|
+
image_features,
|
|
73
|
+
inputs_embeds,
|
|
74
|
+
input_ids,
|
|
75
|
+
):
|
|
76
|
+
"""Merge image features into input embeddings at image token positions.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
image_features: Vision features from the vision tower [num_features, hidden_dim]
|
|
80
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
81
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Updated input embeddings with image features inserted
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
# Positions of <image> tokens in input_ids
|
|
88
|
+
image_positions = input_ids == image_token_id
|
|
89
|
+
if mx.sum(image_positions) == 0:
|
|
90
|
+
image_positions = input_ids == video_token_id
|
|
91
|
+
|
|
92
|
+
# Get dimensions
|
|
93
|
+
batch_size, seq_len = input_ids.shape
|
|
94
|
+
|
|
95
|
+
# Process each batch item
|
|
96
|
+
batch_outputs = []
|
|
97
|
+
feature_start_idx = 0
|
|
98
|
+
|
|
99
|
+
for batch_idx in range(batch_size):
|
|
100
|
+
# Get mask for this batch
|
|
101
|
+
image_mask = image_positions[batch_idx]
|
|
102
|
+
num_positions = mx.sum(image_mask).item()
|
|
103
|
+
|
|
104
|
+
if num_positions > 0:
|
|
105
|
+
# Extract features for this batch
|
|
106
|
+
batch_features = image_features[
|
|
107
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# Validate we have the right number of features
|
|
111
|
+
if batch_features.shape[0] != num_positions:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
114
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Create indices for gathering
|
|
118
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
119
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
120
|
+
|
|
121
|
+
# Gather features
|
|
122
|
+
gathered_features = batch_features[feature_indices]
|
|
123
|
+
|
|
124
|
+
# Combine with original embeddings
|
|
125
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
126
|
+
batch_output = mx.where(
|
|
127
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
feature_start_idx += num_positions
|
|
131
|
+
else:
|
|
132
|
+
# No image tokens in this batch item
|
|
133
|
+
batch_output = inputs_embeds[batch_idx]
|
|
134
|
+
|
|
135
|
+
batch_outputs.append(batch_output)
|
|
136
|
+
|
|
137
|
+
# Stack all batch outputs
|
|
138
|
+
return mx.stack(batch_outputs, axis=0)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def layers(self):
|
|
142
|
+
return self.language_model.model.layers
|
|
143
|
+
|
|
144
|
+
def __call__(
|
|
145
|
+
self,
|
|
146
|
+
input_ids: mx.array,
|
|
147
|
+
pixel_values: Optional[mx.array] = None,
|
|
148
|
+
mask: Optional[mx.array] = None,
|
|
149
|
+
cache=None,
|
|
150
|
+
**kwargs,
|
|
151
|
+
):
|
|
152
|
+
|
|
153
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
154
|
+
input_ids, pixel_values, **kwargs
|
|
155
|
+
)
|
|
156
|
+
kwargs = {
|
|
157
|
+
"pixel_values": pixel_values,
|
|
158
|
+
**kwargs,
|
|
159
|
+
}
|
|
160
|
+
logits = self.language_model(
|
|
161
|
+
input_ids,
|
|
162
|
+
input_embeddings_features.inputs_embeds,
|
|
163
|
+
mask=mask,
|
|
164
|
+
cache=cache,
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
return logits
|
|
168
|
+
|
|
169
|
+
def sanitize(self, weights):
|
|
170
|
+
def transform_key(key):
|
|
171
|
+
if "vision_tower" not in key:
|
|
172
|
+
key = key.replace("visual", "vision_tower")
|
|
173
|
+
if "language_model" not in key:
|
|
174
|
+
if "model" in key:
|
|
175
|
+
key = key.replace("model", "language_model.model")
|
|
176
|
+
elif "lm_head" in key:
|
|
177
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
178
|
+
return key
|
|
179
|
+
|
|
180
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from .config import VisionConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_array_shape(arr):
|
|
10
|
+
shape = arr.shape
|
|
11
|
+
|
|
12
|
+
# Check if the shape has 4 dimensions
|
|
13
|
+
if len(shape) not in [4, 5]:
|
|
14
|
+
return False
|
|
15
|
+
|
|
16
|
+
B, out_channels, kH, KW, t = shape
|
|
17
|
+
|
|
18
|
+
if t == 3:
|
|
19
|
+
return True
|
|
20
|
+
|
|
21
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
22
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
23
|
+
return True
|
|
24
|
+
else:
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def rotate_half(x):
|
|
29
|
+
"""Rotates half the hidden dims of the input."""
|
|
30
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
31
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
32
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
|
|
36
|
+
orig_dtype = tensor.dtype
|
|
37
|
+
|
|
38
|
+
cos = mx.cos(freqs)
|
|
39
|
+
sin = mx.sin(freqs)
|
|
40
|
+
|
|
41
|
+
cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
|
|
42
|
+
cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
43
|
+
cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
|
|
44
|
+
|
|
45
|
+
sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
|
|
46
|
+
sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
47
|
+
sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
|
|
48
|
+
|
|
49
|
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
50
|
+
return output.astype(orig_dtype)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class VisionRotaryEmbedding(nn.Module):
|
|
54
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.dim = dim
|
|
57
|
+
self.theta = theta
|
|
58
|
+
|
|
59
|
+
def __call__(self, seqlen: int) -> mx.array:
|
|
60
|
+
inv_freq = 1.0 / (
|
|
61
|
+
self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
|
|
62
|
+
)
|
|
63
|
+
seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
|
|
64
|
+
freqs = mx.outer(seq, inv_freq)
|
|
65
|
+
return freqs
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class PatchEmbed(nn.Module):
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
patch_size: int = 14,
|
|
72
|
+
temporal_patch_size: int = 2,
|
|
73
|
+
in_channels: int = 3,
|
|
74
|
+
embed_dim: int = 1152,
|
|
75
|
+
) -> None:
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.patch_size = patch_size
|
|
78
|
+
self.temporal_patch_size = temporal_patch_size
|
|
79
|
+
self.in_channels = in_channels
|
|
80
|
+
self.embed_dim = embed_dim
|
|
81
|
+
|
|
82
|
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
|
83
|
+
self.proj = nn.Conv3d(
|
|
84
|
+
in_channels,
|
|
85
|
+
embed_dim,
|
|
86
|
+
kernel_size=kernel_size,
|
|
87
|
+
stride=kernel_size,
|
|
88
|
+
bias=False,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
92
|
+
hidden_states = hidden_states.reshape(
|
|
93
|
+
-1,
|
|
94
|
+
self.in_channels,
|
|
95
|
+
self.temporal_patch_size,
|
|
96
|
+
self.patch_size,
|
|
97
|
+
self.patch_size,
|
|
98
|
+
).moveaxis(1, 4)
|
|
99
|
+
|
|
100
|
+
hidden_states = self.proj(hidden_states)
|
|
101
|
+
hidden_states = hidden_states.reshape(-1, self.embed_dim)
|
|
102
|
+
return hidden_states
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class PatchMerger(nn.Module):
|
|
106
|
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
109
|
+
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
|
|
110
|
+
self.mlp = [
|
|
111
|
+
nn.Linear(self.hidden_size, self.hidden_size),
|
|
112
|
+
nn.GELU(),
|
|
113
|
+
nn.Linear(self.hidden_size, dim),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
117
|
+
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
|
118
|
+
for layer in self.mlp:
|
|
119
|
+
x = layer(x)
|
|
120
|
+
return x
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class Attention(nn.Module):
|
|
124
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
125
|
+
super().__init__()
|
|
126
|
+
self.num_heads = num_heads
|
|
127
|
+
self.head_dim = head_dim = dim // num_heads
|
|
128
|
+
self.scale = head_dim**-0.5
|
|
129
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
|
130
|
+
self.proj = nn.Linear(dim, dim)
|
|
131
|
+
|
|
132
|
+
def __call__(
|
|
133
|
+
self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
|
|
134
|
+
) -> mx.array:
|
|
135
|
+
seq_length = x.shape[0]
|
|
136
|
+
qkv = (
|
|
137
|
+
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
|
|
138
|
+
)
|
|
139
|
+
q, k, v = mx.split(qkv, 3)
|
|
140
|
+
|
|
141
|
+
q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
|
|
142
|
+
k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
|
|
143
|
+
attention_mask = mx.zeros((seq_length, seq_length), dtype=mx.bool_)
|
|
144
|
+
|
|
145
|
+
for i in range(1, len(cu_seqlens)):
|
|
146
|
+
start = int(cu_seqlens[i - 1])
|
|
147
|
+
end = int(cu_seqlens[i])
|
|
148
|
+
attention_mask[start:end, start:end] = True
|
|
149
|
+
|
|
150
|
+
q = q.transpose(0, 2, 1, 3)
|
|
151
|
+
k = k.transpose(0, 2, 1, 3)
|
|
152
|
+
v = v.transpose(0, 2, 1, 3)
|
|
153
|
+
|
|
154
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
155
|
+
q, k, v, scale=self.scale, mask=attention_mask
|
|
156
|
+
)
|
|
157
|
+
output = output.transpose(0, 2, 1, 3)
|
|
158
|
+
output = output.reshape(seq_length, -1)
|
|
159
|
+
return self.proj(output)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class MLP(nn.Module):
|
|
163
|
+
def __init__(self, dim, hidden_dim):
|
|
164
|
+
super().__init__()
|
|
165
|
+
self.activation_fn = nn.GELU(approx="fast")
|
|
166
|
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
|
167
|
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
|
168
|
+
|
|
169
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
170
|
+
x = self.activation_fn(self.fc1(x))
|
|
171
|
+
x = self.fc2(x)
|
|
172
|
+
return x
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class Qwen2VLVisionBlock(nn.Module):
|
|
176
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
|
179
|
+
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
|
180
|
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
|
181
|
+
|
|
182
|
+
self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
|
|
183
|
+
self.mlp = MLP(dim=config.embed_dim, hidden_dim=mlp_hidden_dim)
|
|
184
|
+
|
|
185
|
+
def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
|
|
186
|
+
hidden_states = hidden_states + self.attn(
|
|
187
|
+
self.norm1(hidden_states),
|
|
188
|
+
cu_seqlens=cu_seqlens,
|
|
189
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
190
|
+
)
|
|
191
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
192
|
+
return hidden_states
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class VisionModel(nn.Module):
|
|
196
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
197
|
+
super().__init__()
|
|
198
|
+
self.config = config
|
|
199
|
+
self.model_type = config.model_type
|
|
200
|
+
if self.model_type != "qwen2_vl":
|
|
201
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
202
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
203
|
+
|
|
204
|
+
self.patch_embed = PatchEmbed(
|
|
205
|
+
patch_size=config.patch_size,
|
|
206
|
+
temporal_patch_size=config.temporal_patch_size,
|
|
207
|
+
in_channels=config.in_channels,
|
|
208
|
+
embed_dim=config.embed_dim,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
head_dim = config.embed_dim // config.num_heads
|
|
212
|
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
213
|
+
|
|
214
|
+
self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
|
|
215
|
+
self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
|
|
216
|
+
|
|
217
|
+
def rot_pos_emb(self, grid_thw):
|
|
218
|
+
pos_ids = []
|
|
219
|
+
|
|
220
|
+
for t, h, w in grid_thw:
|
|
221
|
+
h, w = int(h), int(w) # Ensure h and w are integers
|
|
222
|
+
hpos_ids = mx.expand_dims(mx.arange(h), 1)
|
|
223
|
+
hpos_ids = mx.repeat(hpos_ids, w, axis=1)
|
|
224
|
+
hpos_ids = hpos_ids.reshape(
|
|
225
|
+
h // self.spatial_merge_size,
|
|
226
|
+
self.spatial_merge_size,
|
|
227
|
+
w // self.spatial_merge_size,
|
|
228
|
+
self.spatial_merge_size,
|
|
229
|
+
)
|
|
230
|
+
hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
|
|
231
|
+
hpos_ids = hpos_ids.flatten()
|
|
232
|
+
|
|
233
|
+
wpos_ids = mx.expand_dims(mx.arange(w), 0)
|
|
234
|
+
wpos_ids = mx.repeat(wpos_ids, h, axis=0)
|
|
235
|
+
wpos_ids = wpos_ids.reshape(
|
|
236
|
+
h // self.spatial_merge_size,
|
|
237
|
+
self.spatial_merge_size,
|
|
238
|
+
w // self.spatial_merge_size,
|
|
239
|
+
self.spatial_merge_size,
|
|
240
|
+
)
|
|
241
|
+
wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
|
|
242
|
+
wpos_ids = wpos_ids.flatten()
|
|
243
|
+
|
|
244
|
+
stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
|
|
245
|
+
pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
|
|
246
|
+
|
|
247
|
+
pos_ids = mx.concatenate(pos_ids, axis=0)
|
|
248
|
+
max_grid_size = mx.max(grid_thw[:, 1:])
|
|
249
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
250
|
+
|
|
251
|
+
rotary_pos_emb_full = rotary_pos_emb_full[pos_ids]
|
|
252
|
+
|
|
253
|
+
return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1)
|
|
254
|
+
|
|
255
|
+
def __call__(
|
|
256
|
+
self,
|
|
257
|
+
hidden_states: mx.array,
|
|
258
|
+
grid_thw: mx.array,
|
|
259
|
+
output_hidden_states: Optional[bool] = None,
|
|
260
|
+
) -> mx.array:
|
|
261
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
262
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
263
|
+
|
|
264
|
+
# Assuming grid_thw has shape (batch_size, 3)
|
|
265
|
+
batch_size = grid_thw.shape[0]
|
|
266
|
+
|
|
267
|
+
# Calculate cu_seqlens for each item in the batch
|
|
268
|
+
cu_seqlens = []
|
|
269
|
+
for i in range(batch_size):
|
|
270
|
+
seq_len = grid_thw[i, 1] * grid_thw[i, 2]
|
|
271
|
+
cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
|
|
272
|
+
|
|
273
|
+
# Concatenate the cu_seqlens for all items in the batch
|
|
274
|
+
cu_seqlens = mx.concatenate(cu_seqlens)
|
|
275
|
+
|
|
276
|
+
cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
|
|
277
|
+
cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
|
|
278
|
+
|
|
279
|
+
encoder_states = (hidden_states,) if output_hidden_states else None
|
|
280
|
+
|
|
281
|
+
for blk in self.blocks:
|
|
282
|
+
hidden_states = blk(
|
|
283
|
+
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
|
284
|
+
)
|
|
285
|
+
if output_hidden_states:
|
|
286
|
+
encoder_states = encoder_states + (hidden_states,)
|
|
287
|
+
|
|
288
|
+
return self.merger(hidden_states)
|
|
289
|
+
|
|
290
|
+
def sanitize(self, weights):
|
|
291
|
+
sanitized_weights = {}
|
|
292
|
+
for k, v in weights.items():
|
|
293
|
+
if "position_ids" in k:
|
|
294
|
+
# Remove unused position_ids
|
|
295
|
+
continue
|
|
296
|
+
elif "patch_embed.proj.weight" in k:
|
|
297
|
+
# PyTorch conv2d weight tensors have shape:
|
|
298
|
+
# [out_channels, in_channels, kH, KW]
|
|
299
|
+
# MLX conv2d expects the weight be of shape:
|
|
300
|
+
# [out_channels, kH, KW, in_channels]
|
|
301
|
+
if check_array_shape(v):
|
|
302
|
+
sanitized_weights[k] = v
|
|
303
|
+
else:
|
|
304
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
|
|
305
|
+
else:
|
|
306
|
+
sanitized_weights[k] = v
|
|
307
|
+
|
|
308
|
+
return sanitized_weights
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .audio import AudioModel
|
|
2
|
+
from .config import (
|
|
3
|
+
AudioConfig,
|
|
4
|
+
Code2WavConfig,
|
|
5
|
+
CodePredictorConfig,
|
|
6
|
+
ModelConfig,
|
|
7
|
+
TalkerConfig,
|
|
8
|
+
TextConfig,
|
|
9
|
+
ThinkerConfig,
|
|
10
|
+
VisionConfig,
|
|
11
|
+
)
|
|
12
|
+
from .language import LanguageModel
|
|
13
|
+
from .qwen3_omni_moe import Model
|
|
14
|
+
from .vision import VisionModel
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Model",
|
|
18
|
+
"ModelConfig",
|
|
19
|
+
"LanguageModel",
|
|
20
|
+
"VisionModel",
|
|
21
|
+
"AudioModel",
|
|
22
|
+
"TextConfig",
|
|
23
|
+
"VisionConfig",
|
|
24
|
+
"AudioConfig",
|
|
25
|
+
"ThinkerConfig",
|
|
26
|
+
"TalkerConfig",
|
|
27
|
+
"CodePredictorConfig",
|
|
28
|
+
"Code2WavConfig",
|
|
29
|
+
]
|