fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""Language model decoder for Jina VLM in MLX."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
|
|
8
|
+
from mlx_lm.models.cache import KVCache
|
|
9
|
+
|
|
10
|
+
from ..base import LanguageModelOutput
|
|
11
|
+
from .config import TextConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RMSNorm(nn.Module):
|
|
15
|
+
"""RMS Layer Normalization."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, dims: int, eps: float = 1e-6):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.eps = eps
|
|
20
|
+
self.weight = mx.ones((dims,))
|
|
21
|
+
|
|
22
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
23
|
+
rms = mx.sqrt(mx.mean(x * x, axis=-1, keepdims=True) + self.eps)
|
|
24
|
+
return self.weight * (x / rms)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RoPE(nn.Module):
|
|
28
|
+
"""Rotary Positional Embeddings."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, dims: int, theta: float = 1000000.0):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.dims = dims
|
|
33
|
+
self.theta = theta
|
|
34
|
+
inv_freq = 1.0 / (theta ** (mx.arange(0, dims, 2).astype(mx.float32) / dims))
|
|
35
|
+
self._inv_freq = inv_freq
|
|
36
|
+
|
|
37
|
+
def __call__(self, x: mx.array, offset: int = 0) -> mx.array:
|
|
38
|
+
seq_len = x.shape[2]
|
|
39
|
+
positions = mx.arange(offset, offset + seq_len).astype(mx.float32)
|
|
40
|
+
freqs = positions[:, None] * self._inv_freq[None, :]
|
|
41
|
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
|
42
|
+
cos = mx.cos(emb)[None, None, :, :]
|
|
43
|
+
sin = mx.sin(emb)[None, None, :, :]
|
|
44
|
+
x1 = x[..., : self.dims // 2]
|
|
45
|
+
x2 = x[..., self.dims // 2 :]
|
|
46
|
+
rotated = mx.concatenate([-x2, x1], axis=-1)
|
|
47
|
+
return (x * cos + rotated * sin).astype(x.dtype)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Attention(nn.Module):
|
|
51
|
+
"""Multi-head attention with GQA and RoPE - matches weight naming: attn.qkv, attn.out"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, config: TextConfig):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.config = config
|
|
56
|
+
self.num_heads = config.num_attention_heads
|
|
57
|
+
self.num_kv_heads = config.num_key_value_heads
|
|
58
|
+
self.head_dim = config.head_dim
|
|
59
|
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
|
60
|
+
self.scale = self.head_dim**-0.5
|
|
61
|
+
|
|
62
|
+
# Fused QKV projection - named to match weights
|
|
63
|
+
qkv_size = (
|
|
64
|
+
config.num_attention_heads + 2 * config.num_key_value_heads
|
|
65
|
+
) * config.head_dim
|
|
66
|
+
self.qkv = nn.Linear(config.hidden_size, qkv_size, bias=False)
|
|
67
|
+
self.out = nn.Linear(
|
|
68
|
+
config.num_attention_heads * config.head_dim, config.hidden_size, bias=False
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# QK normalization - named to match weights
|
|
72
|
+
if config.use_qk_norm:
|
|
73
|
+
self.q_norm = RMSNorm(config.head_dim, eps=config.rms_norm_eps)
|
|
74
|
+
self.k_norm = RMSNorm(config.head_dim, eps=config.rms_norm_eps)
|
|
75
|
+
else:
|
|
76
|
+
self.q_norm = None
|
|
77
|
+
self.k_norm = None
|
|
78
|
+
|
|
79
|
+
self.rope = RoPE(config.head_dim, theta=config.rope_theta)
|
|
80
|
+
|
|
81
|
+
def __call__(
|
|
82
|
+
self,
|
|
83
|
+
x: mx.array,
|
|
84
|
+
mask: Optional[mx.array] = None,
|
|
85
|
+
cache: Optional[KVCache] = None,
|
|
86
|
+
) -> mx.array:
|
|
87
|
+
B, L, _ = x.shape
|
|
88
|
+
|
|
89
|
+
# Compute fused QKV
|
|
90
|
+
qkv = self.qkv(x)
|
|
91
|
+
q_size = self.num_heads * self.head_dim
|
|
92
|
+
kv_size = self.num_kv_heads * self.head_dim
|
|
93
|
+
|
|
94
|
+
q = qkv[..., :q_size]
|
|
95
|
+
k = qkv[..., q_size : q_size + kv_size]
|
|
96
|
+
v = qkv[..., q_size + kv_size :]
|
|
97
|
+
|
|
98
|
+
q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
99
|
+
k = k.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
100
|
+
v = v.reshape(B, L, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
101
|
+
|
|
102
|
+
if self.q_norm is not None:
|
|
103
|
+
q = self.q_norm(q)
|
|
104
|
+
k = self.k_norm(k)
|
|
105
|
+
|
|
106
|
+
if cache is not None:
|
|
107
|
+
q = self.rope(q, offset=cache.offset)
|
|
108
|
+
k = self.rope(k, offset=cache.offset)
|
|
109
|
+
k, v = cache.update_and_fetch(k, v)
|
|
110
|
+
else:
|
|
111
|
+
q = self.rope(q)
|
|
112
|
+
k = self.rope(k)
|
|
113
|
+
|
|
114
|
+
output = scaled_dot_product_attention(
|
|
115
|
+
q, k, v, cache=cache, scale=self.scale, mask=mask
|
|
116
|
+
)
|
|
117
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
118
|
+
return self.out(output)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class MLP(nn.Module):
|
|
122
|
+
"""MLP with SwiGLU - matches weight naming: ffn.gate_up, ffn.down"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, config: TextConfig):
|
|
125
|
+
super().__init__()
|
|
126
|
+
# Fused gate and up projection - named to match weights
|
|
127
|
+
self.gate_up = nn.Linear(
|
|
128
|
+
config.hidden_size, 2 * config.intermediate_size, bias=False
|
|
129
|
+
)
|
|
130
|
+
self.down = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
131
|
+
|
|
132
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
133
|
+
gate_up = self.gate_up(x)
|
|
134
|
+
# Jina VLM convention: first half is value, second half is gate (activated)
|
|
135
|
+
up, gate = mx.split(gate_up, 2, axis=-1)
|
|
136
|
+
return self.down(nn.silu(gate) * up)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class TransformerBlock(nn.Module):
|
|
140
|
+
"""Transformer block - matches weight naming: attn_norm, ffn_norm"""
|
|
141
|
+
|
|
142
|
+
def __init__(self, config: TextConfig, layer_idx: int = 0):
|
|
143
|
+
super().__init__()
|
|
144
|
+
# Named to match weights
|
|
145
|
+
self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
146
|
+
self.attn = Attention(config)
|
|
147
|
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
148
|
+
self.ffn = MLP(config)
|
|
149
|
+
|
|
150
|
+
def __call__(
|
|
151
|
+
self,
|
|
152
|
+
x: mx.array,
|
|
153
|
+
mask: Optional[mx.array] = None,
|
|
154
|
+
cache: Optional[KVCache] = None,
|
|
155
|
+
) -> mx.array:
|
|
156
|
+
h = self.attn(self.attn_norm(x), mask=mask, cache=cache)
|
|
157
|
+
x = x + h
|
|
158
|
+
x = x + self.ffn(self.ffn_norm(x))
|
|
159
|
+
return x
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ExtendedEmbedding(nn.Module):
|
|
163
|
+
"""Embedding with additional tokens - matches weight naming: embedding, new_embedding"""
|
|
164
|
+
|
|
165
|
+
def __init__(self, vocab_size: int, additional_size: int, dims: int):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.vocab_size = vocab_size
|
|
168
|
+
self.additional_size = additional_size
|
|
169
|
+
# Named to match weights
|
|
170
|
+
self.embedding = mx.zeros((vocab_size, dims))
|
|
171
|
+
self.new_embedding = mx.zeros((additional_size, dims))
|
|
172
|
+
|
|
173
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
174
|
+
full_embedding = mx.concatenate([self.embedding, self.new_embedding], axis=0)
|
|
175
|
+
return full_embedding[x]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class TextModel(nn.Module):
|
|
179
|
+
"""Text decoder model - matches weight naming structure"""
|
|
180
|
+
|
|
181
|
+
def __init__(self, config: TextConfig):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.config = config
|
|
184
|
+
|
|
185
|
+
# Named to match weights: language_model.embedding
|
|
186
|
+
if config.additional_vocab_size > 0:
|
|
187
|
+
self.embedding = ExtendedEmbedding(
|
|
188
|
+
config.vocab_size, config.additional_vocab_size, config.hidden_size
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
192
|
+
|
|
193
|
+
self.layers = [
|
|
194
|
+
TransformerBlock(config, layer_idx=i)
|
|
195
|
+
for i in range(config.num_hidden_layers)
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
# Named to match weights: language_model.ln_f
|
|
199
|
+
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
200
|
+
|
|
201
|
+
def __call__(
|
|
202
|
+
self,
|
|
203
|
+
input_ids: mx.array,
|
|
204
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
205
|
+
mask: Optional[mx.array] = None,
|
|
206
|
+
cache: Optional[List[KVCache]] = None,
|
|
207
|
+
) -> mx.array:
|
|
208
|
+
if inputs_embeds is None:
|
|
209
|
+
x = self.embedding(input_ids)
|
|
210
|
+
else:
|
|
211
|
+
x = inputs_embeds
|
|
212
|
+
|
|
213
|
+
for i, layer in enumerate(self.layers):
|
|
214
|
+
layer_cache = cache[i] if cache is not None else None
|
|
215
|
+
x = layer(x, mask=mask, cache=layer_cache)
|
|
216
|
+
|
|
217
|
+
return self.ln_f(x)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class LanguageModel(nn.Module):
|
|
221
|
+
"""Language model wrapper - the TextModel is accessed as language_model in weights"""
|
|
222
|
+
|
|
223
|
+
def __init__(self, config: TextConfig):
|
|
224
|
+
super().__init__()
|
|
225
|
+
self.config = config
|
|
226
|
+
self.model_type = config.model_type
|
|
227
|
+
# This will be loaded under "language_model" prefix
|
|
228
|
+
self.embedding = None # Handled by sanitize
|
|
229
|
+
self.layers = None # Handled by sanitize
|
|
230
|
+
self.ln_f = None # Handled by sanitize
|
|
231
|
+
|
|
232
|
+
# Build the actual model components directly here
|
|
233
|
+
# They'll be found via language_model.embedding, language_model.layers, etc.
|
|
234
|
+
if config.additional_vocab_size > 0:
|
|
235
|
+
self.embedding = ExtendedEmbedding(
|
|
236
|
+
config.vocab_size, config.additional_vocab_size, config.hidden_size
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
240
|
+
|
|
241
|
+
self.layers = [
|
|
242
|
+
TransformerBlock(config, layer_idx=i)
|
|
243
|
+
for i in range(config.num_hidden_layers)
|
|
244
|
+
]
|
|
245
|
+
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
246
|
+
|
|
247
|
+
def __call__(
|
|
248
|
+
self,
|
|
249
|
+
inputs: mx.array,
|
|
250
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
251
|
+
mask: Optional[mx.array] = None,
|
|
252
|
+
cache: Optional[List[KVCache]] = None,
|
|
253
|
+
**kwargs,
|
|
254
|
+
) -> LanguageModelOutput:
|
|
255
|
+
if inputs_embeds is None:
|
|
256
|
+
x = self.embedding(inputs)
|
|
257
|
+
else:
|
|
258
|
+
x = inputs_embeds
|
|
259
|
+
|
|
260
|
+
# Initialize cache if needed
|
|
261
|
+
if cache is None:
|
|
262
|
+
cache = [None] * len(self.layers)
|
|
263
|
+
|
|
264
|
+
# Create causal attention mask
|
|
265
|
+
mask = create_attention_mask(x, cache)
|
|
266
|
+
|
|
267
|
+
for i, layer in enumerate(self.layers):
|
|
268
|
+
x = layer(x, mask=mask, cache=cache[i])
|
|
269
|
+
|
|
270
|
+
hidden_states = self.ln_f(x)
|
|
271
|
+
logits = self.lm_head(hidden_states)
|
|
272
|
+
return LanguageModelOutput(logits=logits)
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""Processor for Jina VLM in MLX-VLM."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from transformers.processing_utils import ProcessorMixin
|
|
9
|
+
|
|
10
|
+
from .image_processor import ImageProcessor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class JinaVLMProcessor(ProcessorMixin):
|
|
14
|
+
"""Processor for Jina VLM that combines tokenizer and image processor."""
|
|
15
|
+
|
|
16
|
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
|
17
|
+
attributes = ["tokenizer"]
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
tokenizer,
|
|
22
|
+
image_token: str = "<|image|>",
|
|
23
|
+
chat_template: Optional[str] = None,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
self.tokenizer = tokenizer
|
|
27
|
+
self.image_token = image_token
|
|
28
|
+
self._image_proc = ImageProcessor() # Internal, not exposed as image_processor
|
|
29
|
+
|
|
30
|
+
# Get image token ID
|
|
31
|
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(image_token)
|
|
32
|
+
|
|
33
|
+
super().__init__(tokenizer, **kwargs)
|
|
34
|
+
|
|
35
|
+
# Set chat template AFTER super().__init__ - always set the default if not already set
|
|
36
|
+
default_chat_template = (
|
|
37
|
+
"{% for message in messages %}"
|
|
38
|
+
"{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}"
|
|
39
|
+
"{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}"
|
|
40
|
+
"{% elif message['role'] == 'assistant' %}{{ '<|assistant|>\n' + message['content'] + '\n' }}"
|
|
41
|
+
"{% endif %}{% endfor %}"
|
|
42
|
+
"{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}"
|
|
43
|
+
)
|
|
44
|
+
if chat_template is not None:
|
|
45
|
+
self.tokenizer.chat_template = chat_template
|
|
46
|
+
elif not self.tokenizer.chat_template:
|
|
47
|
+
self.tokenizer.chat_template = default_chat_template
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def chat_template(self):
|
|
51
|
+
return self.tokenizer.chat_template
|
|
52
|
+
|
|
53
|
+
@chat_template.setter
|
|
54
|
+
def chat_template(self, value):
|
|
55
|
+
self.tokenizer.chat_template = value
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def pad_token(self):
|
|
59
|
+
return self.tokenizer.pad_token
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def pad_token_id(self):
|
|
63
|
+
return self.tokenizer.pad_token_id
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def eos_token(self):
|
|
67
|
+
return self.tokenizer.eos_token
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def eos_token_id(self):
|
|
71
|
+
return self.tokenizer.eos_token_id
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def bos_token(self):
|
|
75
|
+
return self.tokenizer.bos_token
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def bos_token_id(self):
|
|
79
|
+
return self.tokenizer.bos_token_id
|
|
80
|
+
|
|
81
|
+
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
|
82
|
+
return self.tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
|
83
|
+
|
|
84
|
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
|
85
|
+
return self.tokenizer.decode(token_ids, **kwargs)
|
|
86
|
+
|
|
87
|
+
def batch_decode(self, token_ids, **kwargs) -> List[str]:
|
|
88
|
+
return self.tokenizer.batch_decode(token_ids, **kwargs)
|
|
89
|
+
|
|
90
|
+
def process_one(
|
|
91
|
+
self,
|
|
92
|
+
prompt: str,
|
|
93
|
+
images: Optional[List[Image.Image]] = None,
|
|
94
|
+
inference_mode: bool = True,
|
|
95
|
+
) -> Dict:
|
|
96
|
+
"""Process a single prompt with images."""
|
|
97
|
+
if images is None:
|
|
98
|
+
images = []
|
|
99
|
+
|
|
100
|
+
# Process images
|
|
101
|
+
if images:
|
|
102
|
+
image_outputs = self._image_proc.preprocess(images)
|
|
103
|
+
pixel_values_list = image_outputs["pixel_values"]
|
|
104
|
+
image_tokens = image_outputs["image_tokens"]
|
|
105
|
+
image_input_idx_list = image_outputs["image_input_idx"]
|
|
106
|
+
image_masks_list = image_outputs["image_masks"]
|
|
107
|
+
else:
|
|
108
|
+
pixel_values_list = None
|
|
109
|
+
image_tokens = []
|
|
110
|
+
image_input_idx_list = None
|
|
111
|
+
image_masks_list = None
|
|
112
|
+
|
|
113
|
+
# Split prompt by image token
|
|
114
|
+
text_splits = prompt.split(self.image_token)
|
|
115
|
+
|
|
116
|
+
# Build input_ids with image tokens interleaved
|
|
117
|
+
input_ids = []
|
|
118
|
+
current_image_idx = 0
|
|
119
|
+
updated_image_input_idx = []
|
|
120
|
+
|
|
121
|
+
for i, text_part in enumerate(text_splits):
|
|
122
|
+
# Encode text part
|
|
123
|
+
if text_part:
|
|
124
|
+
text_tokens = self.encode(text_part, add_special_tokens=False)
|
|
125
|
+
input_ids.extend(text_tokens)
|
|
126
|
+
|
|
127
|
+
# Add image tokens if not the last split and we have images
|
|
128
|
+
if i < len(text_splits) - 1 and current_image_idx < len(image_tokens):
|
|
129
|
+
# Get image tokens for this image
|
|
130
|
+
img_tokens = image_tokens[current_image_idx]
|
|
131
|
+
# Offset image_input_idx by current position
|
|
132
|
+
if image_input_idx_list is not None and current_image_idx < len(
|
|
133
|
+
image_input_idx_list
|
|
134
|
+
):
|
|
135
|
+
offset_idx = image_input_idx_list[current_image_idx] + len(
|
|
136
|
+
input_ids
|
|
137
|
+
)
|
|
138
|
+
updated_image_input_idx.append(offset_idx)
|
|
139
|
+
input_ids.extend(img_tokens.tolist())
|
|
140
|
+
current_image_idx += 1
|
|
141
|
+
|
|
142
|
+
input_ids = mx.array(input_ids)
|
|
143
|
+
|
|
144
|
+
result = {
|
|
145
|
+
"input_ids": input_ids[None, :], # Add batch dimension
|
|
146
|
+
"attention_mask": mx.ones_like(input_ids)[None, :],
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if pixel_values_list is not None and len(pixel_values_list) > 0:
|
|
150
|
+
# Stack pixel values: (n_crops, n_patches, patch_dim)
|
|
151
|
+
result["pixel_values"] = mx.array(np.stack(pixel_values_list))
|
|
152
|
+
# Stack image_input_idx: (n_images, tokens_per_image)
|
|
153
|
+
result["image_input_idx"] = mx.array(np.stack(updated_image_input_idx))
|
|
154
|
+
# Stack image_masks: (n_crops, n_patches)
|
|
155
|
+
result["image_masks"] = mx.array(np.stack(image_masks_list))
|
|
156
|
+
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
def __call__(
|
|
160
|
+
self,
|
|
161
|
+
text: Optional[Union[str, List[str]]] = None,
|
|
162
|
+
images: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
|
163
|
+
inference_mode: bool = True,
|
|
164
|
+
return_tensors: Literal["np", "mx", "pt"] = "mx",
|
|
165
|
+
**kwargs,
|
|
166
|
+
) -> Dict:
|
|
167
|
+
"""Process text and images for Jina VLM.
|
|
168
|
+
|
|
169
|
+
When called with just text (like a tokenizer), returns tokenizer output.
|
|
170
|
+
When called with text and images, returns full processed inputs.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
text: Input text or list of texts
|
|
174
|
+
images: Input image or list of images
|
|
175
|
+
inference_mode: Whether in inference mode
|
|
176
|
+
return_tensors: Type of tensors to return
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Dictionary containing processed inputs
|
|
180
|
+
"""
|
|
181
|
+
# If called with just text (like a tokenizer), delegate to tokenizer
|
|
182
|
+
if text is not None and images is None:
|
|
183
|
+
return self.tokenizer(text, **kwargs)
|
|
184
|
+
|
|
185
|
+
if text is None:
|
|
186
|
+
raise ValueError("Text must be provided")
|
|
187
|
+
|
|
188
|
+
# Normalize inputs
|
|
189
|
+
if isinstance(text, str):
|
|
190
|
+
texts = [text]
|
|
191
|
+
else:
|
|
192
|
+
texts = text
|
|
193
|
+
|
|
194
|
+
if images is None:
|
|
195
|
+
images_list = [None] * len(texts)
|
|
196
|
+
elif isinstance(images, Image.Image):
|
|
197
|
+
images_list = [[images]]
|
|
198
|
+
elif isinstance(images, list) and len(images) > 0:
|
|
199
|
+
if isinstance(images[0], Image.Image):
|
|
200
|
+
# Single list of images for single prompt
|
|
201
|
+
images_list = [images]
|
|
202
|
+
else:
|
|
203
|
+
images_list = images
|
|
204
|
+
else:
|
|
205
|
+
images_list = [None] * len(texts)
|
|
206
|
+
|
|
207
|
+
# Process each text-image pair
|
|
208
|
+
batch_results = []
|
|
209
|
+
for prompt, imgs in zip(texts, images_list):
|
|
210
|
+
result = self.process_one(prompt, imgs, inference_mode)
|
|
211
|
+
batch_results.append(result)
|
|
212
|
+
|
|
213
|
+
# Collate results
|
|
214
|
+
if len(batch_results) == 1:
|
|
215
|
+
return batch_results[0]
|
|
216
|
+
else:
|
|
217
|
+
return self._collate_batch(batch_results)
|
|
218
|
+
|
|
219
|
+
def _collate_batch(self, batch_results: List[Dict]) -> Dict:
|
|
220
|
+
"""Collate multiple results into a batch."""
|
|
221
|
+
# Get max sequence length
|
|
222
|
+
max_len = max(r["input_ids"].shape[1] for r in batch_results)
|
|
223
|
+
|
|
224
|
+
padded_input_ids = []
|
|
225
|
+
padded_attention_mask = []
|
|
226
|
+
|
|
227
|
+
for r in batch_results:
|
|
228
|
+
seq_len = r["input_ids"].shape[1]
|
|
229
|
+
pad_len = max_len - seq_len
|
|
230
|
+
|
|
231
|
+
if pad_len > 0:
|
|
232
|
+
input_ids = mx.concatenate(
|
|
233
|
+
[mx.full((1, pad_len), self.pad_token_id), r["input_ids"]], axis=1
|
|
234
|
+
)
|
|
235
|
+
attention_mask = mx.concatenate(
|
|
236
|
+
[mx.zeros((1, pad_len)), r["attention_mask"]], axis=1
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
input_ids = r["input_ids"]
|
|
240
|
+
attention_mask = r["attention_mask"]
|
|
241
|
+
|
|
242
|
+
padded_input_ids.append(input_ids)
|
|
243
|
+
padded_attention_mask.append(attention_mask)
|
|
244
|
+
|
|
245
|
+
result = {
|
|
246
|
+
"input_ids": mx.concatenate(padded_input_ids, axis=0),
|
|
247
|
+
"attention_mask": mx.concatenate(padded_attention_mask, axis=0),
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
# Combine pixel values if present
|
|
251
|
+
all_pixel_values = []
|
|
252
|
+
all_image_input_idx = []
|
|
253
|
+
all_image_masks = []
|
|
254
|
+
|
|
255
|
+
for r in batch_results:
|
|
256
|
+
if "pixel_values" in r:
|
|
257
|
+
all_pixel_values.append(r["pixel_values"])
|
|
258
|
+
all_image_input_idx.append(r["image_input_idx"])
|
|
259
|
+
all_image_masks.append(r["image_masks"])
|
|
260
|
+
|
|
261
|
+
if all_pixel_values:
|
|
262
|
+
result["pixel_values"] = mx.concatenate(all_pixel_values, axis=0)
|
|
263
|
+
result["image_input_idx"] = mx.concatenate(all_image_input_idx, axis=0)
|
|
264
|
+
result["image_masks"] = mx.concatenate(all_image_masks, axis=0)
|
|
265
|
+
|
|
266
|
+
return result
|