fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,692 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from .config import VisionConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NamedSequential(nn.Module):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self._order = []
|
|
13
|
+
|
|
14
|
+
def add_module(self, name, module):
|
|
15
|
+
setattr(self, name, module)
|
|
16
|
+
self._order.append(name)
|
|
17
|
+
|
|
18
|
+
def __call__(self, x):
|
|
19
|
+
for name in self._order:
|
|
20
|
+
x = getattr(self, name)(x)
|
|
21
|
+
return x
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CallableModuleList(list):
|
|
25
|
+
def __call__(self, x: mx.array):
|
|
26
|
+
for item in self:
|
|
27
|
+
x = item(x)
|
|
28
|
+
return x
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MHSA(nn.Module):
|
|
32
|
+
"""Multi-headed Self Attention module.
|
|
33
|
+
|
|
34
|
+
Source modified from:
|
|
35
|
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
dim: int,
|
|
41
|
+
head_dim: int = 32,
|
|
42
|
+
qkv_bias: bool = False,
|
|
43
|
+
attn_drop: float = 0.0,
|
|
44
|
+
proj_drop: float = 0.0,
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
assert dim % head_dim == 0, "dim should be divisible by head_dim"
|
|
48
|
+
self.head_dim = head_dim
|
|
49
|
+
self.num_heads = dim // head_dim
|
|
50
|
+
self.scale = head_dim**-0.5
|
|
51
|
+
|
|
52
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
53
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
54
|
+
self.proj = nn.Linear(dim, dim)
|
|
55
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
56
|
+
|
|
57
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
58
|
+
# Source: https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/mci.py#L661
|
|
59
|
+
x = x.transpose(0, 3, 1, 2)
|
|
60
|
+
B, C, H, W = x.shape
|
|
61
|
+
N = H * W
|
|
62
|
+
x = x.flatten(start_axis=2).transpose(0, 2, 1) # (B, N, C)
|
|
63
|
+
qkv = (
|
|
64
|
+
self.qkv(x)
|
|
65
|
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
66
|
+
.transpose(2, 0, 3, 1, 4)
|
|
67
|
+
)
|
|
68
|
+
q, k, v = qkv
|
|
69
|
+
|
|
70
|
+
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
|
71
|
+
x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
|
|
72
|
+
x = self.proj(x)
|
|
73
|
+
x = self.proj_drop(x)
|
|
74
|
+
|
|
75
|
+
x = x.reshape(B, H, W, C)
|
|
76
|
+
return x
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ConvFFN(nn.Module):
|
|
80
|
+
"""Convolutional FFN Module."""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
in_channels: int,
|
|
85
|
+
hidden_channels: Optional[int] = None,
|
|
86
|
+
out_channels: Optional[int] = None,
|
|
87
|
+
act_layer: nn.Module = nn.GELU,
|
|
88
|
+
) -> None:
|
|
89
|
+
super().__init__()
|
|
90
|
+
out_channels = out_channels or in_channels
|
|
91
|
+
hidden_channels = hidden_channels or in_channels
|
|
92
|
+
self.conv = NamedSequential()
|
|
93
|
+
self.conv.add_module(
|
|
94
|
+
"conv",
|
|
95
|
+
nn.Conv2d(
|
|
96
|
+
in_channels=in_channels,
|
|
97
|
+
out_channels=out_channels,
|
|
98
|
+
kernel_size=7,
|
|
99
|
+
padding=3,
|
|
100
|
+
groups=in_channels,
|
|
101
|
+
bias=False,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
self.conv.add_module(
|
|
105
|
+
"bn",
|
|
106
|
+
nn.BatchNorm(num_features=out_channels),
|
|
107
|
+
)
|
|
108
|
+
self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
|
|
109
|
+
self.act = act_layer()
|
|
110
|
+
self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
|
|
111
|
+
|
|
112
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
113
|
+
x = self.conv(x)
|
|
114
|
+
x = self.fc1(x)
|
|
115
|
+
x = self.act(x)
|
|
116
|
+
x = self.fc2(x)
|
|
117
|
+
return x
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LayerNormChannel(nn.Module):
|
|
121
|
+
"""
|
|
122
|
+
LayerNorm only for Channel Dimension.
|
|
123
|
+
Input: tensor in shape [B, H, W, C]
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(self, num_features, eps=1e-05) -> None:
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.weight = mx.ones(num_features)
|
|
129
|
+
self.bias = mx.zeros(num_features)
|
|
130
|
+
self.eps = eps
|
|
131
|
+
|
|
132
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
133
|
+
u = x.mean(-1, keepdims=True)
|
|
134
|
+
s = mx.power(x - u, 2).mean(-1, keepdims=True)
|
|
135
|
+
x = (x - u) / mx.sqrt(s + self.eps)
|
|
136
|
+
x = self.weight * x + self.bias
|
|
137
|
+
return x
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class AttentionBlock(nn.Module):
|
|
141
|
+
"""Implementation of metaformer block with MHSA as token mixer.
|
|
142
|
+
|
|
143
|
+
For more details on Metaformer structure, please refer to:
|
|
144
|
+
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
dim: int,
|
|
150
|
+
mlp_ratio: float = 4.0,
|
|
151
|
+
act_layer: nn.Module = nn.GELU,
|
|
152
|
+
norm_layer: nn.Module = nn.BatchNorm,
|
|
153
|
+
):
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
self.norm = norm_layer(num_features=dim)
|
|
157
|
+
self.token_mixer = MHSA(dim=dim)
|
|
158
|
+
|
|
159
|
+
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
|
160
|
+
mlp_ratio
|
|
161
|
+
)
|
|
162
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
163
|
+
self.convffn = ConvFFN(
|
|
164
|
+
in_channels=dim,
|
|
165
|
+
hidden_channels=mlp_hidden_dim,
|
|
166
|
+
act_layer=act_layer,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.layer_scale_1 = mx.ones((1, 1, dim))
|
|
170
|
+
self.layer_scale_2 = mx.ones((1, 1, dim))
|
|
171
|
+
|
|
172
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
173
|
+
x = x + self.layer_scale_1 * self.token_mixer(self.norm(x))
|
|
174
|
+
x = x + self.layer_scale_2 * self.convffn(x)
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class RepCPE(nn.Module):
|
|
179
|
+
"""Implementation of conditional positional encoding.
|
|
180
|
+
|
|
181
|
+
For more details refer to paper:
|
|
182
|
+
`Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
in_channels: int,
|
|
188
|
+
embed_dim: int = 768,
|
|
189
|
+
spatial_shape=(7, 7),
|
|
190
|
+
) -> None:
|
|
191
|
+
super().__init__()
|
|
192
|
+
if isinstance(spatial_shape, int):
|
|
193
|
+
spatial_shape = tuple([spatial_shape] * 2)
|
|
194
|
+
assert isinstance(spatial_shape, Tuple), (
|
|
195
|
+
f'"spatial_shape" must by a sequence or int, '
|
|
196
|
+
f"get {type(spatial_shape)} instead."
|
|
197
|
+
)
|
|
198
|
+
assert len(spatial_shape) == 2, (
|
|
199
|
+
f'Length of "spatial_shape" should be 2, '
|
|
200
|
+
f"got {len(spatial_shape)} instead."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
self.reparam_conv = nn.Conv2d(
|
|
204
|
+
in_channels=in_channels,
|
|
205
|
+
out_channels=embed_dim,
|
|
206
|
+
kernel_size=spatial_shape,
|
|
207
|
+
stride=1,
|
|
208
|
+
padding=int(spatial_shape[0] // 2),
|
|
209
|
+
groups=embed_dim,
|
|
210
|
+
bias=True,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
214
|
+
return self.reparam_conv(x)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ReparamLargeKernelConv(nn.Module):
|
|
218
|
+
"""Building Block of RepLKNet
|
|
219
|
+
|
|
220
|
+
This class defines overparameterized large kernel conv block
|
|
221
|
+
introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
|
|
222
|
+
|
|
223
|
+
Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
in_channels: int,
|
|
229
|
+
out_channels: int,
|
|
230
|
+
kernel_size: int,
|
|
231
|
+
stride: int,
|
|
232
|
+
groups: int,
|
|
233
|
+
activation: nn.Module = nn.GELU(),
|
|
234
|
+
) -> None:
|
|
235
|
+
super(ReparamLargeKernelConv, self).__init__()
|
|
236
|
+
self.activation = activation
|
|
237
|
+
self.lkb_reparam = nn.Conv2d(
|
|
238
|
+
in_channels=in_channels,
|
|
239
|
+
out_channels=out_channels,
|
|
240
|
+
kernel_size=kernel_size,
|
|
241
|
+
stride=stride,
|
|
242
|
+
padding=kernel_size // 2,
|
|
243
|
+
dilation=1,
|
|
244
|
+
groups=groups,
|
|
245
|
+
bias=True,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
249
|
+
return self.activation(self.lkb_reparam(x))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class PatchEmbed(nn.Module):
|
|
253
|
+
"""Convolutional patch embedding layer."""
|
|
254
|
+
|
|
255
|
+
def __init__(
|
|
256
|
+
self,
|
|
257
|
+
patch_size: int,
|
|
258
|
+
stride: int,
|
|
259
|
+
in_channels: int,
|
|
260
|
+
embed_dim: int,
|
|
261
|
+
) -> None:
|
|
262
|
+
super().__init__()
|
|
263
|
+
self.proj = CallableModuleList()
|
|
264
|
+
self.proj.append(
|
|
265
|
+
ReparamLargeKernelConv(
|
|
266
|
+
in_channels=in_channels,
|
|
267
|
+
out_channels=embed_dim,
|
|
268
|
+
kernel_size=patch_size,
|
|
269
|
+
stride=stride,
|
|
270
|
+
groups=in_channels,
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
self.proj.append(
|
|
274
|
+
MobileOneBlock(
|
|
275
|
+
in_channels=embed_dim,
|
|
276
|
+
out_channels=embed_dim,
|
|
277
|
+
kernel_size=1,
|
|
278
|
+
stride=1,
|
|
279
|
+
padding=0,
|
|
280
|
+
groups=1,
|
|
281
|
+
use_se=False,
|
|
282
|
+
)
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
286
|
+
return self.proj(x)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class RepMixer(nn.Module):
|
|
290
|
+
"""Reparameterizable token mixer.
|
|
291
|
+
|
|
292
|
+
For more details, please refer to Apple's paper:
|
|
293
|
+
`FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
dim,
|
|
299
|
+
kernel_size=3,
|
|
300
|
+
):
|
|
301
|
+
super().__init__()
|
|
302
|
+
self.dim = dim
|
|
303
|
+
self.kernel_size = kernel_size
|
|
304
|
+
|
|
305
|
+
self.reparam_conv = nn.Conv2d(
|
|
306
|
+
in_channels=self.dim,
|
|
307
|
+
out_channels=self.dim,
|
|
308
|
+
kernel_size=self.kernel_size,
|
|
309
|
+
stride=1,
|
|
310
|
+
padding=self.kernel_size // 2,
|
|
311
|
+
groups=self.dim,
|
|
312
|
+
bias=True,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
316
|
+
return self.reparam_conv(x)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class RepMixerBlock(nn.Module):
|
|
320
|
+
"""Implementation of Metaformer block with RepMixer as token mixer.
|
|
321
|
+
|
|
322
|
+
For more details on Metaformer structure, please refer to:
|
|
323
|
+
`MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
dim: int,
|
|
329
|
+
kernel_size: int = 3,
|
|
330
|
+
mlp_ratio: float = 4.0,
|
|
331
|
+
act_layer: nn.Module = nn.GELU,
|
|
332
|
+
):
|
|
333
|
+
super().__init__()
|
|
334
|
+
|
|
335
|
+
self.token_mixer = RepMixer(dim, kernel_size=kernel_size)
|
|
336
|
+
|
|
337
|
+
assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
|
|
338
|
+
mlp_ratio
|
|
339
|
+
)
|
|
340
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
341
|
+
self.convffn = ConvFFN(
|
|
342
|
+
in_channels=dim,
|
|
343
|
+
hidden_channels=mlp_hidden_dim,
|
|
344
|
+
act_layer=act_layer,
|
|
345
|
+
)
|
|
346
|
+
self.layer_scale = mx.ones((1, 1, dim))
|
|
347
|
+
|
|
348
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
349
|
+
x = self.token_mixer(x)
|
|
350
|
+
x = x + self.layer_scale * self.convffn(x)
|
|
351
|
+
return x
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def basic_blocks(
|
|
355
|
+
dim: int,
|
|
356
|
+
block_index: int,
|
|
357
|
+
num_blocks: List[int],
|
|
358
|
+
token_mixer_type: str,
|
|
359
|
+
kernel_size: int = 3,
|
|
360
|
+
mlp_ratio: float = 4.0,
|
|
361
|
+
act_layer: nn.Module = nn.GELU,
|
|
362
|
+
norm_layer: nn.Module = nn.BatchNorm,
|
|
363
|
+
):
|
|
364
|
+
blocks = CallableModuleList()
|
|
365
|
+
for _ in range(num_blocks[block_index]):
|
|
366
|
+
if token_mixer_type == "repmixer":
|
|
367
|
+
blocks.append(
|
|
368
|
+
RepMixerBlock(
|
|
369
|
+
dim,
|
|
370
|
+
kernel_size=kernel_size,
|
|
371
|
+
mlp_ratio=mlp_ratio,
|
|
372
|
+
act_layer=act_layer,
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
elif token_mixer_type == "attention":
|
|
376
|
+
blocks.append(
|
|
377
|
+
AttentionBlock(
|
|
378
|
+
dim,
|
|
379
|
+
mlp_ratio=mlp_ratio,
|
|
380
|
+
act_layer=act_layer,
|
|
381
|
+
norm_layer=norm_layer,
|
|
382
|
+
)
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
raise ValueError(
|
|
386
|
+
"Token mixer type: {} not supported".format(token_mixer_type)
|
|
387
|
+
)
|
|
388
|
+
return blocks
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def build_fast_vit_network(config: VisionConfig):
|
|
392
|
+
network = []
|
|
393
|
+
for i in range(len(config.layers)):
|
|
394
|
+
spatial_shape = config.pos_embs_shapes[i]
|
|
395
|
+
if spatial_shape is not None:
|
|
396
|
+
position_embeddings = RepCPE(
|
|
397
|
+
in_channels=config.embed_dims[i],
|
|
398
|
+
embed_dim=config.embed_dims[i],
|
|
399
|
+
spatial_shape=spatial_shape,
|
|
400
|
+
)
|
|
401
|
+
network.append(position_embeddings)
|
|
402
|
+
|
|
403
|
+
stage = basic_blocks(
|
|
404
|
+
config.embed_dims[i],
|
|
405
|
+
i,
|
|
406
|
+
config.layers,
|
|
407
|
+
token_mixer_type=config.token_mixers[i],
|
|
408
|
+
kernel_size=config.repmixer_kernel_size,
|
|
409
|
+
mlp_ratio=config.mlp_ratios[i],
|
|
410
|
+
norm_layer=LayerNormChannel,
|
|
411
|
+
)
|
|
412
|
+
network.append(stage)
|
|
413
|
+
|
|
414
|
+
if i >= len(config.layers) - 1:
|
|
415
|
+
break
|
|
416
|
+
|
|
417
|
+
# Patch merging/downsampling between stages.
|
|
418
|
+
if config.downsamples[i] or config.embed_dims[i] != config.embed_dims[i + 1]:
|
|
419
|
+
network.append(
|
|
420
|
+
PatchEmbed(
|
|
421
|
+
patch_size=config.down_patch_size,
|
|
422
|
+
stride=config.down_stride,
|
|
423
|
+
in_channels=config.embed_dims[i],
|
|
424
|
+
embed_dim=config.embed_dims[i + 1],
|
|
425
|
+
)
|
|
426
|
+
)
|
|
427
|
+
return network
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
class SEBlock(nn.Module):
|
|
431
|
+
"""Squeeze and Excite module.
|
|
432
|
+
|
|
433
|
+
MLX implementation of `Squeeze-and-Excitation Networks` -
|
|
434
|
+
https://arxiv.org/pdf/1709.01507.pdf
|
|
435
|
+
"""
|
|
436
|
+
|
|
437
|
+
def __init__(self, in_channels: int, rd_ratio: float = 0.0625):
|
|
438
|
+
"""Construct a Squeeze and Excite Module.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
in_channels: Number of input channels.
|
|
442
|
+
rd_ratio: Input channel reduction ratio.
|
|
443
|
+
"""
|
|
444
|
+
super().__init__()
|
|
445
|
+
self.reduce = nn.Conv2d(
|
|
446
|
+
in_channels=in_channels,
|
|
447
|
+
out_channels=int(in_channels * rd_ratio),
|
|
448
|
+
kernel_size=1,
|
|
449
|
+
stride=1,
|
|
450
|
+
bias=True,
|
|
451
|
+
)
|
|
452
|
+
self.expand = nn.Conv2d(
|
|
453
|
+
in_channels=int(in_channels * rd_ratio),
|
|
454
|
+
out_channels=in_channels,
|
|
455
|
+
kernel_size=1,
|
|
456
|
+
stride=1,
|
|
457
|
+
bias=True,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
def __call__(self, inputs: mx.array) -> mx.array:
|
|
461
|
+
_, h, w, c = inputs.shape
|
|
462
|
+
x = nn.AvgPool2d(kernel_size=[h, w])(inputs)
|
|
463
|
+
x = self.reduce(x)
|
|
464
|
+
x = nn.layers.relu(x)
|
|
465
|
+
x = self.expand(x)
|
|
466
|
+
x = mx.sigmoid(x)
|
|
467
|
+
x = x.reshape(-1, 1, 1, c)
|
|
468
|
+
return inputs * x
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class MobileOneBlock(nn.Module):
|
|
472
|
+
"""MobileOne building block.
|
|
473
|
+
|
|
474
|
+
This implementation only uses the inference time CNN architecture and uses FastViTHD conventions.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
in_channels: int,
|
|
480
|
+
out_channels: int,
|
|
481
|
+
kernel_size: int,
|
|
482
|
+
stride: int = 1,
|
|
483
|
+
padding: int = 0,
|
|
484
|
+
dilation: int = 1,
|
|
485
|
+
groups: int = 1,
|
|
486
|
+
use_se: bool = False,
|
|
487
|
+
):
|
|
488
|
+
super().__init__()
|
|
489
|
+
self.groups = groups
|
|
490
|
+
self.stride = stride
|
|
491
|
+
self.padding = padding
|
|
492
|
+
self.dilation = dilation
|
|
493
|
+
self.kernel_size = kernel_size
|
|
494
|
+
self.in_channels = in_channels
|
|
495
|
+
self.out_channels = out_channels
|
|
496
|
+
|
|
497
|
+
# Check if SE-ReLU is requested
|
|
498
|
+
if use_se:
|
|
499
|
+
self.se = SEBlock(out_channels)
|
|
500
|
+
else:
|
|
501
|
+
self.se = nn.Identity()
|
|
502
|
+
|
|
503
|
+
self.activation = nn.GELU()
|
|
504
|
+
self.reparam_conv = nn.Conv2d(
|
|
505
|
+
in_channels=in_channels,
|
|
506
|
+
out_channels=out_channels,
|
|
507
|
+
kernel_size=kernel_size,
|
|
508
|
+
stride=stride,
|
|
509
|
+
padding=padding,
|
|
510
|
+
dilation=dilation,
|
|
511
|
+
groups=groups,
|
|
512
|
+
bias=True,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
516
|
+
return self.activation(self.se(self.reparam_conv(x)))
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class ConvolutionalStem(nn.Module):
|
|
520
|
+
def __init__(self, config: VisionConfig):
|
|
521
|
+
super().__init__()
|
|
522
|
+
in_channels = 3
|
|
523
|
+
out_channels = config.embed_dims[0]
|
|
524
|
+
self.blocks = CallableModuleList(
|
|
525
|
+
[
|
|
526
|
+
MobileOneBlock(
|
|
527
|
+
in_channels=in_channels,
|
|
528
|
+
out_channels=out_channels,
|
|
529
|
+
kernel_size=3,
|
|
530
|
+
stride=2,
|
|
531
|
+
padding=1,
|
|
532
|
+
groups=1,
|
|
533
|
+
),
|
|
534
|
+
MobileOneBlock(
|
|
535
|
+
in_channels=out_channels,
|
|
536
|
+
out_channels=out_channels,
|
|
537
|
+
kernel_size=3,
|
|
538
|
+
stride=2,
|
|
539
|
+
padding=1,
|
|
540
|
+
groups=out_channels,
|
|
541
|
+
),
|
|
542
|
+
MobileOneBlock(
|
|
543
|
+
in_channels=out_channels,
|
|
544
|
+
out_channels=out_channels,
|
|
545
|
+
kernel_size=1,
|
|
546
|
+
stride=1,
|
|
547
|
+
padding=0,
|
|
548
|
+
groups=1,
|
|
549
|
+
),
|
|
550
|
+
]
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
554
|
+
return self.blocks(x)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class FastViTHDModel(nn.Module):
|
|
558
|
+
"""
|
|
559
|
+
Based on https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/mci.py
|
|
560
|
+
Hardcoded, for now, for:
|
|
561
|
+
- FastViTHD variant
|
|
562
|
+
- Use inference_mode (i.e., modules contain the convolutional reparameterized versions of the architecture)
|
|
563
|
+
"""
|
|
564
|
+
|
|
565
|
+
def __init__(self, config: VisionConfig):
|
|
566
|
+
super().__init__()
|
|
567
|
+
if config.pos_embs_shapes is None:
|
|
568
|
+
config.pos_embs_shapes = [None] * len(config.layers)
|
|
569
|
+
self.config = config
|
|
570
|
+
|
|
571
|
+
# We follow the nomenclature from mci.py
|
|
572
|
+
self.patch_embed = ConvolutionalStem(config)
|
|
573
|
+
self.network = build_fast_vit_network(config)
|
|
574
|
+
self.conv_exp = MobileOneBlock(
|
|
575
|
+
in_channels=config.embed_dims[-1],
|
|
576
|
+
out_channels=int(config.embed_dims[-1] * config.cls_ratio),
|
|
577
|
+
kernel_size=3,
|
|
578
|
+
stride=1,
|
|
579
|
+
padding=1,
|
|
580
|
+
groups=config.embed_dims[-1],
|
|
581
|
+
use_se=True,
|
|
582
|
+
)
|
|
583
|
+
self.head = nn.Linear(
|
|
584
|
+
int(config.embed_dims[-1] * config.cls_ratio), config.num_classes
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
def __call__(
|
|
588
|
+
self,
|
|
589
|
+
x: mx.array,
|
|
590
|
+
output_hidden_states: Optional[bool] = None,
|
|
591
|
+
):
|
|
592
|
+
x = self.patch_embed(x)
|
|
593
|
+
|
|
594
|
+
encoder_states = (x,) if output_hidden_states else None
|
|
595
|
+
for layer in self.network:
|
|
596
|
+
x = layer(x)
|
|
597
|
+
if output_hidden_states:
|
|
598
|
+
encoder_states = encoder_states + (x,)
|
|
599
|
+
|
|
600
|
+
x = self.conv_exp(x)
|
|
601
|
+
cls_out = self.head(x)
|
|
602
|
+
|
|
603
|
+
return cls_out, x, encoder_states
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class GlobalPool2D(nn.Module):
|
|
607
|
+
"""This class implements global pooling with linear projection."""
|
|
608
|
+
|
|
609
|
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
|
610
|
+
super().__init__()
|
|
611
|
+
self.proj = mx.zeros((in_dim, out_dim))
|
|
612
|
+
|
|
613
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
614
|
+
assert (
|
|
615
|
+
x.ndim == 4
|
|
616
|
+
), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
|
|
617
|
+
x.shape
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# [batch, in_height, in_width, in_dim] --> [batch, in_dim]
|
|
621
|
+
x = x.mean(axis=[1, 2])
|
|
622
|
+
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
|
623
|
+
x = x @ self.proj
|
|
624
|
+
return x
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
class VisionModel(nn.Module):
|
|
628
|
+
def __init__(self, config: VisionConfig):
|
|
629
|
+
super().__init__()
|
|
630
|
+
|
|
631
|
+
self.model_type = config.model_type
|
|
632
|
+
if self.model_type not in ["llava_qwen2", "fastvlm"]:
|
|
633
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
634
|
+
|
|
635
|
+
self.vision_model = FastViTHDModel(config)
|
|
636
|
+
|
|
637
|
+
# Replace projection head, same as in
|
|
638
|
+
# https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/__init__.py#L49
|
|
639
|
+
if config.projection_dim is not None:
|
|
640
|
+
in_dim = int(config.embed_dims[-1] * config.cls_ratio)
|
|
641
|
+
self.vision_model.head = GlobalPool2D(in_dim, config.projection_dim)
|
|
642
|
+
|
|
643
|
+
def __call__(
|
|
644
|
+
self, x: mx.array, output_hidden_states: Optional[bool] = None
|
|
645
|
+
) -> mx.array:
|
|
646
|
+
return self.vision_model(x, output_hidden_states)
|
|
647
|
+
|
|
648
|
+
def sanitize(self, weights):
|
|
649
|
+
# Only transpose during conversion from transformers
|
|
650
|
+
W, C = weights[
|
|
651
|
+
"vision_tower.vision_model.patch_embed.blocks.1.reparam_conv.weight"
|
|
652
|
+
].shape[-2:]
|
|
653
|
+
skip_transpose = W > C
|
|
654
|
+
|
|
655
|
+
def is_conv(k):
|
|
656
|
+
if skip_transpose:
|
|
657
|
+
return False
|
|
658
|
+
if ".reparam_conv.weight" in k:
|
|
659
|
+
return True
|
|
660
|
+
if ".conv.weight" in k:
|
|
661
|
+
return True
|
|
662
|
+
if ".fc1.weight" in k:
|
|
663
|
+
return True
|
|
664
|
+
if ".fc2.weight" in k:
|
|
665
|
+
return True
|
|
666
|
+
if ".lkb_reparam.weight" in k:
|
|
667
|
+
return True
|
|
668
|
+
if ".reduce.weight" in k:
|
|
669
|
+
return True
|
|
670
|
+
if ".expand.weight" in k:
|
|
671
|
+
return True
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
sanitized_weights = {}
|
|
675
|
+
for k, v in weights.items():
|
|
676
|
+
if is_conv(k):
|
|
677
|
+
# PyTorch conv2d weight tensors have shape:
|
|
678
|
+
# [out_channels, in_channels, kH, KW]
|
|
679
|
+
# MLX conv2d expects the weight be of shape:
|
|
680
|
+
# [out_channels, kH, KW, in_channels]
|
|
681
|
+
if v.ndim == 4:
|
|
682
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
683
|
+
else:
|
|
684
|
+
sanitized_weights[k] = v
|
|
685
|
+
elif "layer_scale" in k and not skip_transpose:
|
|
686
|
+
sanitized_weights[k] = v.transpose(1, 2, 0)
|
|
687
|
+
elif "num_batches_tracked" in k:
|
|
688
|
+
# I don't think we need this
|
|
689
|
+
continue
|
|
690
|
+
else:
|
|
691
|
+
sanitized_weights[k] = v
|
|
692
|
+
return sanitized_weights
|