fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ..base import interpolate
|
|
8
|
+
from .config import AdapterConfig, VisionConfig, VitConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _gelu_from_name(name: str) -> nn.Module:
|
|
12
|
+
if name == "gelu_pytorch_tanh":
|
|
13
|
+
return nn.GELU(approx="fast")
|
|
14
|
+
return nn.GELU(approx="fast")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ViTMLP(nn.Module):
|
|
18
|
+
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=True)
|
|
21
|
+
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=True)
|
|
22
|
+
self.act = _gelu_from_name(hidden_act)
|
|
23
|
+
|
|
24
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
25
|
+
return self.w2(self.act(self.w1(x)))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ViTMultiHeadDotProductAttention(nn.Module):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
hidden_size: int,
|
|
33
|
+
num_heads: int,
|
|
34
|
+
num_key_value_heads: int,
|
|
35
|
+
head_dim: int,
|
|
36
|
+
input_dim: Optional[int] = None,
|
|
37
|
+
use_bias: bool = True,
|
|
38
|
+
float32_attention: bool = True,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.hidden_size = hidden_size
|
|
42
|
+
self.num_heads = num_heads
|
|
43
|
+
self.num_key_value_heads = num_key_value_heads
|
|
44
|
+
self.head_dim = head_dim
|
|
45
|
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
46
|
+
self.scale = head_dim**-0.5
|
|
47
|
+
self.float32_attention = float32_attention
|
|
48
|
+
|
|
49
|
+
input_dim = input_dim or hidden_size
|
|
50
|
+
self.wq = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=use_bias)
|
|
51
|
+
self.wk = nn.Linear(
|
|
52
|
+
input_dim, self.num_key_value_heads * self.head_dim, bias=use_bias
|
|
53
|
+
)
|
|
54
|
+
self.wv = nn.Linear(
|
|
55
|
+
input_dim, self.num_key_value_heads * self.head_dim, bias=use_bias
|
|
56
|
+
)
|
|
57
|
+
self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
|
58
|
+
|
|
59
|
+
def __call__(
|
|
60
|
+
self,
|
|
61
|
+
inputs_q: mx.array,
|
|
62
|
+
inputs_kv: Optional[mx.array] = None,
|
|
63
|
+
attn_mask: Optional[mx.array] = None,
|
|
64
|
+
) -> mx.array:
|
|
65
|
+
if inputs_kv is None:
|
|
66
|
+
inputs_k = inputs_q
|
|
67
|
+
inputs_v = inputs_q
|
|
68
|
+
else:
|
|
69
|
+
inputs_k = inputs_kv
|
|
70
|
+
inputs_v = inputs_kv
|
|
71
|
+
|
|
72
|
+
xq = self.wq(inputs_q)
|
|
73
|
+
xk = self.wk(inputs_k)
|
|
74
|
+
xv = self.wv(inputs_v)
|
|
75
|
+
|
|
76
|
+
bsz, q_len, _ = xq.shape
|
|
77
|
+
_, kv_len, _ = xk.shape
|
|
78
|
+
|
|
79
|
+
xq = xq.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
|
80
|
+
xk = xk.reshape(bsz, kv_len, self.num_key_value_heads, self.head_dim)
|
|
81
|
+
xv = xv.reshape(bsz, kv_len, self.num_key_value_heads, self.head_dim)
|
|
82
|
+
|
|
83
|
+
if self.num_heads != self.num_key_value_heads:
|
|
84
|
+
xk = mx.repeat(xk, self.num_key_value_groups, axis=2)
|
|
85
|
+
xv = mx.repeat(xv, self.num_key_value_groups, axis=2)
|
|
86
|
+
|
|
87
|
+
q = xq.transpose(0, 2, 1, 3)
|
|
88
|
+
k = xk.transpose(0, 2, 1, 3)
|
|
89
|
+
v = xv.transpose(0, 2, 1, 3)
|
|
90
|
+
|
|
91
|
+
dtype = q.dtype
|
|
92
|
+
if self.float32_attention:
|
|
93
|
+
q = q.astype(mx.float32)
|
|
94
|
+
k = k.astype(mx.float32)
|
|
95
|
+
v = v.astype(mx.float32)
|
|
96
|
+
|
|
97
|
+
scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) * self.scale
|
|
98
|
+
if attn_mask is not None:
|
|
99
|
+
scores = mx.where(
|
|
100
|
+
attn_mask,
|
|
101
|
+
scores,
|
|
102
|
+
mx.full(scores.shape, vals=-1e9, dtype=scores.dtype),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
weights = mx.softmax(scores, axis=-1)
|
|
106
|
+
out = mx.matmul(weights, v).astype(dtype)
|
|
107
|
+
out = out.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
|
108
|
+
return self.wo(out)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class Molmo2VisionBlock(nn.Module):
|
|
112
|
+
def __init__(self, config: VitConfig):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.attention = ViTMultiHeadDotProductAttention(
|
|
115
|
+
hidden_size=config.hidden_size,
|
|
116
|
+
num_heads=config.num_attention_heads,
|
|
117
|
+
num_key_value_heads=config.num_key_value_heads,
|
|
118
|
+
head_dim=config.head_dim,
|
|
119
|
+
float32_attention=config.float32_attention,
|
|
120
|
+
input_dim=config.hidden_size,
|
|
121
|
+
)
|
|
122
|
+
self.feed_forward = ViTMLP(
|
|
123
|
+
config.hidden_size, config.intermediate_size, config.hidden_act
|
|
124
|
+
)
|
|
125
|
+
self.attention_norm = nn.LayerNorm(
|
|
126
|
+
config.hidden_size, eps=config.layer_norm_eps
|
|
127
|
+
)
|
|
128
|
+
self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
129
|
+
|
|
130
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
131
|
+
x = x + self.attention(self.attention_norm(x))
|
|
132
|
+
x = x + self.feed_forward(self.ffn_norm(x))
|
|
133
|
+
return x
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Molmo2VisionTransformer(nn.Module):
|
|
137
|
+
def __init__(self, config: VitConfig):
|
|
138
|
+
super().__init__()
|
|
139
|
+
self.config = config
|
|
140
|
+
self.num_prefix_tokens = 0
|
|
141
|
+
|
|
142
|
+
self.positional_embedding = mx.zeros((config.image_num_pos, config.hidden_size))
|
|
143
|
+
patch_dim = config.image_patch_size * config.image_patch_size * 3
|
|
144
|
+
self.patch_embedding = nn.Linear(patch_dim, config.hidden_size, bias=True)
|
|
145
|
+
self.transformer = [
|
|
146
|
+
Molmo2VisionBlock(config) for _ in range(config.num_hidden_layers)
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
def add_pos_emb(self, x: mx.array, patch_num: Tuple[int, int]) -> mx.array:
|
|
150
|
+
pos_emb = self.positional_embedding
|
|
151
|
+
pos_emb_size = int(pos_emb.shape[0] ** 0.5)
|
|
152
|
+
pos_emb = mx.reshape(pos_emb, (pos_emb_size, pos_emb_size, pos_emb.shape[1]))
|
|
153
|
+
|
|
154
|
+
patch_h, patch_w = patch_num
|
|
155
|
+
if pos_emb.shape[0] != patch_h or pos_emb.shape[1] != patch_w:
|
|
156
|
+
pos_emb = mx.transpose(pos_emb[None, ...], (0, 3, 1, 2))
|
|
157
|
+
pos_emb = interpolate(
|
|
158
|
+
pos_emb, (patch_h, patch_w), mode="cubic", align_corners=False
|
|
159
|
+
)
|
|
160
|
+
pos_emb = mx.transpose(pos_emb, (0, 2, 3, 1))[0]
|
|
161
|
+
|
|
162
|
+
pos_emb = mx.reshape(pos_emb, (-1, pos_emb.shape[-1]))
|
|
163
|
+
return x + pos_emb[None, :, :].astype(x.dtype)
|
|
164
|
+
|
|
165
|
+
def __call__(
|
|
166
|
+
self,
|
|
167
|
+
x: mx.array,
|
|
168
|
+
patch_num: Optional[Tuple[int, int]] = None,
|
|
169
|
+
):
|
|
170
|
+
if patch_num is None:
|
|
171
|
+
patch_num = self.config.image_num_patch
|
|
172
|
+
|
|
173
|
+
x = self.patch_embedding(x)
|
|
174
|
+
x = self.add_pos_emb(x, patch_num)
|
|
175
|
+
|
|
176
|
+
hidden_states = []
|
|
177
|
+
for block in self.transformer:
|
|
178
|
+
x = block(x)
|
|
179
|
+
hidden_states.append(x)
|
|
180
|
+
return hidden_states
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class ImageProjectorMLP(nn.Module):
|
|
184
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
|
|
187
|
+
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
|
|
188
|
+
self.w3 = nn.Linear(input_dim, hidden_dim, bias=False)
|
|
189
|
+
|
|
190
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
191
|
+
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class VisionModel(nn.Module):
|
|
195
|
+
def __init__(self, config: VisionConfig):
|
|
196
|
+
super().__init__()
|
|
197
|
+
self.config = config
|
|
198
|
+
self.model_type = "molmo2"
|
|
199
|
+
self.vit_config: VitConfig = config.vit_config
|
|
200
|
+
self.adapter_config: AdapterConfig = config.adapter_config
|
|
201
|
+
|
|
202
|
+
self.image_vit = Molmo2VisionTransformer(self.vit_config)
|
|
203
|
+
|
|
204
|
+
self.vit_layers = []
|
|
205
|
+
for layer in self.adapter_config.vit_layers:
|
|
206
|
+
self.vit_layers.append(
|
|
207
|
+
layer if layer >= 0 else layer + self.vit_config.num_hidden_layers
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
pool_dim = self.vit_config.hidden_size * len(self.vit_layers)
|
|
211
|
+
|
|
212
|
+
self.image_pooling_2d = ViTMultiHeadDotProductAttention(
|
|
213
|
+
hidden_size=self.adapter_config.hidden_size,
|
|
214
|
+
num_heads=self.adapter_config.num_attention_heads,
|
|
215
|
+
num_key_value_heads=self.adapter_config.num_key_value_heads,
|
|
216
|
+
head_dim=self.adapter_config.head_dim,
|
|
217
|
+
input_dim=pool_dim,
|
|
218
|
+
float32_attention=self.adapter_config.float32_attention,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self.image_projector = ImageProjectorMLP(
|
|
222
|
+
self.adapter_config.hidden_size,
|
|
223
|
+
self.adapter_config.intermediate_size,
|
|
224
|
+
self.adapter_config.text_hidden_size,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def encode_image(self, images: mx.array) -> mx.array:
|
|
228
|
+
batch_size, num_crops, num_patch, patch_dim = images.shape
|
|
229
|
+
images = images.reshape(batch_size * num_crops, num_patch, patch_dim)
|
|
230
|
+
hidden_states = self.image_vit(images)
|
|
231
|
+
|
|
232
|
+
features = [hidden_states[layer] for layer in self.vit_layers]
|
|
233
|
+
image_features = mx.concatenate(features, axis=-1)
|
|
234
|
+
image_features = image_features.reshape(batch_size, num_crops, num_patch, -1)
|
|
235
|
+
return image_features
|
|
236
|
+
|
|
237
|
+
def __call__(
|
|
238
|
+
self,
|
|
239
|
+
images: mx.array,
|
|
240
|
+
pooled_patches_idx: mx.array,
|
|
241
|
+
) -> mx.array:
|
|
242
|
+
batch_size, num_crops = images.shape[:2]
|
|
243
|
+
|
|
244
|
+
image_features = self.encode_image(images)
|
|
245
|
+
dim = image_features.shape[-1]
|
|
246
|
+
|
|
247
|
+
valid = pooled_patches_idx >= 0
|
|
248
|
+
valid_token = mx.any(valid, axis=-1)
|
|
249
|
+
|
|
250
|
+
flat_features = image_features.reshape(batch_size, -1, dim)
|
|
251
|
+
idx = mx.clip(pooled_patches_idx, 0, None)
|
|
252
|
+
batch_idx = mx.arange(batch_size)[:, None, None]
|
|
253
|
+
batch_idx = mx.broadcast_to(batch_idx, idx.shape)
|
|
254
|
+
|
|
255
|
+
gathered = flat_features[mx.reshape(batch_idx, (-1,)), mx.reshape(idx, (-1,))]
|
|
256
|
+
to_pool = gathered.reshape(
|
|
257
|
+
pooled_patches_idx.shape[0],
|
|
258
|
+
pooled_patches_idx.shape[1],
|
|
259
|
+
pooled_patches_idx.shape[2],
|
|
260
|
+
dim,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
to_pool = to_pool * valid[..., None].astype(to_pool.dtype)
|
|
264
|
+
to_pool = to_pool.reshape(-1, pooled_patches_idx.shape[-1], dim)
|
|
265
|
+
|
|
266
|
+
if self.adapter_config.pooling_attention_mask:
|
|
267
|
+
attn_mask = valid.reshape(-1, 1, 1, valid.shape[-1])
|
|
268
|
+
denom = valid.reshape(-1, to_pool.shape[-2]).astype(mx.float32).sum(axis=-1)
|
|
269
|
+
denom = mx.where(denom == 0, mx.ones_like(denom), denom)
|
|
270
|
+
query = to_pool.sum(axis=-2, keepdims=True) / denom[:, None, None].astype(
|
|
271
|
+
to_pool.dtype
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
attn_mask = None
|
|
275
|
+
query = mx.mean(to_pool, axis=-2, keepdims=True)
|
|
276
|
+
|
|
277
|
+
pooled = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
|
|
278
|
+
pooled = pooled.reshape(batch_size, -1, pooled.shape[-1])
|
|
279
|
+
pooled = self.image_projector(pooled)
|
|
280
|
+
|
|
281
|
+
pooled = pooled.reshape(-1, pooled.shape[-1])
|
|
282
|
+
|
|
283
|
+
# MLX doesn't support boolean indexing, so convert to integer indices
|
|
284
|
+
valid_flat = np.array(valid_token).flatten()
|
|
285
|
+
valid_indices = np.where(valid_flat)[0]
|
|
286
|
+
return pooled[mx.array(valid_indices)]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .config import ModelConfig, TextConfig, VisionConfig
|
|
2
|
+
from .image_crops import (
|
|
3
|
+
adaptive_avg_pool2d,
|
|
4
|
+
overlap_crop_image,
|
|
5
|
+
reconstruct_from_crops,
|
|
6
|
+
select_tiling,
|
|
7
|
+
)
|
|
8
|
+
from .moondream2 import ImageProcessor, Model
|
|
9
|
+
from .vision import VisionModel
|
|
10
|
+
from .language import LanguageModel
|
|
11
|
+
from . import processing_moondream # Registers the AutoProcessor patch
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ..base import BaseModelConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class TextConfig(BaseModelConfig):
|
|
10
|
+
model_type: str = "phi"
|
|
11
|
+
hidden_size: int = 2048
|
|
12
|
+
num_hidden_layers: int = 24
|
|
13
|
+
intermediate_size: int = 8192
|
|
14
|
+
num_attention_heads: int = 32
|
|
15
|
+
num_key_value_heads: int = 32
|
|
16
|
+
vocab_size: int = 51200
|
|
17
|
+
max_position_embeddings: int = 2048
|
|
18
|
+
rope_theta: float = 10000.0
|
|
19
|
+
layer_norm_eps: float = 1e-5
|
|
20
|
+
# Moondream uses partial RoPE - only first 32 dims (head_dim // 2)
|
|
21
|
+
partial_rotary_factor: float = 0.5
|
|
22
|
+
# Prefix attention length: BOS (1) + image patches (729) = 730
|
|
23
|
+
prefix_attn_len: int = 730
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class VisionConfig(BaseModelConfig):
|
|
28
|
+
model_type: str = "moondream_vision"
|
|
29
|
+
hidden_size: int = 1152 # enc_dim
|
|
30
|
+
num_hidden_layers: int = 27 # enc_n_layers
|
|
31
|
+
intermediate_size: int = 4304 # enc_ff_dim
|
|
32
|
+
num_attention_heads: int = 16 # enc_n_heads
|
|
33
|
+
image_size: int = 378 # crop_size
|
|
34
|
+
patch_size: int = 14 # enc_patch_size
|
|
35
|
+
num_channels: int = 3 # in_channels
|
|
36
|
+
layer_norm_eps: float = 1e-5
|
|
37
|
+
# Multi-crop settings (for future full implementation)
|
|
38
|
+
max_crops: int = 12
|
|
39
|
+
overlap_margin: int = 4
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ModelConfig(BaseModelConfig):
|
|
44
|
+
text_config: TextConfig = None
|
|
45
|
+
vision_config: VisionConfig = None
|
|
46
|
+
model_type: str = "moondream1"
|
|
47
|
+
# Projection MLP inner dimension
|
|
48
|
+
proj_inner_dim: int = 8192
|
|
49
|
+
# Image features are prepended after BOS token
|
|
50
|
+
image_token_index: int = -200
|
|
51
|
+
vocab_size: int = 51200
|
|
52
|
+
# Prefix attention length: BOS (1) + image patches (729) = 730
|
|
53
|
+
prefix_attn_len: int = 730
|
|
54
|
+
# Token IDs (EOS and BOS are the same for moondream)
|
|
55
|
+
eos_token_id: int = 0
|
|
56
|
+
bos_token_id: int = 0
|
|
57
|
+
|
|
58
|
+
def __post_init__(self):
|
|
59
|
+
if self.text_config is None:
|
|
60
|
+
self.text_config = TextConfig()
|
|
61
|
+
if self.vision_config is None:
|
|
62
|
+
self.vision_config = VisionConfig()
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_dict(cls, params):
|
|
66
|
+
# Extract nested configs
|
|
67
|
+
text_config_dict = params.get("text_config", {})
|
|
68
|
+
vision_config_dict = params.get("vision_config", {})
|
|
69
|
+
|
|
70
|
+
# If text_config is empty, try to get from root level
|
|
71
|
+
if not text_config_dict:
|
|
72
|
+
text_config_dict = {
|
|
73
|
+
k: v
|
|
74
|
+
for k, v in params.items()
|
|
75
|
+
if k in inspect.signature(TextConfig).parameters
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
# Create nested config objects
|
|
79
|
+
text_config = TextConfig.from_dict(text_config_dict)
|
|
80
|
+
vision_config = VisionConfig.from_dict(vision_config_dict)
|
|
81
|
+
|
|
82
|
+
# Build the main config
|
|
83
|
+
return cls(
|
|
84
|
+
text_config=text_config,
|
|
85
|
+
vision_config=vision_config,
|
|
86
|
+
**{
|
|
87
|
+
k: v
|
|
88
|
+
for k, v in params.items()
|
|
89
|
+
if k in inspect.signature(cls).parameters
|
|
90
|
+
and k not in ("text_config", "vision_config")
|
|
91
|
+
},
|
|
92
|
+
)
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-crop image processing utilities for Moondream2.
|
|
3
|
+
|
|
4
|
+
Reference implementation: moondream2/image_crops.py
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
|
|
10
|
+
import mlx.core as mx
|
|
11
|
+
import numpy as np
|
|
12
|
+
from PIL import Image
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def select_tiling(
|
|
16
|
+
height: int, width: int, crop_size: int, max_crops: int
|
|
17
|
+
) -> Tuple[int, int]:
|
|
18
|
+
"""
|
|
19
|
+
Determine the optimal number of tiles to cover an image with overlapping crops.
|
|
20
|
+
|
|
21
|
+
Ported from HF reference: moondream2/image_crops.py:17-50
|
|
22
|
+
"""
|
|
23
|
+
if height <= crop_size or width <= crop_size:
|
|
24
|
+
return (1, 1)
|
|
25
|
+
|
|
26
|
+
# Minimum required tiles in each dimension
|
|
27
|
+
min_h = math.ceil(height / crop_size)
|
|
28
|
+
min_w = math.ceil(width / crop_size)
|
|
29
|
+
|
|
30
|
+
# If minimum required tiles exceed max_crops, return proportional distribution
|
|
31
|
+
if min_h * min_w > max_crops:
|
|
32
|
+
ratio = math.sqrt(max_crops / (min_h * min_w))
|
|
33
|
+
return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
|
|
34
|
+
|
|
35
|
+
# Perfect aspect-ratio tiles that satisfy max_crops
|
|
36
|
+
h_tiles = math.floor(math.sqrt(max_crops * height / width))
|
|
37
|
+
w_tiles = math.floor(math.sqrt(max_crops * width / height))
|
|
38
|
+
|
|
39
|
+
# Ensure we meet minimum tile requirements
|
|
40
|
+
h_tiles = max(h_tiles, min_h)
|
|
41
|
+
w_tiles = max(w_tiles, min_w)
|
|
42
|
+
|
|
43
|
+
# If we exceeded max_crops, scale down the larger dimension
|
|
44
|
+
if h_tiles * w_tiles > max_crops:
|
|
45
|
+
if w_tiles > h_tiles:
|
|
46
|
+
w_tiles = math.floor(max_crops / h_tiles)
|
|
47
|
+
else:
|
|
48
|
+
h_tiles = math.floor(max_crops / w_tiles)
|
|
49
|
+
|
|
50
|
+
return (max(1, h_tiles), max(1, w_tiles))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def overlap_crop_image(
|
|
54
|
+
image: np.ndarray,
|
|
55
|
+
max_crops: int = 12,
|
|
56
|
+
overlap_margin: int = 4,
|
|
57
|
+
base_size: Tuple[int, int] = (378, 378),
|
|
58
|
+
patch_size: int = 14,
|
|
59
|
+
) -> Tuple[np.ndarray, Tuple[int, int]]:
|
|
60
|
+
"""
|
|
61
|
+
Create overlapping crops from an image for multi-scale processing.
|
|
62
|
+
|
|
63
|
+
Ported from HF reference: moondream2/image_crops.py:58-167
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
image: Input image as numpy array [H, W, C] in range [0, 255]
|
|
67
|
+
max_crops: Maximum number of local crops allowed (default 12)
|
|
68
|
+
overlap_margin: Number of patches to overlap between adjacent crops (default 4)
|
|
69
|
+
base_size: Size of each crop (default (378, 378))
|
|
70
|
+
patch_size: Size of each patch for the vision encoder (default 14)
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
crops: numpy array [n_crops, H, W, C] - crops[0] is global, rest are local
|
|
74
|
+
tiling: (h_tiles, w_tiles) tuple describing the local crop layout
|
|
75
|
+
"""
|
|
76
|
+
original_h, original_w = image.shape[:2]
|
|
77
|
+
|
|
78
|
+
# Convert margin from patch units to pixels
|
|
79
|
+
margin_pixels = patch_size * overlap_margin
|
|
80
|
+
total_margin_pixels = margin_pixels * 2 # Both sides
|
|
81
|
+
|
|
82
|
+
# Calculate crop parameters
|
|
83
|
+
crop_patches = base_size[0] // patch_size # patches per crop dimension
|
|
84
|
+
crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
|
|
85
|
+
crop_window_size = crop_window_patches * patch_size # usable size in pixels
|
|
86
|
+
|
|
87
|
+
# Determine tiling using margin-adjusted dimensions and effective crop size
|
|
88
|
+
tiling = select_tiling(
|
|
89
|
+
original_h - total_margin_pixels,
|
|
90
|
+
original_w - total_margin_pixels,
|
|
91
|
+
crop_window_size,
|
|
92
|
+
max_crops,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Pre-allocate crops
|
|
96
|
+
n_crops = tiling[0] * tiling[1] + 1 # +1 for global crop
|
|
97
|
+
crops = np.zeros(
|
|
98
|
+
(n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Resize image to fit tiling
|
|
102
|
+
target_size = (
|
|
103
|
+
tiling[0] * crop_window_size + total_margin_pixels,
|
|
104
|
+
tiling[1] * crop_window_size + total_margin_pixels,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
pil_image = Image.fromarray(image.astype(np.uint8))
|
|
108
|
+
|
|
109
|
+
# Resize for local crops
|
|
110
|
+
resized = pil_image.resize(
|
|
111
|
+
(int(target_size[1]), int(target_size[0])),
|
|
112
|
+
resample=Image.Resampling.LANCZOS,
|
|
113
|
+
)
|
|
114
|
+
image = np.asarray(resized)
|
|
115
|
+
|
|
116
|
+
# Create global crop
|
|
117
|
+
global_crop = pil_image.resize(
|
|
118
|
+
(int(base_size[1]), int(base_size[0])),
|
|
119
|
+
resample=Image.Resampling.LANCZOS,
|
|
120
|
+
)
|
|
121
|
+
crops[0] = np.asarray(global_crop)
|
|
122
|
+
|
|
123
|
+
# Extract local crops
|
|
124
|
+
for i in range(tiling[0]):
|
|
125
|
+
for j in range(tiling[1]):
|
|
126
|
+
y0 = i * crop_window_size
|
|
127
|
+
x0 = j * crop_window_size
|
|
128
|
+
|
|
129
|
+
y_end = min(y0 + base_size[0], image.shape[0])
|
|
130
|
+
x_end = min(x0 + base_size[1], image.shape[1])
|
|
131
|
+
|
|
132
|
+
crop_region = image[y0:y_end, x0:x_end]
|
|
133
|
+
crops[
|
|
134
|
+
1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
|
|
135
|
+
] = crop_region
|
|
136
|
+
|
|
137
|
+
return crops, tiling
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def reconstruct_from_crops(
|
|
141
|
+
local_features: mx.array,
|
|
142
|
+
tiling: Tuple[int, int],
|
|
143
|
+
overlap_margin: int = 4,
|
|
144
|
+
) -> mx.array:
|
|
145
|
+
"""
|
|
146
|
+
Reconstruct a unified feature map from local crop features.
|
|
147
|
+
|
|
148
|
+
This function stitches together the features from local crops,
|
|
149
|
+
handling the overlap regions by trimming interior margins.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
local_features: [n_local, 27, 27, 1152] features from local crops
|
|
153
|
+
(27x27 patches per crop, each with 1152-dim features)
|
|
154
|
+
tiling: (h_tiles, w_tiles) describing the crop layout
|
|
155
|
+
overlap_margin: Number of patches that overlap between adjacent crops (default 4)
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Reconstructed feature map [H, W, 1152] where:
|
|
159
|
+
H = h_tiles * (27 - 2*overlap_margin) + 2*overlap_margin
|
|
160
|
+
W = w_tiles * (27 - 2*overlap_margin) + 2*overlap_margin
|
|
161
|
+
"""
|
|
162
|
+
h_tiles, w_tiles = tiling
|
|
163
|
+
n_local = h_tiles * w_tiles
|
|
164
|
+
patches_per_side = 27 # 378 / 14 = 27 patches per crop side
|
|
165
|
+
hidden_size = local_features.shape[-1] # 1152
|
|
166
|
+
|
|
167
|
+
# Effective patches per crop after removing interior overlaps
|
|
168
|
+
effective_patches = patches_per_side - 2 * overlap_margin # 27 - 8 = 19
|
|
169
|
+
|
|
170
|
+
# Output feature map size
|
|
171
|
+
out_h = h_tiles * effective_patches + 2 * overlap_margin
|
|
172
|
+
out_w = w_tiles * effective_patches + 2 * overlap_margin
|
|
173
|
+
|
|
174
|
+
# Initialize output
|
|
175
|
+
# Use numpy for easier slicing, convert to mx at the end
|
|
176
|
+
local_np = np.array(local_features)
|
|
177
|
+
output = np.zeros((out_h, out_w, hidden_size), dtype=local_np.dtype)
|
|
178
|
+
|
|
179
|
+
crop_idx = 0
|
|
180
|
+
for i in range(h_tiles):
|
|
181
|
+
for j in range(w_tiles):
|
|
182
|
+
crop_features = local_np[crop_idx] # [27, 27, 1152]
|
|
183
|
+
|
|
184
|
+
# Determine which margins to keep based on position
|
|
185
|
+
top_margin = overlap_margin if i == 0 else 0
|
|
186
|
+
bottom_margin = overlap_margin if i == h_tiles - 1 else 0
|
|
187
|
+
left_margin = overlap_margin if j == 0 else 0
|
|
188
|
+
right_margin = overlap_margin if j == w_tiles - 1 else 0
|
|
189
|
+
|
|
190
|
+
# Trim interior margins
|
|
191
|
+
start_y = 0 if i == 0 else overlap_margin
|
|
192
|
+
end_y = patches_per_side if i == h_tiles - 1 else patches_per_side - overlap_margin
|
|
193
|
+
start_x = 0 if j == 0 else overlap_margin
|
|
194
|
+
end_x = patches_per_side if j == w_tiles - 1 else patches_per_side - overlap_margin
|
|
195
|
+
|
|
196
|
+
trimmed = crop_features[start_y:end_y, start_x:end_x]
|
|
197
|
+
|
|
198
|
+
# Calculate output position
|
|
199
|
+
out_y = 0 if i == 0 else (patches_per_side - overlap_margin) + (i - 1) * effective_patches
|
|
200
|
+
out_x = 0 if j == 0 else (patches_per_side - overlap_margin) + (j - 1) * effective_patches
|
|
201
|
+
|
|
202
|
+
out_h_slice = end_y - start_y
|
|
203
|
+
out_w_slice = end_x - start_x
|
|
204
|
+
|
|
205
|
+
output[out_y : out_y + out_h_slice, out_x : out_x + out_w_slice] = trimmed
|
|
206
|
+
|
|
207
|
+
crop_idx += 1
|
|
208
|
+
|
|
209
|
+
return mx.array(output)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def adaptive_avg_pool2d(
|
|
213
|
+
x: mx.array,
|
|
214
|
+
output_size: Tuple[int, int],
|
|
215
|
+
) -> mx.array:
|
|
216
|
+
"""
|
|
217
|
+
Adaptive average pooling that pools input to a fixed output size.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
x: Input tensor [H, W, C] or [C, H, W]
|
|
221
|
+
output_size: Target (H_out, W_out)
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Pooled tensor with spatial dimensions matching output_size
|
|
225
|
+
"""
|
|
226
|
+
# Assume input is [H, W, C] (channel last)
|
|
227
|
+
H, W, C = x.shape
|
|
228
|
+
out_h, out_w = output_size
|
|
229
|
+
|
|
230
|
+
if H == out_h and W == out_w:
|
|
231
|
+
return x
|
|
232
|
+
|
|
233
|
+
# Calculate kernel and stride sizes for adaptive pooling
|
|
234
|
+
# Kernel size = ceil(input_size / output_size)
|
|
235
|
+
# Stride = floor(input_size / output_size)
|
|
236
|
+
kernel_h = (H + out_h - 1) // out_h
|
|
237
|
+
kernel_w = (W + out_w - 1) // out_w
|
|
238
|
+
stride_h = H // out_h
|
|
239
|
+
stride_w = W // out_w
|
|
240
|
+
|
|
241
|
+
# Pad if necessary to ensure we can cover the output size
|
|
242
|
+
pad_h = max(0, (out_h - 1) * stride_h + kernel_h - H)
|
|
243
|
+
pad_w = max(0, (out_w - 1) * stride_w + kernel_w - W)
|
|
244
|
+
|
|
245
|
+
if pad_h > 0 or pad_w > 0:
|
|
246
|
+
# Pad with zeros
|
|
247
|
+
x = mx.pad(x, [(0, pad_h), (0, pad_w), (0, 0)])
|
|
248
|
+
|
|
249
|
+
# Perform pooling using a simple averaging approach
|
|
250
|
+
# Convert to [1, H, W, C] for batch processing
|
|
251
|
+
x = x[None, :, :, :] # [1, H, W, C]
|
|
252
|
+
|
|
253
|
+
# Use reshape and mean for pooling
|
|
254
|
+
result = np.zeros((out_h, out_w, C), dtype=np.float32)
|
|
255
|
+
x_np = np.array(x[0]) # [H, W, C]
|
|
256
|
+
|
|
257
|
+
for i in range(out_h):
|
|
258
|
+
for j in range(out_w):
|
|
259
|
+
# Calculate the input region for this output pixel
|
|
260
|
+
h_start = i * stride_h
|
|
261
|
+
h_end = min(h_start + kernel_h, x_np.shape[0])
|
|
262
|
+
w_start = j * stride_w
|
|
263
|
+
w_end = min(w_start + kernel_w, x_np.shape[1])
|
|
264
|
+
|
|
265
|
+
# Average pool
|
|
266
|
+
region = x_np[h_start:h_end, w_start:w_end, :]
|
|
267
|
+
result[i, j, :] = region.mean(axis=(0, 1))
|
|
268
|
+
|
|
269
|
+
return mx.array(result)
|