fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .vision import VisionModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LlavaMultiModalProjector(nn.Module):
|
|
13
|
+
def __init__(self, config: ModelConfig):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.linear_1 = nn.Linear(
|
|
16
|
+
config.vision_config.hidden_size, config.text_config.hidden_size, bias=True
|
|
17
|
+
)
|
|
18
|
+
self.gelu = nn.GELU()
|
|
19
|
+
self.linear_2 = nn.Linear(
|
|
20
|
+
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
24
|
+
x = self.linear_1(x)
|
|
25
|
+
x = self.gelu(x)
|
|
26
|
+
x = self.linear_2(x)
|
|
27
|
+
return x
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Model(nn.Module):
|
|
31
|
+
def __init__(self, config: ModelConfig):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.config = config
|
|
34
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
35
|
+
self.language_model = LanguageModel(config.text_config)
|
|
36
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
|
37
|
+
self.vision_feature_layer = config.vision_feature_layer
|
|
38
|
+
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
|
39
|
+
|
|
40
|
+
def get_input_embeddings(
|
|
41
|
+
self,
|
|
42
|
+
input_ids: Optional[mx.array] = None,
|
|
43
|
+
pixel_values: Optional[mx.array] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
):
|
|
46
|
+
if pixel_values is None:
|
|
47
|
+
return InputEmbeddingsFeatures(
|
|
48
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Get the input embeddings from the language model
|
|
52
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
53
|
+
|
|
54
|
+
# Get the output hidden states from the vision model
|
|
55
|
+
if isinstance(pixel_values, list):
|
|
56
|
+
pixel_values = mx.concatenate(
|
|
57
|
+
[mx.array(pv)[None, ...] for pv in pixel_values], axis=0
|
|
58
|
+
)
|
|
59
|
+
if pixel_values.ndim == 3:
|
|
60
|
+
pixel_values = pixel_values[None, ...]
|
|
61
|
+
|
|
62
|
+
# Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
|
|
63
|
+
# Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
|
|
64
|
+
# and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
|
|
65
|
+
*_, hidden_states = self.vision_tower(
|
|
66
|
+
pixel_values.transpose(0, 2, 3, 1),
|
|
67
|
+
output_hidden_states=True,
|
|
68
|
+
)
|
|
69
|
+
# Select the hidden states from the desired layer
|
|
70
|
+
selected_image_feature = hidden_states[self.vision_feature_layer]
|
|
71
|
+
|
|
72
|
+
# Pass image features through the multi-modal projector
|
|
73
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
|
74
|
+
|
|
75
|
+
# Insert special image tokens in the input_ids
|
|
76
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
77
|
+
self.config.image_token_index, image_features, inputs_embeds, input_ids
|
|
78
|
+
)
|
|
79
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def merge_input_ids_with_image_features(
|
|
83
|
+
image_token_index, image_features, inputs_embeds, input_ids
|
|
84
|
+
):
|
|
85
|
+
"""Merge image features into input embeddings at image token positions.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
image_token_index: Token ID for image placeholder
|
|
89
|
+
image_features: Vision features from the projector [1, num_features, hidden_dim]
|
|
90
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
91
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Updated input embeddings with image features inserted
|
|
95
|
+
"""
|
|
96
|
+
# Remove the extra batch dimension from image_features if present
|
|
97
|
+
if image_features.ndim == 3 and image_features.shape[0] == 1:
|
|
98
|
+
image_features = image_features.squeeze(0) # [num_features, hidden_dim]
|
|
99
|
+
|
|
100
|
+
# Positions of <image> tokens in input_ids
|
|
101
|
+
image_positions = input_ids == image_token_index
|
|
102
|
+
|
|
103
|
+
# Get dimensions
|
|
104
|
+
batch_size, seq_len = input_ids.shape
|
|
105
|
+
|
|
106
|
+
# Process each batch item
|
|
107
|
+
batch_outputs = []
|
|
108
|
+
feature_start_idx = 0
|
|
109
|
+
|
|
110
|
+
for batch_idx in range(batch_size):
|
|
111
|
+
# Get mask for this batch
|
|
112
|
+
image_mask = image_positions[batch_idx]
|
|
113
|
+
num_positions = mx.sum(image_mask).item()
|
|
114
|
+
|
|
115
|
+
if num_positions > 0:
|
|
116
|
+
# Extract features for this batch
|
|
117
|
+
batch_features = image_features[
|
|
118
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
# Validate we have the right number of features
|
|
122
|
+
if batch_features.shape[0] != num_positions:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
125
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Create indices for gathering
|
|
129
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
130
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
131
|
+
|
|
132
|
+
# Gather features
|
|
133
|
+
gathered_features = batch_features[feature_indices]
|
|
134
|
+
|
|
135
|
+
# Combine with original embeddings
|
|
136
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
137
|
+
batch_output = mx.where(
|
|
138
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
feature_start_idx += num_positions
|
|
142
|
+
else:
|
|
143
|
+
# No image tokens in this batch item
|
|
144
|
+
batch_output = inputs_embeds[batch_idx]
|
|
145
|
+
|
|
146
|
+
batch_outputs.append(batch_output)
|
|
147
|
+
|
|
148
|
+
# Stack all batch outputs
|
|
149
|
+
return mx.stack(batch_outputs, axis=0)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def layers(self):
|
|
153
|
+
return self.language_model.model.layers
|
|
154
|
+
|
|
155
|
+
def __call__(
|
|
156
|
+
self,
|
|
157
|
+
input_ids: mx.array,
|
|
158
|
+
pixel_values: mx.array,
|
|
159
|
+
mask: mx.array,
|
|
160
|
+
cache=None,
|
|
161
|
+
**kwargs,
|
|
162
|
+
):
|
|
163
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
164
|
+
input_ids, pixel_values, **kwargs
|
|
165
|
+
)
|
|
166
|
+
logits = self.language_model(
|
|
167
|
+
input_ids,
|
|
168
|
+
cache=cache,
|
|
169
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
170
|
+
)
|
|
171
|
+
return logits
|
|
172
|
+
|
|
173
|
+
def sanitize(self, weights):
|
|
174
|
+
def transform_key(key):
|
|
175
|
+
if "vision_tower" in key and "vision_model" not in key:
|
|
176
|
+
if "transformer" in key:
|
|
177
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
178
|
+
if "patch_conv" in key:
|
|
179
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
180
|
+
if "ln_pre" in key:
|
|
181
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
182
|
+
|
|
183
|
+
elif "vision_encoder" in key and "vision_tower" not in key:
|
|
184
|
+
if "transformer" in key:
|
|
185
|
+
key = key.replace(
|
|
186
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
187
|
+
)
|
|
188
|
+
if "patch_conv" in key:
|
|
189
|
+
key = key.replace(
|
|
190
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
191
|
+
)
|
|
192
|
+
if "ln_pre" in key:
|
|
193
|
+
key = key.replace(
|
|
194
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
elif "model.language_model" in key and "language_model.model" not in key:
|
|
198
|
+
key = key.replace("model.language_model", "language_model.model")
|
|
199
|
+
|
|
200
|
+
elif "lm_head" in key and "language_model" not in key:
|
|
201
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
202
|
+
|
|
203
|
+
elif "model.vision_projection" in key:
|
|
204
|
+
key = key.replace("model.vision_projection", "multi_modal_projector")
|
|
205
|
+
|
|
206
|
+
return key
|
|
207
|
+
|
|
208
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from .config import VisionConfig
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_array_shape(arr):
|
|
10
|
+
shape = arr.shape
|
|
11
|
+
|
|
12
|
+
# Check if the shape has 4 dimensions
|
|
13
|
+
if len(shape) != 4:
|
|
14
|
+
return False
|
|
15
|
+
|
|
16
|
+
out_channels, kH, KW, _ = shape
|
|
17
|
+
|
|
18
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
19
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
20
|
+
return True
|
|
21
|
+
else:
|
|
22
|
+
return False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def position_ids_in_meshgrid(patch_embeds_list, max_width):
|
|
26
|
+
positions = []
|
|
27
|
+
for patch in patch_embeds_list:
|
|
28
|
+
height, width = patch.shape[0], patch.shape[1]
|
|
29
|
+
h_grid, v_grid = mx.meshgrid(mx.arange(height), mx.arange(width), indexing="ij")
|
|
30
|
+
h_grid = h_grid.reshape(-1, 1)
|
|
31
|
+
v_grid = v_grid.reshape(-1, 1)
|
|
32
|
+
ids = h_grid * max_width + v_grid
|
|
33
|
+
positions.append(ids.flatten())
|
|
34
|
+
return mx.concatenate(positions)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def generate_block_attention_mask(patch_embeds_list, tensor):
|
|
38
|
+
seq_len = tensor.shape[1]
|
|
39
|
+
d_min = -1e9 # Using a large negative value as MLX doesn't have finfo
|
|
40
|
+
|
|
41
|
+
causal_mask = mx.full((seq_len, seq_len), vals=d_min)
|
|
42
|
+
|
|
43
|
+
block_end_idx = mx.cumsum(mx.array(patch_embeds_list))
|
|
44
|
+
block_start_idx = mx.concatenate([mx.array([0]), mx.array(patch_embeds_list[:-1])])
|
|
45
|
+
block_start_idx = mx.cumsum(block_start_idx)
|
|
46
|
+
|
|
47
|
+
for start, end in zip(block_start_idx, block_end_idx):
|
|
48
|
+
start, end = int(start), int(end) # Convert to integers for indexing
|
|
49
|
+
causal_mask[start:end, start:end] = 0
|
|
50
|
+
|
|
51
|
+
causal_mask = mx.broadcast_to(
|
|
52
|
+
causal_mask[None, None, :, :], (tensor.shape[0], 1, seq_len, seq_len)
|
|
53
|
+
)
|
|
54
|
+
return causal_mask.astype(tensor.dtype)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def rotate_half(x):
|
|
58
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
59
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
60
|
+
return mx.concatenate((-x2, x1), axis=-1)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
64
|
+
cos = mx.expand_dims(cos, axis=unsqueeze_dim)
|
|
65
|
+
sin = mx.expand_dims(sin, axis=unsqueeze_dim)
|
|
66
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
67
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
68
|
+
return q_embed, k_embed
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Attention(nn.Module):
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dims: int,
|
|
75
|
+
num_heads: int,
|
|
76
|
+
query_input_dims: Optional[int] = None,
|
|
77
|
+
key_input_dims: Optional[int] = None,
|
|
78
|
+
value_input_dims: Optional[int] = None,
|
|
79
|
+
value_dims: Optional[int] = None,
|
|
80
|
+
value_output_dims: Optional[int] = None,
|
|
81
|
+
bias: bool = False,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
|
|
85
|
+
if (dims % num_heads) != 0:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"The input feature dimensions should be divisible by the "
|
|
88
|
+
f"number of heads ({dims} % {num_heads}) != 0"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
query_input_dims = query_input_dims or dims
|
|
92
|
+
key_input_dims = key_input_dims or dims
|
|
93
|
+
value_input_dims = value_input_dims or key_input_dims
|
|
94
|
+
value_dims = value_dims or dims
|
|
95
|
+
value_output_dims = value_output_dims or dims
|
|
96
|
+
|
|
97
|
+
self.embed_dim = dims
|
|
98
|
+
self.num_heads = num_heads
|
|
99
|
+
self.head_dim = self.embed_dim // self.num_heads
|
|
100
|
+
|
|
101
|
+
self.scale = self.head_dim**-0.5
|
|
102
|
+
|
|
103
|
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
104
|
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
105
|
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
106
|
+
self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
|
107
|
+
|
|
108
|
+
def __call__(self, queries, keys, values, position_embeddings, mask=None):
|
|
109
|
+
queries = self.q_proj(queries)
|
|
110
|
+
keys = self.k_proj(keys)
|
|
111
|
+
values = self.v_proj(values)
|
|
112
|
+
|
|
113
|
+
num_heads = self.num_heads
|
|
114
|
+
B, L, D = queries.shape
|
|
115
|
+
_, S, _ = keys.shape
|
|
116
|
+
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
117
|
+
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
118
|
+
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
119
|
+
|
|
120
|
+
cos, sin = position_embeddings
|
|
121
|
+
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, unsqueeze_dim=0)
|
|
122
|
+
|
|
123
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
124
|
+
queries, keys, values, scale=self.scale, mask=mask
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
128
|
+
|
|
129
|
+
return self.o_proj(output)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class MLP(nn.Module):
|
|
133
|
+
def __init__(self, config: VisionConfig):
|
|
134
|
+
super().__init__()
|
|
135
|
+
dim = config.hidden_size
|
|
136
|
+
hidden_dim = config.intermediate_size
|
|
137
|
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
138
|
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
139
|
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
140
|
+
|
|
141
|
+
def __call__(self, x) -> mx.array:
|
|
142
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class EncoderLayer(nn.Module):
|
|
146
|
+
def __init__(self, config: VisionConfig):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.embed_dim = config.hidden_size
|
|
149
|
+
self.attention = Attention(
|
|
150
|
+
config.hidden_size, config.num_attention_heads, bias=True
|
|
151
|
+
)
|
|
152
|
+
self.attention_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
|
153
|
+
self.feed_forward = MLP(config)
|
|
154
|
+
self.ffn_norm = nn.RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
|
155
|
+
|
|
156
|
+
def __call__(
|
|
157
|
+
self,
|
|
158
|
+
x: mx.array,
|
|
159
|
+
position_embeddings: mx.array,
|
|
160
|
+
mask: Optional[mx.array] = None,
|
|
161
|
+
) -> mx.array:
|
|
162
|
+
y = self.attention_norm(x)
|
|
163
|
+
y = self.attention(y, y, y, position_embeddings, mask)
|
|
164
|
+
x = x + y
|
|
165
|
+
y = self.ffn_norm(x)
|
|
166
|
+
y = self.feed_forward(y)
|
|
167
|
+
return x + y
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class Encoder(nn.Module):
|
|
171
|
+
def __init__(self, config: VisionConfig):
|
|
172
|
+
super().__init__()
|
|
173
|
+
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class PixtralRotaryEmbedding:
|
|
177
|
+
def __init__(self, config):
|
|
178
|
+
self.dim = config.head_dim
|
|
179
|
+
self.base = config.rope_theta
|
|
180
|
+
max_patches_per_side = config.image_size // config.patch_size
|
|
181
|
+
freqs = 1.0 / (
|
|
182
|
+
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
h = mx.arange(max_patches_per_side)
|
|
186
|
+
w = mx.arange(max_patches_per_side)
|
|
187
|
+
|
|
188
|
+
freqs_h = mx.outer(h, freqs[::2]).astype(mx.float32)
|
|
189
|
+
freqs_w = mx.outer(w, freqs[1::2]).astype(mx.float32)
|
|
190
|
+
inv_freq = mx.concatenate(
|
|
191
|
+
[
|
|
192
|
+
mx.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
|
|
193
|
+
mx.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)),
|
|
194
|
+
],
|
|
195
|
+
axis=-1,
|
|
196
|
+
).reshape(-1, self.dim // 2)
|
|
197
|
+
|
|
198
|
+
self.inv_freq = mx.concatenate((inv_freq, inv_freq), axis=-1)
|
|
199
|
+
|
|
200
|
+
def __call__(self, x, position_ids):
|
|
201
|
+
freqs = self.inv_freq[position_ids]
|
|
202
|
+
emb = freqs
|
|
203
|
+
cos = mx.cos(emb)
|
|
204
|
+
sin = mx.sin(emb)
|
|
205
|
+
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class PixtralVisionModel(nn.Module):
|
|
209
|
+
def __init__(self, config: VisionConfig):
|
|
210
|
+
super().__init__()
|
|
211
|
+
self.config = config
|
|
212
|
+
self.patch_conv = nn.Conv2d(
|
|
213
|
+
in_channels=config.num_channels,
|
|
214
|
+
out_channels=config.hidden_size,
|
|
215
|
+
kernel_size=config.patch_size,
|
|
216
|
+
stride=config.patch_size,
|
|
217
|
+
bias=False,
|
|
218
|
+
)
|
|
219
|
+
self.ln_pre = nn.RMSNorm(config.hidden_size)
|
|
220
|
+
self.transformer = Encoder(config)
|
|
221
|
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
|
222
|
+
|
|
223
|
+
def __call__(
|
|
224
|
+
self,
|
|
225
|
+
x: List[mx.array],
|
|
226
|
+
output_hidden_states: Optional[bool] = None,
|
|
227
|
+
) -> mx.array:
|
|
228
|
+
|
|
229
|
+
if x.dtype != self.patch_conv.weight.dtype:
|
|
230
|
+
x = x.astype(self.patch_conv.weight.dtype)
|
|
231
|
+
|
|
232
|
+
patch_embeds_list = self.patch_conv(x)
|
|
233
|
+
patch_embeds = patch_embeds_list.reshape(1, -1, patch_embeds_list.shape[-1])
|
|
234
|
+
|
|
235
|
+
patch_embeds = self.ln_pre(patch_embeds)
|
|
236
|
+
|
|
237
|
+
position_ids = position_ids_in_meshgrid(
|
|
238
|
+
patch_embeds_list,
|
|
239
|
+
max_width=self.config.image_size // self.config.patch_size,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
|
243
|
+
|
|
244
|
+
mask = generate_block_attention_mask(
|
|
245
|
+
[p.shape[1] * p.shape[0] for p in patch_embeds_list], patch_embeds
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
encoder_states = (patch_embeds,) if output_hidden_states else None
|
|
249
|
+
|
|
250
|
+
for l in self.transformer.layers:
|
|
251
|
+
patch_embeds = l(
|
|
252
|
+
patch_embeds, mask=mask, position_embeddings=position_embedding
|
|
253
|
+
)
|
|
254
|
+
if output_hidden_states:
|
|
255
|
+
encoder_states = encoder_states + (patch_embeds,)
|
|
256
|
+
|
|
257
|
+
return patch_embeds, encoder_states
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class VisionModel(nn.Module):
|
|
261
|
+
def __init__(self, config: VisionConfig):
|
|
262
|
+
super().__init__()
|
|
263
|
+
|
|
264
|
+
self.model_type = config.model_type
|
|
265
|
+
if self.model_type not in ["clip_vision_model", "pixtral"]:
|
|
266
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
267
|
+
|
|
268
|
+
self.vision_model = PixtralVisionModel(config)
|
|
269
|
+
|
|
270
|
+
def __call__(
|
|
271
|
+
self, x: List[mx.array], output_hidden_states: Optional[bool] = None
|
|
272
|
+
) -> mx.array:
|
|
273
|
+
return self.vision_model(x, output_hidden_states)
|
|
274
|
+
|
|
275
|
+
def sanitize(self, weights):
|
|
276
|
+
sanitized_weights = {}
|
|
277
|
+
for k, v in weights.items():
|
|
278
|
+
if "position_ids" in k:
|
|
279
|
+
# Remove unused position_ids
|
|
280
|
+
continue
|
|
281
|
+
elif "patch_conv.weight" in k:
|
|
282
|
+
# PyTorch conv2d weight tensors have shape:
|
|
283
|
+
# [out_channels, in_channels, kH, KW]
|
|
284
|
+
# MLX conv2d expects the weight be of shape:
|
|
285
|
+
# [out_channels, kH, KW, in_channels]
|
|
286
|
+
if check_array_shape(v):
|
|
287
|
+
sanitized_weights[k] = v
|
|
288
|
+
else:
|
|
289
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
290
|
+
else:
|
|
291
|
+
sanitized_weights[k] = v
|
|
292
|
+
|
|
293
|
+
return sanitized_weights
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from ..base import BaseModelConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class VisionConfig(BaseModelConfig):
|
|
10
|
+
model_type: str = "qwen2_5_vl"
|
|
11
|
+
depth: int = 32
|
|
12
|
+
hidden_size: int = 1280
|
|
13
|
+
intermediate_size: int = 3420
|
|
14
|
+
out_hidden_size: int = 1536
|
|
15
|
+
num_heads: int = 16
|
|
16
|
+
image_size: int = 384
|
|
17
|
+
patch_size: int = 14
|
|
18
|
+
vocab_size: int = 32000
|
|
19
|
+
mlp_ratio: float = 4.0
|
|
20
|
+
in_channels: int = 3
|
|
21
|
+
layer_norm_eps: float = 1e-6
|
|
22
|
+
spatial_patch_size: int = 14
|
|
23
|
+
spatial_merge_size: int = 2
|
|
24
|
+
tokens_per_second: int = 2
|
|
25
|
+
temporal_patch_size: int = 2
|
|
26
|
+
window_size: int = 112
|
|
27
|
+
patch_size: int = 14
|
|
28
|
+
fullatt_block_indexes: list[int] = field(default_factory=lambda: [7, 15, 23, 31])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class TextConfig(BaseModelConfig):
|
|
33
|
+
model_type: str
|
|
34
|
+
hidden_size: int
|
|
35
|
+
num_hidden_layers: int
|
|
36
|
+
intermediate_size: int
|
|
37
|
+
num_attention_heads: int
|
|
38
|
+
rms_norm_eps: float
|
|
39
|
+
vocab_size: int
|
|
40
|
+
num_key_value_heads: Optional[int] = None
|
|
41
|
+
max_position_embeddings: Optional[int] = 128000
|
|
42
|
+
rope_theta: float = 1000000.0
|
|
43
|
+
rope_traditional: bool = False
|
|
44
|
+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
45
|
+
tie_word_embeddings: bool = True
|
|
46
|
+
|
|
47
|
+
def __post_init__(self):
|
|
48
|
+
if self.num_key_value_heads is None:
|
|
49
|
+
self.num_key_value_heads = self.num_attention_heads
|
|
50
|
+
|
|
51
|
+
if self.rope_scaling:
|
|
52
|
+
required_keys = {"mrope_section", "type"}
|
|
53
|
+
if not all(key in self.rope_scaling for key in required_keys):
|
|
54
|
+
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
55
|
+
|
|
56
|
+
if not self.rope_scaling["type"] in ["mrope", "default"]:
|
|
57
|
+
raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ModelConfig(BaseModelConfig):
|
|
62
|
+
text_config: TextConfig
|
|
63
|
+
vision_config: VisionConfig
|
|
64
|
+
model_type: str
|
|
65
|
+
ignore_index: int = -100
|
|
66
|
+
image_token_id: int = 151655
|
|
67
|
+
video_token_id: int = 151656
|
|
68
|
+
vision_start_token_id: int = 151652
|
|
69
|
+
vision_end_token_id: int = 151653
|
|
70
|
+
vision_token_id: int = 151654
|
|
71
|
+
vision_feature_select_strategy: str = "default"
|
|
72
|
+
vision_feature_layer: int = -2
|
|
73
|
+
vocab_size: int = 32000
|
|
74
|
+
eos_token_id: Optional[List[int]] = None
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_dict(cls, params):
|
|
78
|
+
# Copy text config parameters from root level
|
|
79
|
+
excluded_keys = {"vision_config"}
|
|
80
|
+
params["text_config"] = dict(
|
|
81
|
+
filter(lambda x: x[0] not in excluded_keys, params.items())
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return cls(
|
|
85
|
+
**{
|
|
86
|
+
k: v
|
|
87
|
+
for k, v in params.items()
|
|
88
|
+
if k in inspect.signature(cls).parameters
|
|
89
|
+
}
|
|
90
|
+
)
|