fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,522 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from functools import partial, reduce
|
|
3
|
+
from typing import List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from transformers.image_transforms import (
|
|
10
|
+
convert_to_rgb,
|
|
11
|
+
resize,
|
|
12
|
+
to_channel_dimension_format,
|
|
13
|
+
)
|
|
14
|
+
from transformers.image_utils import PILImageResampling, to_numpy_array
|
|
15
|
+
|
|
16
|
+
from ..base import BaseImageProcessor, InputEmbeddingsFeatures
|
|
17
|
+
from .config import ModelConfig, VisionConfig
|
|
18
|
+
from .image_crops import adaptive_avg_pool2d, overlap_crop_image, reconstruct_from_crops
|
|
19
|
+
from .language import LanguageModel
|
|
20
|
+
from .vision import VisionModel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ImageProcessor(BaseImageProcessor):
|
|
24
|
+
"""Moondream image processor with multi-crop support."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, max_crops: int = 12, overlap_margin: int = 4):
|
|
27
|
+
super().__init__(
|
|
28
|
+
image_mean=(0.5, 0.5, 0.5),
|
|
29
|
+
image_std=(0.5, 0.5, 0.5),
|
|
30
|
+
size=(378, 378),
|
|
31
|
+
resample=PILImageResampling.BICUBIC,
|
|
32
|
+
rescale_factor=1 / 255,
|
|
33
|
+
)
|
|
34
|
+
self.max_crops = max_crops
|
|
35
|
+
self.overlap_margin = overlap_margin
|
|
36
|
+
|
|
37
|
+
def preprocess(
|
|
38
|
+
self, images
|
|
39
|
+
) -> Tuple[List[np.ndarray], List[int], List[Tuple[int, int]]]:
|
|
40
|
+
"""
|
|
41
|
+
Preprocess images with multi-crop support.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
images: Single PIL Image or list of PIL Images
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
crops_list: List of [n_crops, C, H, W] arrays per image
|
|
48
|
+
crop_counts: Number of crops per image
|
|
49
|
+
tilings: (h_tiles, w_tiles) per image
|
|
50
|
+
"""
|
|
51
|
+
if isinstance(images, Image.Image):
|
|
52
|
+
images = [images]
|
|
53
|
+
else:
|
|
54
|
+
assert isinstance(images, list)
|
|
55
|
+
|
|
56
|
+
crops_list = []
|
|
57
|
+
crop_counts = []
|
|
58
|
+
tilings = []
|
|
59
|
+
|
|
60
|
+
for image in images:
|
|
61
|
+
# Convert to RGB numpy array
|
|
62
|
+
image = convert_to_rgb(image)
|
|
63
|
+
image_np = to_numpy_array(image)
|
|
64
|
+
|
|
65
|
+
# Get multi-crop decomposition
|
|
66
|
+
crops, tiling = overlap_crop_image(
|
|
67
|
+
image_np,
|
|
68
|
+
max_crops=self.max_crops,
|
|
69
|
+
overlap_margin=self.overlap_margin,
|
|
70
|
+
base_size=self.size,
|
|
71
|
+
patch_size=14,
|
|
72
|
+
)
|
|
73
|
+
# crops is [n_crops, H, W, C] in range [0, 255]
|
|
74
|
+
|
|
75
|
+
# Normalize each crop: (pixel/255 - 0.5) / 0.5 = [-1, 1]
|
|
76
|
+
crops = crops.astype(np.float32) * self.rescale_factor # [0, 1]
|
|
77
|
+
crops = (crops - 0.5) / 0.5 # [-1, 1]
|
|
78
|
+
|
|
79
|
+
# Convert to channel-first format: [n_crops, H, W, C] -> [n_crops, C, H, W]
|
|
80
|
+
crops = np.transpose(crops, (0, 3, 1, 2))
|
|
81
|
+
|
|
82
|
+
crops_list.append(crops)
|
|
83
|
+
crop_counts.append(crops.shape[0])
|
|
84
|
+
tilings.append(tiling)
|
|
85
|
+
|
|
86
|
+
return crops_list, crop_counts, tilings
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class VisionProjection(nn.Module):
|
|
90
|
+
"""
|
|
91
|
+
2-layer MLP projector from vision to language space.
|
|
92
|
+
|
|
93
|
+
Projects concatenated [global, reconstructed] features (2304D) to language
|
|
94
|
+
model dimension (2048D). The input is the concatenation of:
|
|
95
|
+
- Global features: [B, 729, 1152] from full image
|
|
96
|
+
- Reconstructed features: [B, 729, 1152] pooled from local crops
|
|
97
|
+
|
|
98
|
+
Reference: moondream2/vision.py:77-89
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, config: ModelConfig):
|
|
102
|
+
super().__init__()
|
|
103
|
+
# Input is concatenation of global and reconstructed: 1152 * 2 = 2304
|
|
104
|
+
vision_dim = config.vision_config.hidden_size * 2 # 2304
|
|
105
|
+
inner_dim = config.proj_inner_dim # 8192
|
|
106
|
+
output_dim = config.text_config.hidden_size # 2048
|
|
107
|
+
|
|
108
|
+
self.fc1 = nn.Linear(vision_dim, inner_dim, bias=True)
|
|
109
|
+
self.fc2 = nn.Linear(inner_dim, output_dim, bias=True)
|
|
110
|
+
self.activation = nn.GELU(approx="precise")
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self, global_features: mx.array, reconstructed_features: mx.array
|
|
114
|
+
) -> mx.array:
|
|
115
|
+
"""
|
|
116
|
+
Project concatenated vision features to language model dimension.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
global_features: [B, 729, 1152] features from global crop
|
|
120
|
+
reconstructed_features: [B, 729, 1152] features reconstructed from local crops
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
[B, 729, 2048] projected features
|
|
124
|
+
"""
|
|
125
|
+
# Concatenate along feature dimension: [B, 729, 2304]
|
|
126
|
+
x = mx.concatenate([global_features, reconstructed_features], axis=-1)
|
|
127
|
+
x = self.activation(self.fc1(x))
|
|
128
|
+
x = self.fc2(x)
|
|
129
|
+
return x
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Model(nn.Module):
|
|
133
|
+
"""Moondream 2 model for visual question answering."""
|
|
134
|
+
|
|
135
|
+
def __init__(self, config: ModelConfig):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.model_type = config.model_type
|
|
138
|
+
self.config = config
|
|
139
|
+
|
|
140
|
+
self.vision_encoder = VisionModel(config.vision_config)
|
|
141
|
+
self.vision_projection = VisionProjection(config)
|
|
142
|
+
self.language_model = LanguageModel(config.text_config)
|
|
143
|
+
|
|
144
|
+
def get_input_embeddings(
|
|
145
|
+
self,
|
|
146
|
+
input_ids: Optional[mx.array] = None,
|
|
147
|
+
pixel_values: Optional[mx.array] = None,
|
|
148
|
+
crop_counts: Optional[List[int]] = None,
|
|
149
|
+
tilings: Optional[List[Tuple[int, int]]] = None,
|
|
150
|
+
**kwargs,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Get input embeddings with multi-crop image features.
|
|
154
|
+
|
|
155
|
+
Full pipeline:
|
|
156
|
+
1. Encode ALL crops through vision_encoder: [total_crops, 729, 1152]
|
|
157
|
+
2. For each image:
|
|
158
|
+
a. global_features = features[0] # [729, 1152]
|
|
159
|
+
b. local_features = features[1:].reshape(n_local, 27, 27, 1152)
|
|
160
|
+
c. reconstructed = reconstruct_from_crops(local_features, tiling)
|
|
161
|
+
d. reconstructed = adaptive_avg_pool2d(reconstructed, (27, 27))
|
|
162
|
+
e. reconstructed = reconstructed.reshape(729, 1152)
|
|
163
|
+
f. projected = vision_projection(global, reconstructed) # [729, 2048]
|
|
164
|
+
3. Insert projected features into embeddings
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
input_ids: Token IDs [B, seq_len]
|
|
168
|
+
pixel_values: Concatenated crops [total_crops, C, H, W]
|
|
169
|
+
crop_counts: Number of crops per image (list of ints)
|
|
170
|
+
tilings: (h_tiles, w_tiles) per image (list of tuples)
|
|
171
|
+
"""
|
|
172
|
+
# #region agent log
|
|
173
|
+
import json
|
|
174
|
+
log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
|
|
175
|
+
def log_embed(location, message, data, hypothesis_id):
|
|
176
|
+
try:
|
|
177
|
+
with open(log_file, "a") as f:
|
|
178
|
+
f.write(json.dumps({"sessionId": "debug-session", "runId": "inference", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
|
|
179
|
+
except: pass
|
|
180
|
+
|
|
181
|
+
log_embed("moondream2.py:get_input_embeddings_entry", "Entry to get_input_embeddings", {
|
|
182
|
+
"input_ids_shape": str(input_ids.shape) if input_ids is not None else None,
|
|
183
|
+
"pixel_values_shape": str(pixel_values.shape) if pixel_values is not None else None,
|
|
184
|
+
"crop_counts": crop_counts,
|
|
185
|
+
"tilings": tilings,
|
|
186
|
+
"input_ids_sample": input_ids[0, :20].tolist() if input_ids is not None else None
|
|
187
|
+
}, "H7,H9")
|
|
188
|
+
# #endregion
|
|
189
|
+
|
|
190
|
+
if pixel_values is None:
|
|
191
|
+
return InputEmbeddingsFeatures(
|
|
192
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Get text embeddings
|
|
196
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
197
|
+
|
|
198
|
+
# #region agent log
|
|
199
|
+
log_embed("moondream2.py:text_embeddings", "Text embeddings from embed_tokens", {
|
|
200
|
+
"shape": str(inputs_embeds.shape),
|
|
201
|
+
"dtype": str(inputs_embeds.dtype),
|
|
202
|
+
"mean": float(mx.mean(inputs_embeds)),
|
|
203
|
+
"std": float(mx.std(inputs_embeds)),
|
|
204
|
+
"min": float(mx.min(inputs_embeds)),
|
|
205
|
+
"max": float(mx.max(inputs_embeds))
|
|
206
|
+
}, "H9")
|
|
207
|
+
# #endregion
|
|
208
|
+
|
|
209
|
+
# Encode ALL crops through vision encoder at once
|
|
210
|
+
# pixel_values is [total_crops, C, H, W]
|
|
211
|
+
all_features = self.vision_encoder(pixel_values) # [total_crops, 729, 1152]
|
|
212
|
+
|
|
213
|
+
# #region agent log
|
|
214
|
+
log_embed("moondream2.py:vision_features", "Vision encoder output", {
|
|
215
|
+
"shape": str(all_features.shape),
|
|
216
|
+
"dtype": str(all_features.dtype),
|
|
217
|
+
"mean": float(mx.mean(all_features)),
|
|
218
|
+
"std": float(mx.std(all_features)),
|
|
219
|
+
"min": float(mx.min(all_features)),
|
|
220
|
+
"max": float(mx.max(all_features))
|
|
221
|
+
}, "H6,H8")
|
|
222
|
+
# #endregion
|
|
223
|
+
|
|
224
|
+
# Process each image's crops
|
|
225
|
+
batch_size = len(crop_counts) if crop_counts is not None else 1
|
|
226
|
+
projected_features_list = []
|
|
227
|
+
|
|
228
|
+
crop_offset = 0
|
|
229
|
+
for b in range(batch_size):
|
|
230
|
+
n_crops = crop_counts[b] if crop_counts is not None else all_features.shape[0]
|
|
231
|
+
tiling = tilings[b] if tilings is not None else (1, 1)
|
|
232
|
+
|
|
233
|
+
# Extract features for this image
|
|
234
|
+
img_features = all_features[crop_offset : crop_offset + n_crops] # [n_crops, 729, 1152]
|
|
235
|
+
crop_offset += n_crops
|
|
236
|
+
|
|
237
|
+
# Global features from first crop
|
|
238
|
+
global_features = img_features[0] # [729, 1152]
|
|
239
|
+
|
|
240
|
+
# Local crop features
|
|
241
|
+
n_local = n_crops - 1
|
|
242
|
+
if n_local > 0:
|
|
243
|
+
local_features = img_features[1:] # [n_local, 729, 1152]
|
|
244
|
+
|
|
245
|
+
# Reshape to spatial grid: [n_local, 729, 1152] -> [n_local, 27, 27, 1152]
|
|
246
|
+
local_features = local_features.reshape(n_local, 27, 27, -1)
|
|
247
|
+
|
|
248
|
+
# Reconstruct unified feature map from local crops
|
|
249
|
+
reconstructed = reconstruct_from_crops(
|
|
250
|
+
local_features, tiling, overlap_margin=self.config.vision_config.overlap_margin
|
|
251
|
+
) # [H, W, 1152]
|
|
252
|
+
|
|
253
|
+
# Pool back to 27x27 to match global features
|
|
254
|
+
reconstructed = adaptive_avg_pool2d(reconstructed, (27, 27)) # [27, 27, 1152]
|
|
255
|
+
|
|
256
|
+
# Flatten to [729, 1152]
|
|
257
|
+
reconstructed_flat = reconstructed.reshape(729, -1)
|
|
258
|
+
else:
|
|
259
|
+
# No local crops, duplicate global for reconstruction
|
|
260
|
+
reconstructed_flat = global_features
|
|
261
|
+
|
|
262
|
+
# Add batch dimension and project
|
|
263
|
+
global_batch = global_features[None, :, :] # [1, 729, 1152]
|
|
264
|
+
reconstructed_batch = reconstructed_flat[None, :, :] # [1, 729, 1152]
|
|
265
|
+
|
|
266
|
+
# #region agent log
|
|
267
|
+
log_embed("moondream2.py:before_projection", "Features before projection", {
|
|
268
|
+
"global_shape": str(global_batch.shape),
|
|
269
|
+
"global_mean": float(mx.mean(global_batch)),
|
|
270
|
+
"global_std": float(mx.std(global_batch)),
|
|
271
|
+
"reconstructed_shape": str(reconstructed_batch.shape),
|
|
272
|
+
"reconstructed_mean": float(mx.mean(reconstructed_batch)),
|
|
273
|
+
"reconstructed_std": float(mx.std(reconstructed_batch)),
|
|
274
|
+
"n_local_crops": n_local
|
|
275
|
+
}, "H6,H8")
|
|
276
|
+
# #endregion
|
|
277
|
+
|
|
278
|
+
# Project concatenated features: [1, 729, 2304] -> [1, 729, 2048]
|
|
279
|
+
projected = self.vision_projection(global_batch, reconstructed_batch)
|
|
280
|
+
|
|
281
|
+
# #region agent log
|
|
282
|
+
log_embed("moondream2.py:after_projection", "Projected vision features", {
|
|
283
|
+
"shape": str(projected.shape),
|
|
284
|
+
"dtype": str(projected.dtype),
|
|
285
|
+
"mean": float(mx.mean(projected)),
|
|
286
|
+
"std": float(mx.std(projected)),
|
|
287
|
+
"min": float(mx.min(projected)),
|
|
288
|
+
"max": float(mx.max(projected))
|
|
289
|
+
}, "H6")
|
|
290
|
+
# #endregion
|
|
291
|
+
|
|
292
|
+
projected_features_list.append(projected)
|
|
293
|
+
|
|
294
|
+
# Concatenate all projected features
|
|
295
|
+
image_features = mx.concatenate(projected_features_list, axis=0) # [B, 729, 2048]
|
|
296
|
+
|
|
297
|
+
# #region agent log
|
|
298
|
+
log_embed("moondream2.py:concatenated_image_features", "Concatenated image features", {
|
|
299
|
+
"shape": str(image_features.shape),
|
|
300
|
+
"dtype": str(image_features.dtype),
|
|
301
|
+
"mean": float(mx.mean(image_features)),
|
|
302
|
+
"std": float(mx.std(image_features)),
|
|
303
|
+
"min": float(mx.min(image_features)),
|
|
304
|
+
"max": float(mx.max(image_features))
|
|
305
|
+
}, "H6")
|
|
306
|
+
# #endregion
|
|
307
|
+
|
|
308
|
+
# Replace 729-token image placeholder in input_ids with vision features
|
|
309
|
+
# prepare_inputs() creates input_ids as: [BOS, <img_token>*729, <text_tokens>]
|
|
310
|
+
patch_count = image_features.shape[1] # expected 729
|
|
311
|
+
if inputs_embeds.shape[1] >= 1 + patch_count:
|
|
312
|
+
# #region agent log
|
|
313
|
+
log_embed("moondream2.py:before_replacement", "Before replacing image tokens", {
|
|
314
|
+
"inputs_embeds_shape": str(inputs_embeds.shape),
|
|
315
|
+
"patch_count": patch_count,
|
|
316
|
+
"replacement_range": f"[1:{1 + patch_count}]",
|
|
317
|
+
"image_features_dtype": str(image_features.dtype),
|
|
318
|
+
"inputs_embeds_dtype": str(inputs_embeds.dtype),
|
|
319
|
+
"embeds_before_mean": float(mx.mean(inputs_embeds[:, 1 : 1 + patch_count, :])),
|
|
320
|
+
"embeds_before_std": float(mx.std(inputs_embeds[:, 1 : 1 + patch_count, :]))
|
|
321
|
+
}, "H7,H9")
|
|
322
|
+
# #endregion
|
|
323
|
+
|
|
324
|
+
# Replace positions [1 : 1+patch_count] (right after BOS)
|
|
325
|
+
inputs_embeds[:, 1 : 1 + patch_count, :] = image_features.astype(
|
|
326
|
+
inputs_embeds.dtype
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# #region agent log
|
|
330
|
+
log_embed("moondream2.py:after_replacement", "After replacing image tokens", {
|
|
331
|
+
"embeds_after_mean": float(mx.mean(inputs_embeds[:, 1 : 1 + patch_count, :])),
|
|
332
|
+
"embeds_after_std": float(mx.std(inputs_embeds[:, 1 : 1 + patch_count, :])),
|
|
333
|
+
"text_tokens_mean": float(mx.mean(inputs_embeds[:, 1 + patch_count:, :])) if inputs_embeds.shape[1] > 1 + patch_count else None,
|
|
334
|
+
"text_tokens_std": float(mx.std(inputs_embeds[:, 1 + patch_count:, :])) if inputs_embeds.shape[1] > 1 + patch_count else None
|
|
335
|
+
}, "H7,H9")
|
|
336
|
+
# #endregion
|
|
337
|
+
|
|
338
|
+
final_embeddings = inputs_embeds
|
|
339
|
+
else:
|
|
340
|
+
# Fallback: original behavior (prepend image embeddings)
|
|
341
|
+
batch_size = inputs_embeds.shape[0]
|
|
342
|
+
final_embeddings = []
|
|
343
|
+
for b in range(batch_size):
|
|
344
|
+
bos_embed = inputs_embeds[b : b + 1, :1, :]
|
|
345
|
+
text_embed = inputs_embeds[b : b + 1, 1:, :]
|
|
346
|
+
img_embed = image_features[b : b + 1]
|
|
347
|
+
combined = mx.concatenate([bos_embed, img_embed, text_embed], axis=1)
|
|
348
|
+
final_embeddings.append(combined)
|
|
349
|
+
final_embeddings = mx.concatenate(final_embeddings, axis=0)
|
|
350
|
+
|
|
351
|
+
# #region agent log
|
|
352
|
+
log_embed("moondream2.py:final_embeddings", "Final embeddings output", {
|
|
353
|
+
"shape": str(final_embeddings.shape),
|
|
354
|
+
"dtype": str(final_embeddings.dtype),
|
|
355
|
+
"mean": float(mx.mean(final_embeddings)),
|
|
356
|
+
"std": float(mx.std(final_embeddings)),
|
|
357
|
+
"min": float(mx.min(final_embeddings)),
|
|
358
|
+
"max": float(mx.max(final_embeddings)),
|
|
359
|
+
"has_nan": bool(mx.any(mx.isnan(final_embeddings))),
|
|
360
|
+
"has_inf": bool(mx.any(mx.isinf(final_embeddings)))
|
|
361
|
+
}, "H9")
|
|
362
|
+
# #endregion
|
|
363
|
+
|
|
364
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_embeddings)
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def layers(self):
|
|
368
|
+
return self.language_model.model.layers
|
|
369
|
+
|
|
370
|
+
def __call__(
|
|
371
|
+
self,
|
|
372
|
+
input_ids: mx.array,
|
|
373
|
+
pixel_values: mx.array,
|
|
374
|
+
mask: Optional[mx.array] = None,
|
|
375
|
+
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
|
376
|
+
crop_counts: Optional[List[int]] = None,
|
|
377
|
+
tilings: Optional[List[Tuple[int, int]]] = None,
|
|
378
|
+
**kwargs,
|
|
379
|
+
):
|
|
380
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
381
|
+
input_ids, pixel_values, crop_counts=crop_counts, tilings=tilings, **kwargs
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
logits = self.language_model(
|
|
385
|
+
inputs=input_ids,
|
|
386
|
+
cache=cache,
|
|
387
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
388
|
+
mask=mask,
|
|
389
|
+
)
|
|
390
|
+
return logits
|
|
391
|
+
|
|
392
|
+
def sanitize(self, weights):
|
|
393
|
+
"""
|
|
394
|
+
Map HuggingFace weights to MLX model structure.
|
|
395
|
+
|
|
396
|
+
HF Weight Structure (from moondream2/vision.py + text.py):
|
|
397
|
+
- model.vision.patch_emb.* -> vision_encoder.patch_emb.*
|
|
398
|
+
- model.vision.pos_emb -> vision_encoder.position_embedding
|
|
399
|
+
- model.vision.blocks.{i}.ln1.* -> vision_encoder.encoder.layers.{i}.ln1.*
|
|
400
|
+
- model.vision.blocks.{i}.attn.* -> vision_encoder.encoder.layers.{i}.attn.*
|
|
401
|
+
- model.vision.blocks.{i}.ln2.* -> vision_encoder.encoder.layers.{i}.ln2.*
|
|
402
|
+
- model.vision.blocks.{i}.mlp.* -> vision_encoder.encoder.layers.{i}.mlp.*
|
|
403
|
+
- model.vision.post_ln.* -> vision_encoder.post_layernorm.*
|
|
404
|
+
- model.vision.proj_mlp.* -> vision_projection.*
|
|
405
|
+
- model.text.wte -> language_model.model.embed_tokens.weight
|
|
406
|
+
- model.text.blocks.{i}.ln.* -> language_model.model.layers.{i}.input_layernorm.*
|
|
407
|
+
- model.text.blocks.{i}.attn.qkv.* -> language_model.model.layers.{i}.self_attn.qkv_proj.*
|
|
408
|
+
- model.text.blocks.{i}.attn.proj.* -> language_model.model.layers.{i}.self_attn.o_proj.*
|
|
409
|
+
- model.text.blocks.{i}.mlp.* -> language_model.model.layers.{i}.mlp.*
|
|
410
|
+
- model.text.post_ln.* -> language_model.model.norm.*
|
|
411
|
+
- model.text.lm_head.* -> language_model.lm_head.*
|
|
412
|
+
- model.region.* -> (skip, not needed for VQA)
|
|
413
|
+
"""
|
|
414
|
+
# #region agent log
|
|
415
|
+
import json
|
|
416
|
+
log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
|
|
417
|
+
def log(location, message, data, hypothesis_id):
|
|
418
|
+
try:
|
|
419
|
+
with open(log_file, "a") as f:
|
|
420
|
+
f.write(json.dumps({"sessionId": "debug-session", "runId": "sanitize", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
|
|
421
|
+
except: pass
|
|
422
|
+
# #endregion
|
|
423
|
+
|
|
424
|
+
new_weights = {}
|
|
425
|
+
n_skipped_region = 0
|
|
426
|
+
n_changed = 0
|
|
427
|
+
n_unchanged = 0
|
|
428
|
+
|
|
429
|
+
# #region agent log
|
|
430
|
+
original_keys = sorted(weights.keys())
|
|
431
|
+
log("moondream2.py:sanitize_entry", "Sanitize entry - input weights", {"n_weights": len(weights), "sample_keys": original_keys[:10], "all_keys": original_keys}, "H1")
|
|
432
|
+
# #endregion
|
|
433
|
+
|
|
434
|
+
for k, v in weights.items():
|
|
435
|
+
# Skip region model weights (not needed for VQA)
|
|
436
|
+
if k.startswith("model.region."):
|
|
437
|
+
n_skipped_region += 1
|
|
438
|
+
continue
|
|
439
|
+
|
|
440
|
+
new_key = k
|
|
441
|
+
|
|
442
|
+
# Vision encoder: patch embedding
|
|
443
|
+
if k.startswith("model.vision.patch_emb."):
|
|
444
|
+
new_key = k.replace(
|
|
445
|
+
"model.vision.patch_emb.",
|
|
446
|
+
"vision_encoder.patch_emb.",
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Vision encoder: positional embedding
|
|
450
|
+
elif k == "model.vision.pos_emb":
|
|
451
|
+
new_key = "vision_encoder.position_embedding"
|
|
452
|
+
|
|
453
|
+
# Vision encoder: blocks
|
|
454
|
+
elif k.startswith("model.vision.blocks."):
|
|
455
|
+
# Extract block number and rest
|
|
456
|
+
match = re.match(r"model\.vision\.blocks\.(\d+)\.(.+)", k)
|
|
457
|
+
if match:
|
|
458
|
+
block_num = match.group(1)
|
|
459
|
+
suffix = match.group(2)
|
|
460
|
+
# Keep the structure: ln1, attn, ln2, mlp
|
|
461
|
+
new_key = f"vision_encoder.encoder.layers.{block_num}.{suffix}"
|
|
462
|
+
|
|
463
|
+
# Vision encoder: post layer norm
|
|
464
|
+
elif k.startswith("model.vision.post_ln."):
|
|
465
|
+
new_key = k.replace(
|
|
466
|
+
"model.vision.post_ln.",
|
|
467
|
+
"vision_encoder.post_layernorm.",
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# Vision projection MLP
|
|
471
|
+
elif k.startswith("model.vision.proj_mlp."):
|
|
472
|
+
new_key = k.replace(
|
|
473
|
+
"model.vision.proj_mlp.",
|
|
474
|
+
"vision_projection.",
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Text model: embedding
|
|
478
|
+
elif k == "model.text.wte":
|
|
479
|
+
new_key = "language_model.model.embed_tokens.weight"
|
|
480
|
+
|
|
481
|
+
# Text model: transformer blocks
|
|
482
|
+
elif k.startswith("model.text.blocks."):
|
|
483
|
+
# Extract block number and rest
|
|
484
|
+
match = re.match(r"model\.text\.blocks\.(\d+)\.(.+)", k)
|
|
485
|
+
if match:
|
|
486
|
+
block_num = match.group(1)
|
|
487
|
+
suffix = match.group(2)
|
|
488
|
+
|
|
489
|
+
# Map the suffixes
|
|
490
|
+
if suffix.startswith("ln."):
|
|
491
|
+
new_suffix = suffix.replace("ln.", "input_layernorm.")
|
|
492
|
+
elif suffix.startswith("attn.qkv."):
|
|
493
|
+
new_suffix = suffix.replace("attn.qkv.", "self_attn.qkv_proj.")
|
|
494
|
+
elif suffix.startswith("attn.proj."):
|
|
495
|
+
new_suffix = suffix.replace("attn.proj.", "self_attn.o_proj.")
|
|
496
|
+
elif suffix.startswith("mlp."):
|
|
497
|
+
new_suffix = suffix
|
|
498
|
+
else:
|
|
499
|
+
new_suffix = suffix
|
|
500
|
+
|
|
501
|
+
new_key = f"language_model.model.layers.{block_num}.{new_suffix}"
|
|
502
|
+
|
|
503
|
+
# Text model: final layer norm
|
|
504
|
+
elif k.startswith("model.text.post_ln."):
|
|
505
|
+
new_key = k.replace("model.text.post_ln.", "language_model.model.norm.")
|
|
506
|
+
|
|
507
|
+
# Text model: lm head
|
|
508
|
+
elif k.startswith("model.text.lm_head."):
|
|
509
|
+
new_key = k.replace("model.text.lm_head.", "language_model.lm_head.")
|
|
510
|
+
|
|
511
|
+
if new_key == k:
|
|
512
|
+
n_unchanged += 1
|
|
513
|
+
else:
|
|
514
|
+
n_changed += 1
|
|
515
|
+
new_weights[new_key] = v
|
|
516
|
+
|
|
517
|
+
# #region agent log
|
|
518
|
+
sanitized_keys = sorted(new_weights.keys())
|
|
519
|
+
log("moondream2.py:sanitize_exit", "Sanitize exit - output weights", {"n_weights": len(new_weights), "n_changed": n_changed, "n_unchanged": n_unchanged, "n_skipped": n_skipped_region, "sample_keys": sanitized_keys[:10], "all_keys": sanitized_keys}, "H1")
|
|
520
|
+
# #endregion
|
|
521
|
+
|
|
522
|
+
return new_weights
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Moondream2 processor for mlx-vlm.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from transformers import AutoTokenizer
|
|
9
|
+
from transformers.processing_utils import ProcessorMixin
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MoondreamProcessor(ProcessorMixin):
|
|
13
|
+
"""
|
|
14
|
+
Processor for Moondream2 model.
|
|
15
|
+
|
|
16
|
+
Wraps the tokenizer and provides compatibility with mlx-vlm's generation flow.
|
|
17
|
+
Image processing is handled separately by the model's ImageProcessor class.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
tokenizer_class = "AutoTokenizer"
|
|
21
|
+
attributes = ["tokenizer"]
|
|
22
|
+
|
|
23
|
+
def __init__(self, tokenizer, chat_template: Optional[str] = None, **kwargs):
|
|
24
|
+
self.tokenizer = tokenizer
|
|
25
|
+
|
|
26
|
+
# Set up chat template for moondream
|
|
27
|
+
if chat_template is None:
|
|
28
|
+
# Moondream uses a simple format: <image>\n\nQuestion: {question}\n\nAnswer:
|
|
29
|
+
chat_template = (
|
|
30
|
+
"{% for message in messages %}"
|
|
31
|
+
"{% if message['role'] == 'user' %}"
|
|
32
|
+
"{{ message['content'] }}\n\n"
|
|
33
|
+
"{% elif message['role'] == 'assistant' %}"
|
|
34
|
+
"{{ message['content'] }}"
|
|
35
|
+
"{% endif %}"
|
|
36
|
+
"{% endfor %}"
|
|
37
|
+
"{% if add_generation_prompt %}Answer: {% endif %}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.chat_template = chat_template
|
|
41
|
+
tokenizer.chat_template = chat_template
|
|
42
|
+
super().__init__(tokenizer, **kwargs)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
46
|
+
"""Load processor from pretrained model path."""
|
|
47
|
+
# Convert Path to string if needed
|
|
48
|
+
if hasattr(pretrained_model_name_or_path, "__fspath__"):
|
|
49
|
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
50
|
+
|
|
51
|
+
# Pop kwargs that are not valid for AutoTokenizer
|
|
52
|
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
|
53
|
+
|
|
54
|
+
# Moondream2 uses a custom tokenizer (starmie-v1), not the GPT-2
|
|
55
|
+
# tokenizer files shipped in the model repo.
|
|
56
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
57
|
+
"moondream/starmie-v1",
|
|
58
|
+
trust_remote_code=trust_remote_code,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# starmie-v1 doesn't define special token roles; set them to
|
|
62
|
+
# <|endoftext|> (ID 0) which moondream uses as BOS/EOS/PAD.
|
|
63
|
+
tokenizer.eos_token = "<|endoftext|>"
|
|
64
|
+
tokenizer.bos_token = "<|endoftext|>"
|
|
65
|
+
tokenizer.pad_token = "<|endoftext|>"
|
|
66
|
+
|
|
67
|
+
return cls(tokenizer=tokenizer)
|
|
68
|
+
|
|
69
|
+
def __call__(
|
|
70
|
+
self,
|
|
71
|
+
text: Optional[Union[str, List[str]]] = None,
|
|
72
|
+
images: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
75
|
+
"""
|
|
76
|
+
Process text and images for the model.
|
|
77
|
+
|
|
78
|
+
Note: Image processing is handled by the model's ImageProcessor,
|
|
79
|
+
this processor mainly handles tokenization.
|
|
80
|
+
"""
|
|
81
|
+
if text is None:
|
|
82
|
+
raise ValueError("Text input is required")
|
|
83
|
+
|
|
84
|
+
# Tokenize text
|
|
85
|
+
encoding = self.tokenizer(text, **kwargs)
|
|
86
|
+
|
|
87
|
+
return encoding
|
|
88
|
+
|
|
89
|
+
def batch_decode(self, *args, **kwargs):
|
|
90
|
+
"""Decode token ids to text."""
|
|
91
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
def decode(self, *args, **kwargs):
|
|
94
|
+
"""Decode token ids to text."""
|
|
95
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
96
|
+
|
|
97
|
+
def apply_chat_template(self, messages, add_generation_prompt=True, **kwargs):
|
|
98
|
+
"""Apply chat template to messages."""
|
|
99
|
+
return self.tokenizer.apply_chat_template(
|
|
100
|
+
messages,
|
|
101
|
+
chat_template=self.chat_template,
|
|
102
|
+
add_generation_prompt=add_generation_prompt,
|
|
103
|
+
tokenize=kwargs.get("tokenize", False),
|
|
104
|
+
**{k: v for k, v in kwargs.items() if k != "tokenize"}
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Token properties delegated to tokenizer
|
|
108
|
+
@property
|
|
109
|
+
def pad_token(self):
|
|
110
|
+
return self.tokenizer.pad_token
|
|
111
|
+
|
|
112
|
+
@pad_token.setter
|
|
113
|
+
def pad_token(self, value):
|
|
114
|
+
self.tokenizer.pad_token = value
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def pad_token_id(self):
|
|
118
|
+
return self.tokenizer.pad_token_id
|
|
119
|
+
|
|
120
|
+
@pad_token_id.setter
|
|
121
|
+
def pad_token_id(self, value):
|
|
122
|
+
self.tokenizer.pad_token_id = value
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def eos_token(self):
|
|
126
|
+
return self.tokenizer.eos_token
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def eos_token_id(self):
|
|
130
|
+
return self.tokenizer.eos_token_id
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def bos_token(self):
|
|
134
|
+
return self.tokenizer.bos_token
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def bos_token_id(self):
|
|
138
|
+
return self.tokenizer.bos_token_id
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Install the AutoProcessor patch for moondream1 model type
|
|
142
|
+
from ..base import install_auto_processor_patch
|
|
143
|
+
|
|
144
|
+
install_auto_processor_patch("moondream1", MoondreamProcessor)
|