fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
from transformers import AutoProcessor
|
|
7
|
+
|
|
8
|
+
from ..base import InputEmbeddingsFeatures
|
|
9
|
+
from ..deepseekocr.language import LanguageModel
|
|
10
|
+
from ..deepseekocr.sam import SAMEncoder
|
|
11
|
+
from .config import ModelConfig, SAMViTConfig
|
|
12
|
+
from .processing_deepseekocr import DeepseekOCR2Processor
|
|
13
|
+
from .vision import VisionModel
|
|
14
|
+
|
|
15
|
+
AutoProcessor.register("deepseekocr_2", DeepseekOCR2Processor)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MlpProjector(nn.Module):
|
|
19
|
+
def __init__(self, config: ModelConfig):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.config = config
|
|
22
|
+
|
|
23
|
+
if config.projector_config.projector_type == "linear":
|
|
24
|
+
self.layers = nn.Linear(
|
|
25
|
+
config.projector_config.input_dim, config.projector_config.n_embed
|
|
26
|
+
)
|
|
27
|
+
else:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Unknown projector type: {config.projector_config.projector_type}"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def __call__(self, x):
|
|
33
|
+
return self.layers(x)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Model(nn.Module):
|
|
37
|
+
def __init__(self, config: ModelConfig):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config = config
|
|
40
|
+
self.vision_model = VisionModel(config.vision_config)
|
|
41
|
+
sam_config = SAMViTConfig()
|
|
42
|
+
self.sam_model = SAMEncoder(
|
|
43
|
+
img_size=sam_config.image_size,
|
|
44
|
+
patch_size=sam_config.patch_size,
|
|
45
|
+
embed_dim=sam_config.width,
|
|
46
|
+
depth=sam_config.layers,
|
|
47
|
+
num_heads=sam_config.heads,
|
|
48
|
+
window_size=sam_config.window_size,
|
|
49
|
+
global_attn_indexes=sam_config.global_attn_indexes,
|
|
50
|
+
final_out_chans=896, # OCR-2 uses 896 output channels (vs 1024 in OCR)
|
|
51
|
+
)
|
|
52
|
+
self.language_model = LanguageModel(config.text_config)
|
|
53
|
+
self.projector = MlpProjector(config)
|
|
54
|
+
|
|
55
|
+
self.tile_tag = config.tile_tag
|
|
56
|
+
self.global_view_pos = config.global_view_pos
|
|
57
|
+
|
|
58
|
+
# view_separator is loaded from model weights (mapped from view_seperator)
|
|
59
|
+
# Initialize with zeros - will be overwritten when weights are loaded
|
|
60
|
+
if self.tile_tag == "2D":
|
|
61
|
+
# <|view_separator|> - marks end of image features
|
|
62
|
+
# Note: This must be defined as an mx.array for weight loading to work
|
|
63
|
+
self.view_separator = mx.zeros((config.projector_config.n_embed,))
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def get_input_embeddings(
|
|
70
|
+
self,
|
|
71
|
+
input_ids: Optional[mx.array] = None,
|
|
72
|
+
pixel_values: Optional[mx.array] = None,
|
|
73
|
+
images_spatial_crop: Optional[mx.array] = None,
|
|
74
|
+
images_seq_mask: Optional[mx.array] = None,
|
|
75
|
+
**kwargs,
|
|
76
|
+
):
|
|
77
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
78
|
+
|
|
79
|
+
if pixel_values is None:
|
|
80
|
+
return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
|
|
81
|
+
|
|
82
|
+
# pixel_values is a list: [patches, global_images]
|
|
83
|
+
if isinstance(pixel_values, list):
|
|
84
|
+
patches, global_images = pixel_values
|
|
85
|
+
else:
|
|
86
|
+
patches = None
|
|
87
|
+
global_images = pixel_values
|
|
88
|
+
|
|
89
|
+
# Check if we have valid pixel values
|
|
90
|
+
if mx.sum(global_images).item() == 0:
|
|
91
|
+
return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
|
|
92
|
+
|
|
93
|
+
# Process images through SAM -> Qwen2 -> Projector pipeline
|
|
94
|
+
batch_size = input_ids.shape[0]
|
|
95
|
+
|
|
96
|
+
for idx in range(batch_size):
|
|
97
|
+
all_features = []
|
|
98
|
+
|
|
99
|
+
# Check if we have valid patches (non-zero)
|
|
100
|
+
has_patches = patches is not None and mx.sum(patches).item() != 0
|
|
101
|
+
|
|
102
|
+
if has_patches:
|
|
103
|
+
# Get spatial crop info for this batch item
|
|
104
|
+
if (
|
|
105
|
+
images_spatial_crop is not None
|
|
106
|
+
and idx < images_spatial_crop.shape[0]
|
|
107
|
+
):
|
|
108
|
+
rows, cols = int(images_spatial_crop[idx, 0].item()), int(
|
|
109
|
+
images_spatial_crop[idx, 1].item()
|
|
110
|
+
)
|
|
111
|
+
num_patches = rows * cols
|
|
112
|
+
else:
|
|
113
|
+
num_patches = patches.shape[0]
|
|
114
|
+
|
|
115
|
+
# Process each patch through SAM -> Qwen2 -> Projector
|
|
116
|
+
# patches shape: (num_patches, C, H, W) where H=W=768
|
|
117
|
+
for patch_idx in range(num_patches):
|
|
118
|
+
if patch_idx >= patches.shape[0]:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
patch = patches[patch_idx : patch_idx + 1] # (1, C, H, W)
|
|
122
|
+
|
|
123
|
+
# Transpose to (B, H, W, C) for MLX conv2d
|
|
124
|
+
patch_hwc = patch.transpose(0, 2, 3, 1)
|
|
125
|
+
|
|
126
|
+
# SAM encoder: (1, 768, 768, 3) -> (1, 12, 12, 896)
|
|
127
|
+
sam_features = self.sam_model(patch_hwc)
|
|
128
|
+
|
|
129
|
+
# Qwen2 encoder: (1, 12, 12, 896) -> (1, 144, 896)
|
|
130
|
+
# Uses query_768 automatically based on 144 input tokens
|
|
131
|
+
vision_features = self.vision_model(patch_hwc, sam_features)
|
|
132
|
+
|
|
133
|
+
# Linear projector: (1, 144, 896) -> (1, 144, 1280)
|
|
134
|
+
vision_features = self.projector(vision_features)
|
|
135
|
+
|
|
136
|
+
# Remove batch dimension: (144, 1280)
|
|
137
|
+
all_features.append(vision_features[0])
|
|
138
|
+
|
|
139
|
+
# Process global view through SAM -> Qwen2 -> Projector
|
|
140
|
+
# global_images is (N, C, H, W) where H=W=1024
|
|
141
|
+
global_image = global_images[idx : idx + 1] # (1, C, H, W)
|
|
142
|
+
|
|
143
|
+
# Transpose to (B, H, W, C) for MLX conv2d
|
|
144
|
+
global_hwc = global_image.transpose(0, 2, 3, 1)
|
|
145
|
+
|
|
146
|
+
# SAM encoder: (1, 1024, 1024, 3) -> (1, 16, 16, 896)
|
|
147
|
+
sam_features = self.sam_model(global_hwc)
|
|
148
|
+
|
|
149
|
+
# Qwen2 encoder: (1, 16, 16, 896) -> (1, 256, 896)
|
|
150
|
+
# Uses query_1024 automatically based on 256 input tokens
|
|
151
|
+
global_features = self.vision_model(global_hwc, sam_features)
|
|
152
|
+
|
|
153
|
+
# Linear projector: (1, 256, 896) -> (1, 256, 1280)
|
|
154
|
+
global_features = self.projector(global_features)
|
|
155
|
+
|
|
156
|
+
# Remove batch dimension: (256, 1280)
|
|
157
|
+
all_features.append(global_features[0])
|
|
158
|
+
|
|
159
|
+
# Add view_separator
|
|
160
|
+
all_features.append(self.view_separator[None, :])
|
|
161
|
+
|
|
162
|
+
# Concatenate all features: [local_patches..., global, view_sep]
|
|
163
|
+
# Shape: (num_patches * 144 + 256 + 1, 1280)
|
|
164
|
+
vision_features = mx.concatenate(all_features, axis=0)
|
|
165
|
+
|
|
166
|
+
# Find positions where images should be placed
|
|
167
|
+
if images_seq_mask is not None:
|
|
168
|
+
image_indices = np.where(images_seq_mask[idx])[0].tolist()
|
|
169
|
+
# Assign image features to those positions
|
|
170
|
+
if len(image_indices) > 0:
|
|
171
|
+
num_positions = len(image_indices)
|
|
172
|
+
if vision_features.shape[0] >= num_positions:
|
|
173
|
+
input_embeds[idx, image_indices] = vision_features[
|
|
174
|
+
:num_positions
|
|
175
|
+
]
|
|
176
|
+
else:
|
|
177
|
+
# If we have fewer features than expected, pad with the last features
|
|
178
|
+
input_embeds[idx, image_indices[: vision_features.shape[0]]] = (
|
|
179
|
+
vision_features
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return InputEmbeddingsFeatures(inputs_embeds=input_embeds)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def layers(self):
|
|
186
|
+
return self.language_model.model.layers
|
|
187
|
+
|
|
188
|
+
def __call__(
|
|
189
|
+
self,
|
|
190
|
+
input_ids: mx.array,
|
|
191
|
+
pixel_values: Optional[mx.array] = None,
|
|
192
|
+
mask: Optional[mx.array] = None,
|
|
193
|
+
cache=None,
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
196
|
+
images_spatial_crop = kwargs.get("images_spatial_crop", None)
|
|
197
|
+
images_seq_mask = kwargs.get("images_seq_mask", None)
|
|
198
|
+
|
|
199
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
200
|
+
input_ids, pixel_values, images_spatial_crop, images_seq_mask
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
logits = self.language_model(
|
|
204
|
+
input_ids,
|
|
205
|
+
cache=cache,
|
|
206
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
207
|
+
)
|
|
208
|
+
return logits
|
|
209
|
+
|
|
210
|
+
@staticmethod
|
|
211
|
+
def sanitize(weights):
|
|
212
|
+
def transform_key(key):
|
|
213
|
+
# Handle Qwen2 encoder weights from HuggingFace format
|
|
214
|
+
# HuggingFace: model.qwen2_model.model.model.layers.X...
|
|
215
|
+
# MLX: vision_model.qwen2_encoder.layers.X...
|
|
216
|
+
if "qwen2_model.model.model.layers" in key:
|
|
217
|
+
key = key.replace(
|
|
218
|
+
"model.qwen2_model.model.model.layers",
|
|
219
|
+
"vision_model.qwen2_encoder.layers",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Handle Qwen2 encoder norm
|
|
223
|
+
if "qwen2_model.model.model.norm" in key:
|
|
224
|
+
key = key.replace(
|
|
225
|
+
"model.qwen2_model.model.model.norm",
|
|
226
|
+
"vision_model.qwen2_encoder.norm",
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Handle query weights (learnable queries for Qwen2 encoder)
|
|
230
|
+
# For 1024x1024 images, SAM outputs 16x16=256 features, so use query_1024
|
|
231
|
+
# query_1024: (256, 896) - used for 1024x1024 images
|
|
232
|
+
# query_768: (144, 896) - used for 768x768 images
|
|
233
|
+
if "model.qwen2_model.query_1024.weight" in key:
|
|
234
|
+
key = key.replace(
|
|
235
|
+
"model.qwen2_model.query_1024.weight",
|
|
236
|
+
"vision_model.qwen2_encoder.query_1024",
|
|
237
|
+
)
|
|
238
|
+
elif "model.qwen2_model.query_1024" in key:
|
|
239
|
+
key = key.replace(
|
|
240
|
+
"model.qwen2_model.query_1024",
|
|
241
|
+
"vision_model.qwen2_encoder.query_1024",
|
|
242
|
+
)
|
|
243
|
+
# Also handle query_768 for smaller images (not currently used but keep for future)
|
|
244
|
+
if "model.qwen2_model.query_768.weight" in key:
|
|
245
|
+
key = key.replace(
|
|
246
|
+
"model.qwen2_model.query_768.weight",
|
|
247
|
+
"vision_model.qwen2_encoder.query_768",
|
|
248
|
+
)
|
|
249
|
+
elif "model.qwen2_model.query_768" in key:
|
|
250
|
+
key = key.replace(
|
|
251
|
+
"model.qwen2_model.query_768",
|
|
252
|
+
"vision_model.qwen2_encoder.query_768",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Language model layers
|
|
256
|
+
if (
|
|
257
|
+
"model.layers" in key
|
|
258
|
+
and "language_model" not in key
|
|
259
|
+
and "qwen2" not in key
|
|
260
|
+
):
|
|
261
|
+
key = key.replace("model.layers", "language_model.model.layers")
|
|
262
|
+
|
|
263
|
+
if (
|
|
264
|
+
"model.embed_tokens" in key
|
|
265
|
+
and "language_model" not in key
|
|
266
|
+
and "qwen2" not in key
|
|
267
|
+
):
|
|
268
|
+
key = key.replace(
|
|
269
|
+
"model.embed_tokens", "language_model.model.embed_tokens"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if (
|
|
273
|
+
"model.norm" in key
|
|
274
|
+
and "language_model" not in key
|
|
275
|
+
and "qwen2" not in key
|
|
276
|
+
):
|
|
277
|
+
key = key.replace("model.norm", "language_model.model.norm")
|
|
278
|
+
|
|
279
|
+
if "model.vision_model" in key:
|
|
280
|
+
key = key.replace("model.vision_model", "vision_model")
|
|
281
|
+
|
|
282
|
+
if "model.sam_model" in key:
|
|
283
|
+
key = key.replace("model.sam_model", "sam_model")
|
|
284
|
+
|
|
285
|
+
if "model.projector" in key:
|
|
286
|
+
key = key.replace("model.projector", "projector")
|
|
287
|
+
|
|
288
|
+
# Note: HuggingFace has typo "view_seperator" (e instead of a)
|
|
289
|
+
if "model.view_seperator" in key:
|
|
290
|
+
key = key.replace("model.view_seperator", "view_separator")
|
|
291
|
+
|
|
292
|
+
if "lm_head.weight" in key and "language_model" not in key:
|
|
293
|
+
key = key.replace("lm_head.weight", "language_model.lm_head.weight")
|
|
294
|
+
|
|
295
|
+
return key
|
|
296
|
+
|
|
297
|
+
return {transform_key(k): v for k, v in weights.items()}
|