fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from ..base import BaseModelConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class TextConfig(BaseModelConfig):
|
|
9
|
+
model_type: str
|
|
10
|
+
vocab_size: int
|
|
11
|
+
hidden_size: int
|
|
12
|
+
intermediate_size: int
|
|
13
|
+
max_position_embeddings: int
|
|
14
|
+
moe_intermediate_size: int
|
|
15
|
+
norm_topk_prob: bool
|
|
16
|
+
num_attention_heads: int
|
|
17
|
+
n_group: int
|
|
18
|
+
head_dim: int
|
|
19
|
+
topk_group: int
|
|
20
|
+
n_shared_experts: int
|
|
21
|
+
n_routed_experts: int
|
|
22
|
+
routed_scaling_factor: float
|
|
23
|
+
num_experts_per_tok: int
|
|
24
|
+
first_k_dense_replace: int
|
|
25
|
+
num_hidden_layers: int
|
|
26
|
+
num_key_value_heads: int
|
|
27
|
+
rms_norm_eps: float
|
|
28
|
+
use_qk_norm: bool
|
|
29
|
+
attention_bias: bool
|
|
30
|
+
partial_rotary_factor: float
|
|
31
|
+
rope_theta: float = None
|
|
32
|
+
rope_parameters: Dict = None
|
|
33
|
+
rope_scaling: Dict = field(
|
|
34
|
+
default_factory=lambda: {"type": "default", "mrope_section": [64, 32, 32]}
|
|
35
|
+
)
|
|
36
|
+
tie_word_embeddings: bool = None
|
|
37
|
+
scoring_func: str = "sigmoid"
|
|
38
|
+
topk_method: str = "noaux_tc"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class VisionConfig(BaseModelConfig):
|
|
43
|
+
model_type: str
|
|
44
|
+
depth: int
|
|
45
|
+
hidden_size: int
|
|
46
|
+
intermediate_size: int
|
|
47
|
+
num_heads: int
|
|
48
|
+
patch_size: int
|
|
49
|
+
window_size: int = 112
|
|
50
|
+
image_size: int = 336
|
|
51
|
+
in_channels: int = 3
|
|
52
|
+
rms_norm_eps: float = 1e-05
|
|
53
|
+
attention_bias: bool = False
|
|
54
|
+
attention_dropout: float = 0.0
|
|
55
|
+
hidden_act: str = "silu"
|
|
56
|
+
initializer_range: float = 0.02
|
|
57
|
+
out_hidden_size: int = 4096
|
|
58
|
+
spatial_merge_size: int = 2
|
|
59
|
+
temporal_patch_size: int = 2
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class ModelConfig(BaseModelConfig):
|
|
64
|
+
text_config: TextConfig
|
|
65
|
+
vision_config: VisionConfig
|
|
66
|
+
model_type: str
|
|
67
|
+
vocab_size: int = 257152
|
|
68
|
+
ignore_index: int = -100
|
|
69
|
+
image_token_index: int = 151363
|
|
70
|
+
image_token_id: int = 151363
|
|
71
|
+
video_token_index: int = 151364
|
|
72
|
+
video_token_id: int = 151364
|
|
73
|
+
vision_start_token_id: int = 151339
|
|
74
|
+
vision_end_token_id: int = 151340
|
|
75
|
+
hidden_size: int = 2048
|
|
76
|
+
pad_token_id: int = 0
|
|
77
|
+
eos_token_id: Optional[List[int]] = None
|
|
78
|
+
|
|
79
|
+
def __post_init__(self):
|
|
80
|
+
if self.eos_token_id is None:
|
|
81
|
+
self.eos_token_id = [151329, 151336, 151338]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .processing import Glm46VMoEProcessor
|
|
10
|
+
from .vision import VisionModel
|
|
11
|
+
|
|
12
|
+
# Register the processor with the name expected by the model config
|
|
13
|
+
try:
|
|
14
|
+
from transformers import AutoProcessor
|
|
15
|
+
|
|
16
|
+
# Register for both possible processor class names
|
|
17
|
+
AutoProcessor.register("Glm46VMoEProcessor", Glm46VMoEProcessor)
|
|
18
|
+
except Exception as e:
|
|
19
|
+
print(f"Error registering glm4v_moe processor: {e}")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Model(nn.Module):
|
|
23
|
+
def __init__(self, config: ModelConfig):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.config = config
|
|
26
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
27
|
+
self.language_model = LanguageModel(config.text_config, config)
|
|
28
|
+
|
|
29
|
+
def get_input_embeddings(
|
|
30
|
+
self,
|
|
31
|
+
input_ids: Optional[mx.array] = None,
|
|
32
|
+
pixel_values: Optional[mx.array] = None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
|
|
36
|
+
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
37
|
+
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
|
38
|
+
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
39
|
+
|
|
40
|
+
if pixel_values is None:
|
|
41
|
+
return InputEmbeddingsFeatures(
|
|
42
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
dtype = self.vision_tower.patch_embed.proj.weight.dtype
|
|
46
|
+
pixel_values = pixel_values.astype(dtype)
|
|
47
|
+
|
|
48
|
+
# Get the input embeddings from the language model
|
|
49
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
50
|
+
|
|
51
|
+
# Get the ouptut hidden states from the vision model
|
|
52
|
+
hidden_states = self.vision_tower(
|
|
53
|
+
pixel_values, grid_thw, output_hidden_states=False
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Insert special image tokens in the input_ids
|
|
57
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
58
|
+
self.config.image_token_id,
|
|
59
|
+
self.config.video_token_id,
|
|
60
|
+
hidden_states,
|
|
61
|
+
inputs_embeds,
|
|
62
|
+
input_ids,
|
|
63
|
+
)
|
|
64
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def merge_input_ids_with_image_features(
|
|
68
|
+
image_token_id,
|
|
69
|
+
video_token_id,
|
|
70
|
+
image_features,
|
|
71
|
+
inputs_embeds,
|
|
72
|
+
input_ids,
|
|
73
|
+
):
|
|
74
|
+
"""Merge image features into input embeddings at image token positions.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
image_token_id: The token ID for image placeholders
|
|
78
|
+
video_token_id: The token ID for video placeholders (fallback)
|
|
79
|
+
image_features: Vision features from the vision tower [num_features, hidden_dim]
|
|
80
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
81
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
82
|
+
grid_thw: Grid dimensions for each image (optional, not used in simple case)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Updated input embeddings with image features inserted
|
|
86
|
+
"""
|
|
87
|
+
# Find positions of image tokens
|
|
88
|
+
image_positions = input_ids == image_token_id
|
|
89
|
+
if mx.sum(image_positions) == 0:
|
|
90
|
+
image_positions = input_ids == video_token_id
|
|
91
|
+
|
|
92
|
+
# Get dimensions
|
|
93
|
+
batch_size, seq_len = input_ids.shape
|
|
94
|
+
|
|
95
|
+
# Process each batch item
|
|
96
|
+
batch_outputs = []
|
|
97
|
+
feature_start_idx = 0
|
|
98
|
+
|
|
99
|
+
for batch_idx in range(batch_size):
|
|
100
|
+
# Get mask for this batch
|
|
101
|
+
image_mask = image_positions[batch_idx]
|
|
102
|
+
num_positions = mx.sum(image_mask).item()
|
|
103
|
+
|
|
104
|
+
if num_positions > 0:
|
|
105
|
+
# Extract features for this batch
|
|
106
|
+
batch_features = image_features[
|
|
107
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# Validate we have the right number of features
|
|
111
|
+
if batch_features.shape[0] != num_positions:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
114
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Create indices for gathering
|
|
118
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
119
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
120
|
+
|
|
121
|
+
# Gather features
|
|
122
|
+
gathered_features = batch_features[feature_indices]
|
|
123
|
+
|
|
124
|
+
# Combine with original embeddings
|
|
125
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
126
|
+
batch_output = mx.where(
|
|
127
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
feature_start_idx += num_positions
|
|
131
|
+
else:
|
|
132
|
+
# No image tokens in this batch item
|
|
133
|
+
batch_output = inputs_embeds[batch_idx]
|
|
134
|
+
|
|
135
|
+
batch_outputs.append(batch_output)
|
|
136
|
+
|
|
137
|
+
# Stack all batch outputs
|
|
138
|
+
return mx.stack(batch_outputs, axis=0)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def layers(self):
|
|
142
|
+
return self.language_model.model.layers
|
|
143
|
+
|
|
144
|
+
def __call__(
|
|
145
|
+
self,
|
|
146
|
+
input_ids: mx.array,
|
|
147
|
+
pixel_values: Optional[mx.array] = None,
|
|
148
|
+
mask: Optional[mx.array] = None,
|
|
149
|
+
cache=None,
|
|
150
|
+
**kwargs,
|
|
151
|
+
):
|
|
152
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
153
|
+
input_ids, pixel_values, **kwargs
|
|
154
|
+
)
|
|
155
|
+
logits = self.language_model(
|
|
156
|
+
input_ids,
|
|
157
|
+
input_embeddings_features.inputs_embeds,
|
|
158
|
+
mask=mask,
|
|
159
|
+
cache=cache,
|
|
160
|
+
**kwargs,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return logits
|
|
164
|
+
|
|
165
|
+
def sanitize(self, weights):
|
|
166
|
+
def transform_key(key):
|
|
167
|
+
if "visual" in key:
|
|
168
|
+
if "vision_tower" not in key:
|
|
169
|
+
key = key.replace("model.", "").replace("visual", "vision_tower")
|
|
170
|
+
if "model.language_model" in key:
|
|
171
|
+
key = key.replace("model.language_model", "language_model.model")
|
|
172
|
+
if "lm_head" in key and not key.startswith("language_model"):
|
|
173
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
174
|
+
return key
|
|
175
|
+
|
|
176
|
+
return {transform_key(k): v for k, v in weights.items()}
|