fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from .config import Qwen2EncoderConfig, VisionConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_array_shape(arr):
|
|
10
|
+
shape = arr.shape
|
|
11
|
+
|
|
12
|
+
# Check if the shape has 4 dimensions
|
|
13
|
+
if len(shape) != 4:
|
|
14
|
+
return False
|
|
15
|
+
|
|
16
|
+
out_channels, kH, KW, _ = shape
|
|
17
|
+
|
|
18
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
19
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
20
|
+
return True
|
|
21
|
+
else:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Qwen2RMSNorm(nn.Module):
|
|
26
|
+
"""RMSNorm for Qwen2 encoder."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.weight = mx.ones((hidden_size,))
|
|
31
|
+
self.variance_epsilon = eps
|
|
32
|
+
|
|
33
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
34
|
+
input_dtype = hidden_states.dtype
|
|
35
|
+
hidden_states = hidden_states.astype(mx.float32)
|
|
36
|
+
variance = mx.mean(hidden_states**2, axis=-1, keepdims=True)
|
|
37
|
+
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
|
|
38
|
+
return self.weight * hidden_states.astype(input_dtype)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Qwen2RotaryEmbedding(nn.Module):
|
|
42
|
+
"""Rotary position embeddings for Qwen2."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self, dim: int, max_position_embeddings: int = 2048, base: float = 1000000.0
|
|
46
|
+
):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.dim = dim
|
|
49
|
+
self.max_position_embeddings = max_position_embeddings
|
|
50
|
+
self.base = base
|
|
51
|
+
# Note: inv_freq is computed on-the-fly, not stored as a parameter
|
|
52
|
+
|
|
53
|
+
def __call__(
|
|
54
|
+
self, x: mx.array, position_ids: mx.array
|
|
55
|
+
) -> Tuple[mx.array, mx.array]:
|
|
56
|
+
# Compute inv_freq on the fly (not stored as parameter)
|
|
57
|
+
inv_freq = 1.0 / (
|
|
58
|
+
self.base ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# position_ids: [batch_size, seq_len]
|
|
62
|
+
# inv_freq: [head_dim // 2]
|
|
63
|
+
# We want freqs of shape [batch_size, seq_len, head_dim // 2]
|
|
64
|
+
|
|
65
|
+
# Outer product: position_ids[:, :, None] * inv_freq[None, None, :]
|
|
66
|
+
position_ids_float = position_ids[:, :, None].astype(mx.float32) # [B, S, 1]
|
|
67
|
+
inv_freq_expanded = inv_freq[None, None, :] # [1, 1, D//2]
|
|
68
|
+
freqs = position_ids_float * inv_freq_expanded # [B, S, D//2]
|
|
69
|
+
|
|
70
|
+
emb = mx.concatenate([freqs, freqs], axis=-1) # [B, S, D]
|
|
71
|
+
cos = mx.cos(emb)
|
|
72
|
+
sin = mx.sin(emb)
|
|
73
|
+
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def rotate_half(x: mx.array) -> mx.array:
|
|
77
|
+
"""Rotates half the hidden dims of the input."""
|
|
78
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
79
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
80
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def apply_rotary_pos_emb(
|
|
84
|
+
q: mx.array, k: mx.array, cos: mx.array, sin: mx.array
|
|
85
|
+
) -> Tuple[mx.array, mx.array]:
|
|
86
|
+
"""Apply rotary position embedding to query and key tensors."""
|
|
87
|
+
cos = cos[:, None, :, :]
|
|
88
|
+
sin = sin[:, None, :, :]
|
|
89
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
90
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
91
|
+
return q_embed, k_embed
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Qwen2MLP(nn.Module):
|
|
95
|
+
"""MLP for Qwen2 encoder."""
|
|
96
|
+
|
|
97
|
+
def __init__(self, config: Qwen2EncoderConfig):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.hidden_size = config.dim
|
|
100
|
+
self.intermediate_size = config.intermediate_size
|
|
101
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
102
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
103
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
104
|
+
|
|
105
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
106
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Qwen2Attention(nn.Module):
|
|
110
|
+
"""Multi-head attention for Qwen2 encoder with GQA support."""
|
|
111
|
+
|
|
112
|
+
def __init__(self, config: Qwen2EncoderConfig, layer_idx: int = 0):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.config = config
|
|
115
|
+
self.layer_idx = layer_idx
|
|
116
|
+
|
|
117
|
+
self.hidden_size = config.dim
|
|
118
|
+
self.num_heads = config.heads
|
|
119
|
+
self.head_dim = self.hidden_size // self.num_heads
|
|
120
|
+
self.num_key_value_heads = config.kv_heads
|
|
121
|
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
122
|
+
|
|
123
|
+
self.q_proj = nn.Linear(
|
|
124
|
+
self.hidden_size, self.num_heads * self.head_dim, bias=True
|
|
125
|
+
)
|
|
126
|
+
self.k_proj = nn.Linear(
|
|
127
|
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
|
128
|
+
)
|
|
129
|
+
self.v_proj = nn.Linear(
|
|
130
|
+
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
|
|
131
|
+
)
|
|
132
|
+
self.o_proj = nn.Linear(
|
|
133
|
+
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self.rotary_emb = Qwen2RotaryEmbedding(
|
|
137
|
+
self.head_dim,
|
|
138
|
+
max_position_embeddings=2048,
|
|
139
|
+
base=config.rope_theta,
|
|
140
|
+
)
|
|
141
|
+
self.scale = self.head_dim**-0.5
|
|
142
|
+
|
|
143
|
+
def __call__(
|
|
144
|
+
self,
|
|
145
|
+
hidden_states: mx.array,
|
|
146
|
+
attention_mask: Optional[mx.array] = None,
|
|
147
|
+
position_ids: Optional[mx.array] = None,
|
|
148
|
+
) -> mx.array:
|
|
149
|
+
bsz, q_len, _ = hidden_states.shape
|
|
150
|
+
|
|
151
|
+
query_states = self.q_proj(hidden_states)
|
|
152
|
+
key_states = self.k_proj(hidden_states)
|
|
153
|
+
value_states = self.v_proj(hidden_states)
|
|
154
|
+
|
|
155
|
+
query_states = query_states.reshape(
|
|
156
|
+
bsz, q_len, self.num_heads, self.head_dim
|
|
157
|
+
).transpose(0, 2, 1, 3)
|
|
158
|
+
key_states = key_states.reshape(
|
|
159
|
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
160
|
+
).transpose(0, 2, 1, 3)
|
|
161
|
+
value_states = value_states.reshape(
|
|
162
|
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
163
|
+
).transpose(0, 2, 1, 3)
|
|
164
|
+
|
|
165
|
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
166
|
+
query_states, key_states = apply_rotary_pos_emb(
|
|
167
|
+
query_states, key_states, cos, sin
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Repeat KV heads for GQA
|
|
171
|
+
if self.num_key_value_groups > 1:
|
|
172
|
+
key_states = mx.repeat(key_states, self.num_key_value_groups, axis=1)
|
|
173
|
+
value_states = mx.repeat(value_states, self.num_key_value_groups, axis=1)
|
|
174
|
+
|
|
175
|
+
attn_output = mx.fast.scaled_dot_product_attention(
|
|
176
|
+
query_states,
|
|
177
|
+
key_states,
|
|
178
|
+
value_states,
|
|
179
|
+
scale=self.scale,
|
|
180
|
+
mask=attention_mask,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
|
184
|
+
attn_output = self.o_proj(attn_output)
|
|
185
|
+
|
|
186
|
+
return attn_output
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class Qwen2DecoderLayer(nn.Module):
|
|
190
|
+
"""Transformer layer for Qwen2 encoder."""
|
|
191
|
+
|
|
192
|
+
def __init__(self, config: Qwen2EncoderConfig, layer_idx: int = 0):
|
|
193
|
+
super().__init__()
|
|
194
|
+
self.hidden_size = config.dim
|
|
195
|
+
self.self_attn = Qwen2Attention(config, layer_idx)
|
|
196
|
+
self.mlp = Qwen2MLP(config)
|
|
197
|
+
self.input_layernorm = Qwen2RMSNorm(config.dim, eps=config.rms_norm_eps)
|
|
198
|
+
self.post_attention_layernorm = Qwen2RMSNorm(
|
|
199
|
+
config.dim, eps=config.rms_norm_eps
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def __call__(
|
|
203
|
+
self,
|
|
204
|
+
hidden_states: mx.array,
|
|
205
|
+
attention_mask: Optional[mx.array] = None,
|
|
206
|
+
position_ids: Optional[mx.array] = None,
|
|
207
|
+
) -> mx.array:
|
|
208
|
+
residual = hidden_states
|
|
209
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
210
|
+
hidden_states = self.self_attn(
|
|
211
|
+
hidden_states=hidden_states,
|
|
212
|
+
attention_mask=attention_mask,
|
|
213
|
+
position_ids=position_ids,
|
|
214
|
+
)
|
|
215
|
+
hidden_states = residual + hidden_states
|
|
216
|
+
|
|
217
|
+
residual = hidden_states
|
|
218
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
219
|
+
hidden_states = self.mlp(hidden_states)
|
|
220
|
+
hidden_states = residual + hidden_states
|
|
221
|
+
|
|
222
|
+
return hidden_states
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class Qwen2Decoder2Encoder(nn.Module):
|
|
226
|
+
"""Qwen2-based decoder used as encoder for vision features.
|
|
227
|
+
|
|
228
|
+
Takes SAM features and processes them through Qwen2 transformer layers
|
|
229
|
+
using learnable queries to produce fixed-size output.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(self, config: Qwen2EncoderConfig):
|
|
233
|
+
super().__init__()
|
|
234
|
+
self.config = config
|
|
235
|
+
|
|
236
|
+
# Learnable queries for cross-attention
|
|
237
|
+
# query_1024: (256, dim) - for 1024x1024 images (SAM outputs 16x16=256 features)
|
|
238
|
+
# query_768: (144, dim) - for 768x768 images (SAM outputs 12x12=144 features)
|
|
239
|
+
# Initialized with zeros, will be loaded from weights
|
|
240
|
+
self.query_1024 = mx.zeros((256, config.dim))
|
|
241
|
+
self.query_768 = mx.zeros((144, config.dim))
|
|
242
|
+
|
|
243
|
+
# Transformer layers
|
|
244
|
+
self.layers = [
|
|
245
|
+
Qwen2DecoderLayer(config, layer_idx=i) for i in range(config.layers)
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
# Final layer norm
|
|
249
|
+
self.norm = Qwen2RMSNorm(config.dim, eps=config.rms_norm_eps)
|
|
250
|
+
|
|
251
|
+
def __call__(self, sam_features: mx.array) -> mx.array:
|
|
252
|
+
"""Process SAM features through Qwen2 encoder.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
sam_features: SAM encoder output of shape (B, H, W, C) where C=896
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Encoded features of shape (B, 256, C)
|
|
259
|
+
"""
|
|
260
|
+
batch_size = sam_features.shape[0]
|
|
261
|
+
|
|
262
|
+
# Flatten spatial dimensions: (B, H, W, C) -> (B, H*W, C)
|
|
263
|
+
sam_features_flat = sam_features.reshape(batch_size, -1, self.config.dim)
|
|
264
|
+
num_image_tokens = sam_features_flat.shape[1]
|
|
265
|
+
|
|
266
|
+
# Select appropriate query based on number of image tokens
|
|
267
|
+
# 256 tokens -> use query_1024 (for 1024x1024 images, SAM outputs 16x16)
|
|
268
|
+
# 144 tokens -> use query_768 (for 768x768 images, SAM outputs 12x12)
|
|
269
|
+
if num_image_tokens == 256:
|
|
270
|
+
query_embed = self.query_1024
|
|
271
|
+
num_queries = 256
|
|
272
|
+
elif num_image_tokens == 144:
|
|
273
|
+
query_embed = self.query_768
|
|
274
|
+
num_queries = 144
|
|
275
|
+
else:
|
|
276
|
+
# Default to query_1024 for unexpected sizes
|
|
277
|
+
query_embed = self.query_1024
|
|
278
|
+
num_queries = 256
|
|
279
|
+
|
|
280
|
+
# Expand queries for batch: (num_queries, C) -> (B, num_queries, C)
|
|
281
|
+
queries = mx.broadcast_to(
|
|
282
|
+
query_embed[None, :, :], (batch_size, num_queries, self.config.dim)
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Concatenate: image tokens + query tokens
|
|
286
|
+
# Shape: (B, num_image_tokens + num_queries, C)
|
|
287
|
+
hidden_states = mx.concatenate([sam_features_flat, queries], axis=1)
|
|
288
|
+
seq_len = hidden_states.shape[1]
|
|
289
|
+
|
|
290
|
+
# Create mixed attention mask:
|
|
291
|
+
# - Image tokens can attend to all image tokens (bidirectional)
|
|
292
|
+
# - Image tokens CANNOT attend to query tokens (blocked)
|
|
293
|
+
# - Query tokens can attend to all image tokens
|
|
294
|
+
# - Query tokens use causal attention within queries (can attend to self + previous)
|
|
295
|
+
# Shape: (1, 1, seq_len, seq_len) - will be broadcast across batch and heads
|
|
296
|
+
mask_dtype = hidden_states.dtype
|
|
297
|
+
|
|
298
|
+
# Start with all positions blocked (large negative value)
|
|
299
|
+
mask = mx.full((seq_len, seq_len), -1e9, dtype=mx.float32)
|
|
300
|
+
|
|
301
|
+
# 1. Image tokens can attend to all image tokens (bidirectional)
|
|
302
|
+
# mask[0:num_image_tokens, 0:num_image_tokens] = 0
|
|
303
|
+
image_to_image = mx.zeros(
|
|
304
|
+
(num_image_tokens, num_image_tokens), dtype=mx.float32
|
|
305
|
+
)
|
|
306
|
+
mask = mx.concatenate(
|
|
307
|
+
[
|
|
308
|
+
mx.concatenate(
|
|
309
|
+
[image_to_image, mask[:num_image_tokens, num_image_tokens:]], axis=1
|
|
310
|
+
),
|
|
311
|
+
mask[num_image_tokens:, :],
|
|
312
|
+
],
|
|
313
|
+
axis=0,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# 2. Query tokens can attend to all image tokens
|
|
317
|
+
# mask[num_image_tokens:, 0:num_image_tokens] = 0
|
|
318
|
+
query_to_image = mx.zeros((num_queries, num_image_tokens), dtype=mx.float32)
|
|
319
|
+
mask = mx.concatenate(
|
|
320
|
+
[
|
|
321
|
+
mask[:num_image_tokens, :],
|
|
322
|
+
mx.concatenate(
|
|
323
|
+
[query_to_image, mask[num_image_tokens:, num_image_tokens:]], axis=1
|
|
324
|
+
),
|
|
325
|
+
],
|
|
326
|
+
axis=0,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# 3. Query tokens use causal attention (can attend to self + previous queries)
|
|
330
|
+
# Create lower triangular mask for query-query region
|
|
331
|
+
query_causal = mx.tril(mx.zeros((num_queries, num_queries), dtype=mx.float32))
|
|
332
|
+
query_causal = query_causal + mx.triu(
|
|
333
|
+
mx.full((num_queries, num_queries), -1e9, dtype=mx.float32), k=1
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Update query-query region in mask
|
|
337
|
+
mask = mx.concatenate(
|
|
338
|
+
[
|
|
339
|
+
mask[:, :num_image_tokens],
|
|
340
|
+
mx.concatenate(
|
|
341
|
+
[mask[:num_image_tokens, num_image_tokens:], query_causal], axis=0
|
|
342
|
+
),
|
|
343
|
+
],
|
|
344
|
+
axis=1,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Cast to input dtype and reshape for attention: (1, 1, seq_len, seq_len)
|
|
348
|
+
attention_mask = mask.astype(mask_dtype)[None, None, :, :]
|
|
349
|
+
|
|
350
|
+
# Create position IDs
|
|
351
|
+
position_ids = mx.broadcast_to(
|
|
352
|
+
mx.arange(seq_len)[None, :], (batch_size, seq_len)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Process through transformer layers
|
|
356
|
+
for layer in self.layers:
|
|
357
|
+
hidden_states = layer(
|
|
358
|
+
hidden_states,
|
|
359
|
+
attention_mask=attention_mask,
|
|
360
|
+
position_ids=position_ids,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Apply final layer norm
|
|
364
|
+
hidden_states = self.norm(hidden_states)
|
|
365
|
+
|
|
366
|
+
# Return only the query tokens (last num_queries tokens)
|
|
367
|
+
return hidden_states[:, -num_queries:, :]
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class VisionModel(nn.Module):
|
|
371
|
+
"""Vision model for DeepSeek-OCR-2 using Qwen2 encoder."""
|
|
372
|
+
|
|
373
|
+
def __init__(self, config: VisionConfig):
|
|
374
|
+
super().__init__()
|
|
375
|
+
self.model_type = config.model_type
|
|
376
|
+
self.config = config
|
|
377
|
+
|
|
378
|
+
if self.model_type != "vision":
|
|
379
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
380
|
+
|
|
381
|
+
# Get Qwen2 config from params
|
|
382
|
+
qwen2_params = config.params.get("qwen2", {})
|
|
383
|
+
qwen2_config = Qwen2EncoderConfig(
|
|
384
|
+
dim=qwen2_params.get("dim", 896),
|
|
385
|
+
layers=qwen2_params.get("layers", 24),
|
|
386
|
+
heads=qwen2_params.get("heads", 14),
|
|
387
|
+
kv_heads=qwen2_params.get("kv_heads", 2),
|
|
388
|
+
intermediate_size=qwen2_params.get("intermediate_size", 4864),
|
|
389
|
+
rms_norm_eps=qwen2_params.get("rms_norm_eps", 1e-6),
|
|
390
|
+
rope_theta=qwen2_params.get("rope_theta", 1000000.0),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
self.qwen2_encoder = Qwen2Decoder2Encoder(qwen2_config)
|
|
394
|
+
|
|
395
|
+
def __call__(self, x: mx.array, sam_features: mx.array) -> mx.array:
|
|
396
|
+
"""Process vision input through Qwen2 encoder.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
x: Original image tensor (not used, kept for API compatibility)
|
|
400
|
+
sam_features: SAM encoder output of shape (B, H, W, C)
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
Encoded features of shape (B, 256, C)
|
|
404
|
+
"""
|
|
405
|
+
return self.qwen2_encoder(sam_features)
|
|
406
|
+
|
|
407
|
+
def sanitize(self, weights):
|
|
408
|
+
sanitized_weights = {}
|
|
409
|
+
weight_keys = {
|
|
410
|
+
"neck.0.weight",
|
|
411
|
+
"neck.2.weight",
|
|
412
|
+
"neck_hd.0.weight",
|
|
413
|
+
"neck_hd.2.weight",
|
|
414
|
+
"sam_model.net_2.weight",
|
|
415
|
+
"sam_model.net_3.weight",
|
|
416
|
+
"downsamples.0.weight",
|
|
417
|
+
"downsamples.1.weight",
|
|
418
|
+
"patch_embed.proj.weight",
|
|
419
|
+
"embeddings.patch_embedding.weight",
|
|
420
|
+
}
|
|
421
|
+
for k, v in weights.items():
|
|
422
|
+
if "position_ids" in k:
|
|
423
|
+
# Remove unused position_ids
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
elif ".".join(k.split(".")[-3:]) in weight_keys:
|
|
427
|
+
# PyTorch conv2d weight tensors have shape:
|
|
428
|
+
# [out_channels, in_channels, kH, KW]
|
|
429
|
+
# MLX conv2d expects the weight be of shape:
|
|
430
|
+
# [out_channels, kH, KW, in_channels]
|
|
431
|
+
if check_array_shape(v):
|
|
432
|
+
sanitized_weights[k] = v
|
|
433
|
+
else:
|
|
434
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
435
|
+
|
|
436
|
+
else:
|
|
437
|
+
sanitized_weights[k] = v
|
|
438
|
+
|
|
439
|
+
return sanitized_weights
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from .config import ModelConfig, TextConfig, VisionConfig
|
|
2
|
+
from .ernie4_5_moe_vl import Model, VariableResolutionResamplerModel
|
|
3
|
+
from .language import LanguageModel
|
|
4
|
+
from .processor import Ernie4_5_VLProcessor, Ernie4_5_VLTokenizer, ImageProcessor
|
|
5
|
+
from .vision import VisionModel
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from ..base import BaseModelConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class VisionConfig(BaseModelConfig):
|
|
10
|
+
"""DFNRopeVisionTransformer configuration."""
|
|
11
|
+
|
|
12
|
+
model_type: str = "DFNRope_vision_transformer"
|
|
13
|
+
depth: int = 32
|
|
14
|
+
embed_dim: int = 1280
|
|
15
|
+
hidden_size: int = 3584 # This should match embed_dim for DFNRope
|
|
16
|
+
hidden_act: str = "quick_gelu"
|
|
17
|
+
mlp_ratio: float = 4.0
|
|
18
|
+
num_heads: int = 16
|
|
19
|
+
in_channels: int = 3
|
|
20
|
+
patch_size: int = 14
|
|
21
|
+
spatial_merge_size: int = 2
|
|
22
|
+
layer_norm_eps: float = 1e-6
|
|
23
|
+
|
|
24
|
+
def __post_init__(self):
|
|
25
|
+
# hidden_size should equal embed_dim for this architecture
|
|
26
|
+
if self.hidden_size != self.embed_dim:
|
|
27
|
+
self.hidden_size = self.embed_dim
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class TextConfig(BaseModelConfig):
|
|
32
|
+
hidden_size: int = 3584
|
|
33
|
+
intermediate_size: int = 18944
|
|
34
|
+
model_type: str = "ernie"
|
|
35
|
+
max_position_embeddings: int = 131072
|
|
36
|
+
num_attention_heads: int = 28
|
|
37
|
+
num_key_value_heads: int = 4
|
|
38
|
+
num_hidden_layers: int = 56
|
|
39
|
+
rms_norm_eps: float = 1e-6
|
|
40
|
+
vocab_size: int = 151936
|
|
41
|
+
rope_theta: float = 1000000.0
|
|
42
|
+
use_bias: bool = False
|
|
43
|
+
tie_word_embeddings: bool = False
|
|
44
|
+
compression_ratio: float = 1.0
|
|
45
|
+
# MoE config
|
|
46
|
+
moe_num_experts: Union[int, List[int]] = 128
|
|
47
|
+
moe_layer_start_index: Union[int, List[int]] = 3
|
|
48
|
+
moe_layer_end_index: Optional[Union[int, List[int]]] = 53
|
|
49
|
+
moe_intermediate_size: Union[int, List[int]] = 1408
|
|
50
|
+
moe_capacity: List[float] = field(default_factory=lambda: [1.2, 2.0, 2.0])
|
|
51
|
+
moe_k: int = 2
|
|
52
|
+
moe_layer_interval: int = 1
|
|
53
|
+
moe_use_aux_free: bool = True
|
|
54
|
+
moe_num_shared_experts: int = 0
|
|
55
|
+
moe_gate_act: str = "softmax"
|
|
56
|
+
moe_norm_gate_logits: bool = True
|
|
57
|
+
head_dim: Optional[int] = None
|
|
58
|
+
# 3D RoPE config
|
|
59
|
+
rope_3d: bool = True
|
|
60
|
+
freq_allocation: int = 20
|
|
61
|
+
mrope_section: List[int] = field(default_factory=lambda: [22, 22, 20])
|
|
62
|
+
rope_scaling: Optional[Dict[str, Union[str, List[int]]]] = None
|
|
63
|
+
rope_parameters: Optional[Dict[str, Union[str, float, List[int]]]] = None
|
|
64
|
+
moe_norm_min: float = 1e-12
|
|
65
|
+
|
|
66
|
+
def __post_init__(self):
|
|
67
|
+
if self.num_key_value_heads is None:
|
|
68
|
+
self.num_key_value_heads = self.num_attention_heads
|
|
69
|
+
if self.head_dim is None:
|
|
70
|
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
71
|
+
# Normalize rope_scaling keys
|
|
72
|
+
if self.rope_scaling:
|
|
73
|
+
if "type" not in self.rope_scaling and "rope_type" in self.rope_scaling:
|
|
74
|
+
self.rope_scaling["type"] = self.rope_scaling.pop("rope_type")
|
|
75
|
+
# Extract mrope_section from rope_scaling if present
|
|
76
|
+
if "mrope_section" in self.rope_scaling:
|
|
77
|
+
self.mrope_section = list(self.rope_scaling["mrope_section"])
|
|
78
|
+
# Also check rope_parameters (HuggingFace format)
|
|
79
|
+
if self.rope_parameters:
|
|
80
|
+
if "mrope_section" in self.rope_parameters:
|
|
81
|
+
self.mrope_section = list(self.rope_parameters["mrope_section"])
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class ModelConfig(BaseModelConfig):
|
|
86
|
+
text_config: TextConfig = None
|
|
87
|
+
vision_config: VisionConfig = None
|
|
88
|
+
model_type: str = "ernie4_5_moe_vl"
|
|
89
|
+
ignore_index: int = -100
|
|
90
|
+
# Token IDs (defaults will be overridden by from_dict / __post_init__)
|
|
91
|
+
im_patch_id: int = 100295
|
|
92
|
+
image_token_id: int = 100295
|
|
93
|
+
image_start_token_id: int = 101304
|
|
94
|
+
image_end_token_id: int = 101305
|
|
95
|
+
video_token_id: int = 100295
|
|
96
|
+
video_start_token_id: int = 101306
|
|
97
|
+
video_end_token_id: int = 101307
|
|
98
|
+
vision_start_token_id: int = 101304
|
|
99
|
+
vision_end_token_id: int = 101305
|
|
100
|
+
vision_token_id: int = 100295
|
|
101
|
+
vocab_size: int = 103424
|
|
102
|
+
eos_token_id: Optional[List[int]] = None
|
|
103
|
+
# Vision-language integration
|
|
104
|
+
pixel_hidden_size: int = 1280
|
|
105
|
+
hidden_size: int = 2560
|
|
106
|
+
# Resampler config
|
|
107
|
+
spatial_conv_size: int = 2
|
|
108
|
+
temporal_conv_size: int = 2
|
|
109
|
+
use_temporal_conv: bool = True
|
|
110
|
+
# 3D RoPE config
|
|
111
|
+
rope_3d: bool = True
|
|
112
|
+
freq_allocation: int = 20
|
|
113
|
+
|
|
114
|
+
def __post_init__(self):
|
|
115
|
+
# Derive image_token_id from im_patch_id if not explicitly set differently
|
|
116
|
+
if self.image_token_id != self.im_patch_id:
|
|
117
|
+
self.image_token_id = self.im_patch_id
|
|
118
|
+
# vision_start/end should match image_start/end
|
|
119
|
+
if self.vision_start_token_id != self.image_start_token_id:
|
|
120
|
+
self.vision_start_token_id = self.image_start_token_id
|
|
121
|
+
if self.vision_end_token_id != self.image_end_token_id:
|
|
122
|
+
self.vision_end_token_id = self.image_end_token_id
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def from_dict(cls, params):
|
|
126
|
+
# Copy text config parameters from root level (like qwen2_vl does)
|
|
127
|
+
# This ensures update_module_configs works correctly
|
|
128
|
+
excluded_keys = {"vision_config"}
|
|
129
|
+
params["text_config"] = dict(
|
|
130
|
+
filter(lambda x: x[0] not in excluded_keys, params.items())
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return cls(
|
|
134
|
+
**{
|
|
135
|
+
k: v
|
|
136
|
+
for k, v in params.items()
|
|
137
|
+
if k in inspect.signature(cls).parameters
|
|
138
|
+
}
|
|
139
|
+
)
|