fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from ..base import BaseModelConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class VisionConfig(BaseModelConfig):
|
|
9
|
+
"""Configuration class for Florence2 Vision model (DaViT)."""
|
|
10
|
+
|
|
11
|
+
model_type: str = "davit"
|
|
12
|
+
in_chans: int = 3
|
|
13
|
+
num_classes: int = 1000
|
|
14
|
+
depths: List[int] = field(default_factory=lambda: [1, 1, 9, 1])
|
|
15
|
+
dim_embed: List[int] = field(default_factory=lambda: [128, 256, 512, 1024])
|
|
16
|
+
num_heads: List[int] = field(default_factory=lambda: [4, 8, 16, 32])
|
|
17
|
+
num_groups: List[int] = field(default_factory=lambda: [4, 8, 16, 32])
|
|
18
|
+
window_size: int = 12
|
|
19
|
+
mlp_ratio: float = 4.0
|
|
20
|
+
drop_path_rate: float = 0.1
|
|
21
|
+
patch_size: List[int] = field(default_factory=lambda: [7, 3, 3, 3])
|
|
22
|
+
patch_stride: List[int] = field(default_factory=lambda: [4, 2, 2, 2])
|
|
23
|
+
patch_padding: List[int] = field(default_factory=lambda: [3, 1, 1, 1])
|
|
24
|
+
patch_prenorm: List[bool] = field(
|
|
25
|
+
default_factory=lambda: [False, False, False, False]
|
|
26
|
+
)
|
|
27
|
+
qkv_bias: bool = True
|
|
28
|
+
conv_at_attn: bool = True
|
|
29
|
+
conv_at_ffn: bool = True
|
|
30
|
+
hidden_size: int = 768
|
|
31
|
+
image_size: Tuple[int, int] = (768, 768)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class TextConfig(BaseModelConfig):
|
|
36
|
+
d_model: int = 768
|
|
37
|
+
model_type: str = "florence2"
|
|
38
|
+
encoder_attention_heads: int = 8
|
|
39
|
+
decoder_attention_heads: int = 8
|
|
40
|
+
encoder_ffn_dim: int = 3072
|
|
41
|
+
decoder_ffn_dim: int = 3072
|
|
42
|
+
dropout: float = 0.1
|
|
43
|
+
attention_dropout: float = 0.0
|
|
44
|
+
activation_dropout: float = 0.0
|
|
45
|
+
activation_function: str = "gelu"
|
|
46
|
+
init_std: float = 0.02
|
|
47
|
+
encoder_layerdrop: float = 0.0
|
|
48
|
+
decoder_layerdrop: float = 0.0
|
|
49
|
+
scale_embedding: bool = False
|
|
50
|
+
use_cache: bool = True
|
|
51
|
+
max_position_embeddings: int = 1024
|
|
52
|
+
vocab_size: int = 51289
|
|
53
|
+
pad_token_id: int = 1
|
|
54
|
+
bos_token_id: int = 0
|
|
55
|
+
eos_token_id: int = 2
|
|
56
|
+
decoder_start_token_id: int = 2
|
|
57
|
+
encoder_layers: int = 6
|
|
58
|
+
decoder_layers: int = 6
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class ModelConfig(BaseModelConfig):
|
|
63
|
+
"""Configuration class for Florence2."""
|
|
64
|
+
|
|
65
|
+
vision_config: VisionConfig
|
|
66
|
+
text_config: TextConfig
|
|
67
|
+
model_type: str = "florence2"
|
|
68
|
+
vocab_size: int = 50265
|
|
69
|
+
max_position_embeddings: int = 1024
|
|
70
|
+
pad_token_id: int = 1
|
|
71
|
+
bos_token_id: int = 0
|
|
72
|
+
eos_token_id: int = 2
|
|
73
|
+
image_token_id: int = 51289
|
|
74
|
+
image_token_index: int = 51289
|
|
75
|
+
image_feature_source: List[str] = field(
|
|
76
|
+
default_factory=lambda: ["temporal_avg_pool", "spatial_avg_pool"]
|
|
77
|
+
)
|
|
78
|
+
visual_temporal_embedding: Optional[dict] = field(
|
|
79
|
+
default_factory=lambda: {"type": "COSINE", "max_temporal_embeddings": 100}
|
|
80
|
+
)
|
|
81
|
+
image_pos_embed: Optional[dict] = field(
|
|
82
|
+
default_factory=lambda: {"type": "learned_abs_2d", "max_pos_embeddings": 50}
|
|
83
|
+
)
|
|
84
|
+
eos_token_id: Optional[List[int]] = None
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
from mlx.utils import tree_map
|
|
7
|
+
|
|
8
|
+
from ..base import InputEmbeddingsFeatures
|
|
9
|
+
|
|
10
|
+
# Import to apply Florence2Processor compatibility patch
|
|
11
|
+
from . import processing_florence2 # noqa: F401
|
|
12
|
+
from .config import ModelConfig
|
|
13
|
+
from .language import LanguageModel
|
|
14
|
+
from .vision import VisionModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def shift_tokens_right(
|
|
18
|
+
input_ids: mx.array, pad_token_id: int, decoder_start_token_id: int
|
|
19
|
+
) -> mx.array:
|
|
20
|
+
"""Shift input tokens right, adding decoder start token at beginning."""
|
|
21
|
+
shifted = mx.roll(input_ids, 1, axis=-1)
|
|
22
|
+
shifted = tree_map(lambda x: x.at[:, 0].set(decoder_start_token_id), shifted)
|
|
23
|
+
shifted = mx.where(shifted == -100, pad_token_id, shifted)
|
|
24
|
+
return shifted
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LearnedPositionEmbedding2D(nn.Module):
|
|
28
|
+
"""2D learned position embeddings."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, embedding_dim: int = 256, num_pos: int = 50):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
|
|
33
|
+
self.column_embeddings = nn.Embedding(
|
|
34
|
+
num_pos, embedding_dim - (embedding_dim // 2)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def __call__(self, x):
|
|
38
|
+
batch_size, height, width, channels = x.shape
|
|
39
|
+
width_pos = mx.arange(width)
|
|
40
|
+
height_pos = mx.arange(height)
|
|
41
|
+
|
|
42
|
+
x_emb = self.column_embeddings(width_pos)
|
|
43
|
+
y_emb = self.row_embeddings(height_pos)
|
|
44
|
+
|
|
45
|
+
pos = mx.concatenate(
|
|
46
|
+
[
|
|
47
|
+
mx.broadcast_to(x_emb[None, :, :], (height, width, x_emb.shape[-1])),
|
|
48
|
+
mx.broadcast_to(y_emb[:, None, :], (height, width, y_emb.shape[-1])),
|
|
49
|
+
],
|
|
50
|
+
axis=-1,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return mx.broadcast_to(pos[None, ...], (batch_size, height, width, channels))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PositionalEmbeddingCosine1D(nn.Module):
|
|
57
|
+
"""
|
|
58
|
+
MLX implementation of 1D cosine positional embeddings.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
embed_dim: The dimension of the embeddings
|
|
62
|
+
max_seq_len: The maximum length to precompute the positional encodings
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.embed_dim = embed_dim
|
|
68
|
+
self.max_seq_len = max_seq_len
|
|
69
|
+
|
|
70
|
+
# Generate position indices and dimension indices
|
|
71
|
+
position = mx.arange(max_seq_len)
|
|
72
|
+
dim_pos = mx.arange(0, embed_dim // 2) # Half the dimensions for sin/cos pairs
|
|
73
|
+
|
|
74
|
+
# Calculate frequency bands
|
|
75
|
+
factor = math.log(10000)
|
|
76
|
+
denominator = mx.exp(-factor * dim_pos / embed_dim)
|
|
77
|
+
|
|
78
|
+
# Create position-frequency product matrix [max_seq_len, embed_dim//2]
|
|
79
|
+
frequencies = mx.reshape(position, (-1, 1)) * denominator
|
|
80
|
+
|
|
81
|
+
# Calculate sin and cos values [max_seq_len, embed_dim//2]
|
|
82
|
+
sin_values = mx.sin(frequencies)
|
|
83
|
+
cos_values = mx.cos(frequencies)
|
|
84
|
+
|
|
85
|
+
# Interleave sin and cos values to create final embeddings
|
|
86
|
+
pos_idx_to_embed = mx.zeros((max_seq_len, embed_dim))
|
|
87
|
+
pos_idx_to_embed = mx.concatenate(
|
|
88
|
+
[mx.expand_dims(sin_values, -1), mx.expand_dims(cos_values, -1)], axis=-1
|
|
89
|
+
).reshape(max_seq_len, embed_dim)
|
|
90
|
+
|
|
91
|
+
# Store the positional embeddings
|
|
92
|
+
self.pos_idx_to_embed = pos_idx_to_embed
|
|
93
|
+
|
|
94
|
+
def __call__(self, seq_embeds: mx.array) -> mx.array:
|
|
95
|
+
"""
|
|
96
|
+
Apply positional embeddings to the input sequence.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
seq_embeds: Input sequence embeddings with shape:
|
|
100
|
+
- [T, D] where T is sequence length and D is embedding dimension
|
|
101
|
+
- [B, T, D] where B is batch size
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Positional embeddings matching input shape
|
|
105
|
+
"""
|
|
106
|
+
shape_len = len(seq_embeds.shape)
|
|
107
|
+
assert 2 <= shape_len <= 3, "Input must be 2D or 3D tensor"
|
|
108
|
+
|
|
109
|
+
len_seq = seq_embeds.shape[-2]
|
|
110
|
+
assert (
|
|
111
|
+
len_seq <= self.max_seq_len
|
|
112
|
+
), f"Sequence length {len_seq} exceeds maximum length {self.max_seq_len}"
|
|
113
|
+
|
|
114
|
+
# Get relevant portion of pre-computed embeddings
|
|
115
|
+
pos_embeds = self.pos_idx_to_embed[:len_seq]
|
|
116
|
+
|
|
117
|
+
# Add batch dimension if input is 3D
|
|
118
|
+
if shape_len == 3:
|
|
119
|
+
pos_embeds = mx.expand_dims(pos_embeds, 0)
|
|
120
|
+
|
|
121
|
+
return pos_embeds
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Model(nn.Module):
|
|
125
|
+
"""Florence-2 model for conditional generation."""
|
|
126
|
+
|
|
127
|
+
def __init__(self, config: ModelConfig):
|
|
128
|
+
super().__init__()
|
|
129
|
+
self.config = config
|
|
130
|
+
|
|
131
|
+
# Initialize vision model
|
|
132
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
133
|
+
|
|
134
|
+
# Initialize language model
|
|
135
|
+
self.language_model = LanguageModel(config.text_config)
|
|
136
|
+
|
|
137
|
+
# Image projection layers
|
|
138
|
+
image_dim = config.vision_config.dim_embed[-1]
|
|
139
|
+
text_dim = config.text_config.d_model
|
|
140
|
+
self.image_projection = mx.zeros((image_dim, text_dim))
|
|
141
|
+
|
|
142
|
+
self.image_proj_norm = nn.LayerNorm(text_dim)
|
|
143
|
+
|
|
144
|
+
# Position embeddings
|
|
145
|
+
if config.image_pos_embed["type"] == "learned_abs_2d":
|
|
146
|
+
self.image_pos_embed = LearnedPositionEmbedding2D(
|
|
147
|
+
embedding_dim=image_dim,
|
|
148
|
+
num_pos=config.image_pos_embed["max_pos_embeddings"],
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
raise NotImplementedError(
|
|
152
|
+
f"Position embedding type {config.image_pos_embed['type']} not supported"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Temporal embeddings
|
|
156
|
+
if config.visual_temporal_embedding["type"] == "COSINE":
|
|
157
|
+
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
|
|
158
|
+
embed_dim=image_dim,
|
|
159
|
+
max_seq_len=config.visual_temporal_embedding["max_temporal_embeddings"],
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
raise NotImplementedError(
|
|
163
|
+
f"Temporal embedding type {config.visual_temporal_embedding['type']} not supported"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
self.image_feature_source = config.image_feature_source
|
|
167
|
+
|
|
168
|
+
def _encode_image(self, pixel_values, extract_features=True):
|
|
169
|
+
"""Encode image using vision model and add position embeddings."""
|
|
170
|
+
T = 1 # Single frame for now
|
|
171
|
+
|
|
172
|
+
# Get vision features
|
|
173
|
+
if extract_features:
|
|
174
|
+
batch_size, C, H, W = pixel_values.shape
|
|
175
|
+
x = self.vision_tower(pixel_values)
|
|
176
|
+
else:
|
|
177
|
+
x = pixel_values
|
|
178
|
+
batch_size = pixel_values.shape[0]
|
|
179
|
+
|
|
180
|
+
# Assuming this is part of a class method, keeping the same structure
|
|
181
|
+
if self.image_pos_embed is not None:
|
|
182
|
+
# Reshape to (batch_size * T, -1, feature_dim)
|
|
183
|
+
x = mx.reshape(x, (batch_size * T, -1, x.shape[-1]))
|
|
184
|
+
num_tokens = x.shape[-2]
|
|
185
|
+
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
|
186
|
+
assert h * w == num_tokens, "only support square feature maps for now"
|
|
187
|
+
# Reshape to (batch_size * T, h, w, feature_dim)
|
|
188
|
+
x = mx.reshape(x, (batch_size * T, h, w, x.shape[-1]))
|
|
189
|
+
pos_embed = self.image_pos_embed(x)
|
|
190
|
+
x = x + pos_embed
|
|
191
|
+
# Reshape to (batch_size, T * h * w, feature_dim)
|
|
192
|
+
x = mx.reshape(x, (batch_size, T * h * w, x.shape[-1]))
|
|
193
|
+
|
|
194
|
+
if self.visual_temporal_embed is not None:
|
|
195
|
+
# Reshape for temporal embedding
|
|
196
|
+
x_temp = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
197
|
+
temporal_input = x_temp[:, :, 0]
|
|
198
|
+
visual_temporal_embed = self.visual_temporal_embed(temporal_input)
|
|
199
|
+
# Expand dims for broadcasting
|
|
200
|
+
visual_temporal_embed = mx.expand_dims(visual_temporal_embed, axis=2)
|
|
201
|
+
x = mx.reshape(x, (batch_size, T, -1, x.shape[-1])) + visual_temporal_embed
|
|
202
|
+
|
|
203
|
+
x_feat_dict = {}
|
|
204
|
+
|
|
205
|
+
# Spatial average pooling
|
|
206
|
+
x_spatial = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
207
|
+
spatial_avg_pool_x = mx.mean(x_spatial, axis=2)
|
|
208
|
+
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
|
209
|
+
|
|
210
|
+
# Temporal average pooling
|
|
211
|
+
x_temporal = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
212
|
+
temporal_avg_pool_x = mx.mean(x_temporal, axis=1)
|
|
213
|
+
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
|
214
|
+
|
|
215
|
+
# Last frame features
|
|
216
|
+
x_last = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
217
|
+
x = x_last[:, -1]
|
|
218
|
+
x_feat_dict["last_frame"] = x
|
|
219
|
+
|
|
220
|
+
# Gather features based on source configuration
|
|
221
|
+
new_x = []
|
|
222
|
+
for _image_feature_source in self.image_feature_source:
|
|
223
|
+
if _image_feature_source not in x_feat_dict:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"invalid image feature source: {_image_feature_source}"
|
|
226
|
+
)
|
|
227
|
+
new_x.append(x_feat_dict[_image_feature_source])
|
|
228
|
+
|
|
229
|
+
# Concatenate features
|
|
230
|
+
x = mx.concatenate(new_x, axis=1)
|
|
231
|
+
|
|
232
|
+
# Final projection and normalization
|
|
233
|
+
x = x @ self.image_projection
|
|
234
|
+
x = self.image_proj_norm(x)
|
|
235
|
+
|
|
236
|
+
return x
|
|
237
|
+
|
|
238
|
+
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds=None):
|
|
239
|
+
batch_size, image_token_length, _ = image_features.shape
|
|
240
|
+
image_attention_mask = mx.ones((batch_size, image_token_length))
|
|
241
|
+
|
|
242
|
+
if inputs_embeds is None:
|
|
243
|
+
return image_features, image_attention_mask
|
|
244
|
+
|
|
245
|
+
task_prefix_embeds = inputs_embeds
|
|
246
|
+
task_prefix_attention_mask = mx.ones((batch_size, task_prefix_embeds.shape[1]))
|
|
247
|
+
|
|
248
|
+
if len(task_prefix_attention_mask.shape) == 3:
|
|
249
|
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
|
250
|
+
|
|
251
|
+
# Concatenate image features and task prefix embeddings
|
|
252
|
+
inputs_embeds = mx.concatenate([image_features, task_prefix_embeds], axis=1)
|
|
253
|
+
attention_mask = mx.concatenate(
|
|
254
|
+
[image_attention_mask, task_prefix_attention_mask], axis=1
|
|
255
|
+
)
|
|
256
|
+
return inputs_embeds, attention_mask
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def layers(self):
|
|
260
|
+
return self.language_model.model.decoder.layers
|
|
261
|
+
|
|
262
|
+
def make_cache(self):
|
|
263
|
+
"""Create cache for encoder-decoder model."""
|
|
264
|
+
return self.language_model.make_cache()
|
|
265
|
+
|
|
266
|
+
def get_input_embeddings(
|
|
267
|
+
self,
|
|
268
|
+
input_ids: Optional[mx.array] = None,
|
|
269
|
+
pixel_values: Optional[mx.array] = None,
|
|
270
|
+
**kwargs,
|
|
271
|
+
):
|
|
272
|
+
|
|
273
|
+
if input_ids is not None:
|
|
274
|
+
# Filter out image placeholder tokens and only embed the task prompt
|
|
275
|
+
# Create mask for non-image tokens
|
|
276
|
+
non_image_mask = input_ids != self.config.image_token_id
|
|
277
|
+
|
|
278
|
+
# Use boolean indexing to filter - convert to list for processing
|
|
279
|
+
batch_size = input_ids.shape[0]
|
|
280
|
+
|
|
281
|
+
# For batch_size=1, filter directly
|
|
282
|
+
if batch_size == 1:
|
|
283
|
+
# Get non-image token indices using argwhere-like approach
|
|
284
|
+
mask_flat = non_image_mask[0]
|
|
285
|
+
# Sum up mask to count non-image tokens
|
|
286
|
+
num_non_image = int(mx.sum(mask_flat).item())
|
|
287
|
+
|
|
288
|
+
if num_non_image > 0:
|
|
289
|
+
# Extract non-image tokens by iterating (simple approach)
|
|
290
|
+
input_list = input_ids[0].tolist()
|
|
291
|
+
filtered_tokens = [
|
|
292
|
+
t for t in input_list if t != self.config.image_token_id
|
|
293
|
+
]
|
|
294
|
+
task_input_ids = mx.array([filtered_tokens])
|
|
295
|
+
inputs_embeds = self.language_model.model.shared(task_input_ids)
|
|
296
|
+
else:
|
|
297
|
+
inputs_embeds = None
|
|
298
|
+
else:
|
|
299
|
+
# For batch processing, embed all and handle later
|
|
300
|
+
inputs_embeds = self.language_model.model.shared(input_ids)
|
|
301
|
+
else:
|
|
302
|
+
inputs_embeds = None
|
|
303
|
+
|
|
304
|
+
attention_mask = None
|
|
305
|
+
|
|
306
|
+
# Process image if provided
|
|
307
|
+
if pixel_values is not None:
|
|
308
|
+
image_features = self._encode_image(pixel_values)
|
|
309
|
+
|
|
310
|
+
# Merge image features with text embeddings (task prompt only)
|
|
311
|
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
|
312
|
+
image_features, inputs_embeds
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# For encoder-decoder models, prepare initial decoder input
|
|
316
|
+
# Use decoder_start_token_id from text_config (default 2 for Florence2/BART)
|
|
317
|
+
decoder_start_token_id = getattr(
|
|
318
|
+
self.config.text_config, "decoder_start_token_id", 2
|
|
319
|
+
)
|
|
320
|
+
decoder_input_ids = mx.array([[decoder_start_token_id]])
|
|
321
|
+
decoder_inputs_embeds = self.language_model.model.shared(decoder_input_ids)
|
|
322
|
+
|
|
323
|
+
return InputEmbeddingsFeatures(
|
|
324
|
+
inputs_embeds=inputs_embeds,
|
|
325
|
+
attention_mask=attention_mask, # Use attention_mask for encoder-decoder
|
|
326
|
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def __call__(
|
|
330
|
+
self,
|
|
331
|
+
input_ids=None,
|
|
332
|
+
pixel_values=None,
|
|
333
|
+
cache=None,
|
|
334
|
+
decoder_input_ids=None,
|
|
335
|
+
decoder_attention_mask=None,
|
|
336
|
+
labels=None,
|
|
337
|
+
**kwargs,
|
|
338
|
+
):
|
|
339
|
+
"""Forward pass."""
|
|
340
|
+
attention_mask = None
|
|
341
|
+
decoder_inputs_embeds = None
|
|
342
|
+
|
|
343
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
344
|
+
input_ids, pixel_values, **kwargs
|
|
345
|
+
)
|
|
346
|
+
inputs_embeds = input_embeddings_features.inputs_embeds
|
|
347
|
+
attention_mask = input_embeddings_features.attention_mask
|
|
348
|
+
# Handle decoder input IDs
|
|
349
|
+
if labels is not None and decoder_input_ids is None:
|
|
350
|
+
decoder_input_ids = shift_tokens_right(
|
|
351
|
+
labels, self.config.pad_token_id, self.config.bos_token_id
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
|
355
|
+
# Use decoder_start_token_id from text_config (default 2 for Florence2/BART)
|
|
356
|
+
decoder_start_token_id = getattr(
|
|
357
|
+
self.config.text_config, "decoder_start_token_id", 2
|
|
358
|
+
)
|
|
359
|
+
decoder_input_ids = mx.array([decoder_start_token_id])[None, :]
|
|
360
|
+
decoder_inputs_embeds = self.language_model.model.shared(decoder_input_ids)
|
|
361
|
+
decoder_input_ids = None
|
|
362
|
+
|
|
363
|
+
# Forward through language model
|
|
364
|
+
outputs = self.language_model(
|
|
365
|
+
inputs=input_ids,
|
|
366
|
+
inputs_embeds=inputs_embeds,
|
|
367
|
+
attention_mask=attention_mask,
|
|
368
|
+
decoder_input_ids=decoder_input_ids,
|
|
369
|
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
370
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
371
|
+
cache=cache,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return outputs
|
|
375
|
+
|
|
376
|
+
@staticmethod
|
|
377
|
+
def sanitize(weights):
|
|
378
|
+
sanitized_weights = {}
|
|
379
|
+
for k, v in weights.items():
|
|
380
|
+
if "final_logits_bias" in k:
|
|
381
|
+
continue
|
|
382
|
+
sanitized_weights[k] = v
|
|
383
|
+
return sanitized_weights
|