fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from typing import Dict, Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures, check_array_shape
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .vision import VisionModel
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from transformers import AutoImageProcessor, AutoProcessor
|
|
13
|
+
|
|
14
|
+
from .processing_hunyuan_vl import HunYuanVLImageProcessor, HunYuanVLProcessor
|
|
15
|
+
|
|
16
|
+
MODEL_TYPE = "hunyuan_vl"
|
|
17
|
+
|
|
18
|
+
AutoImageProcessor.register(
|
|
19
|
+
MODEL_TYPE, slow_image_processor_class=HunYuanVLImageProcessor
|
|
20
|
+
)
|
|
21
|
+
AutoProcessor.register(MODEL_TYPE, HunYuanVLProcessor)
|
|
22
|
+
|
|
23
|
+
except Exception as e:
|
|
24
|
+
raise e
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Model(nn.Module):
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: ModelConfig):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.config = config
|
|
32
|
+
self.model_type = config.model_type
|
|
33
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
34
|
+
self.language_model = LanguageModel(config)
|
|
35
|
+
|
|
36
|
+
def get_input_embeddings(
|
|
37
|
+
self,
|
|
38
|
+
input_ids: Optional[mx.array] = None,
|
|
39
|
+
pixel_values: Optional[mx.array] = None,
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> mx.array:
|
|
42
|
+
|
|
43
|
+
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
44
|
+
|
|
45
|
+
position_ids_from_processor = kwargs.pop("position_ids", None)
|
|
46
|
+
|
|
47
|
+
# Get text embeddings
|
|
48
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
49
|
+
|
|
50
|
+
# If no image, return text embeddings
|
|
51
|
+
if pixel_values is None:
|
|
52
|
+
# Reset stored position_ids when no image
|
|
53
|
+
self.language_model._position_ids = None
|
|
54
|
+
return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
|
|
55
|
+
|
|
56
|
+
# Get vision features
|
|
57
|
+
vision_features = self.vision_tower(pixel_values, image_grid_thw)
|
|
58
|
+
|
|
59
|
+
# Find image token positions and replace with vision features
|
|
60
|
+
image_token_id = self.config.image_token_id
|
|
61
|
+
image_mask = input_ids == image_token_id
|
|
62
|
+
|
|
63
|
+
# Get number of image tokens expected
|
|
64
|
+
num_image_tokens = image_mask.sum().item()
|
|
65
|
+
num_vision_tokens = vision_features.shape[1]
|
|
66
|
+
|
|
67
|
+
if num_image_tokens != num_vision_tokens:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Number of image placeholders ({num_image_tokens}) does not match "
|
|
70
|
+
f"number of vision tokens ({num_vision_tokens}). "
|
|
71
|
+
f"Expected token count based on grid: {num_vision_tokens}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
B, L, _ = inputs_embeds.shape
|
|
75
|
+
|
|
76
|
+
output_parts = []
|
|
77
|
+
|
|
78
|
+
for b in range(B):
|
|
79
|
+
mask_b = image_mask[b] # (L,) boolean mask
|
|
80
|
+
text_embeds_b = inputs_embeds[b] # (L, D)
|
|
81
|
+
vis_feats_b = vision_features[b] # (num_vis_tokens, D)
|
|
82
|
+
|
|
83
|
+
# Build sequence for this batch
|
|
84
|
+
vis_idx = 0
|
|
85
|
+
seq_parts = []
|
|
86
|
+
for pos in range(L):
|
|
87
|
+
if mask_b[pos].item():
|
|
88
|
+
# Use vision feature
|
|
89
|
+
seq_parts.append(vis_feats_b[vis_idx : vis_idx + 1])
|
|
90
|
+
vis_idx += 1
|
|
91
|
+
else:
|
|
92
|
+
# Use text embedding
|
|
93
|
+
seq_parts.append(text_embeds_b[pos : pos + 1])
|
|
94
|
+
|
|
95
|
+
# Concatenate all parts for this batch
|
|
96
|
+
batch_embeds = mx.concatenate(seq_parts, axis=0) # (L, D)
|
|
97
|
+
output_parts.append(batch_embeds[None, :, :]) # (1, L, D)
|
|
98
|
+
|
|
99
|
+
# Stack batches
|
|
100
|
+
inputs_embeds = mx.concatenate(output_parts, axis=0) # (B, L, D)
|
|
101
|
+
|
|
102
|
+
# Pre-calculate position_ids for chunked prefill
|
|
103
|
+
if position_ids_from_processor is not None:
|
|
104
|
+
self.language_model._position_ids = position_ids_from_processor
|
|
105
|
+
elif image_grid_thw is not None:
|
|
106
|
+
position_ids = self.language_model.get_xdrope_input_positions(
|
|
107
|
+
input_tokens=input_ids[0].tolist(),
|
|
108
|
+
image_grid_thw=image_grid_thw,
|
|
109
|
+
image_token_id=self.config.image_token_id,
|
|
110
|
+
spatial_merge_size=self.config.vision_config.spatial_merge_size,
|
|
111
|
+
)[None, ...]
|
|
112
|
+
self.language_model._position_ids = position_ids
|
|
113
|
+
|
|
114
|
+
return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def layers(self):
|
|
118
|
+
return self.language_model.model.layers
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def head_dim(self):
|
|
122
|
+
return self.config.text_config.head_dim
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def n_kv_heads(self):
|
|
126
|
+
return self.config.text_config.num_key_value_heads
|
|
127
|
+
|
|
128
|
+
def __call__(
|
|
129
|
+
self,
|
|
130
|
+
input_ids: mx.array,
|
|
131
|
+
pixel_values: Optional[mx.array] = None,
|
|
132
|
+
mask: Optional[mx.array] = None,
|
|
133
|
+
cache=None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
):
|
|
136
|
+
|
|
137
|
+
# Get embeddings (with vision features merged if image provided)
|
|
138
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
139
|
+
input_ids=input_ids,
|
|
140
|
+
pixel_values=pixel_values,
|
|
141
|
+
**kwargs,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Forward through language model
|
|
145
|
+
return self.language_model(
|
|
146
|
+
input_ids=input_ids,
|
|
147
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
148
|
+
mask=mask,
|
|
149
|
+
cache=cache,
|
|
150
|
+
image_grid_thw=image_grid_thw,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
154
|
+
|
|
155
|
+
sanitized = {}
|
|
156
|
+
|
|
157
|
+
for key, value in weights.items():
|
|
158
|
+
new_key = key
|
|
159
|
+
|
|
160
|
+
# Language model mappings
|
|
161
|
+
if key.startswith("model."):
|
|
162
|
+
new_key = "language_model." + key
|
|
163
|
+
|
|
164
|
+
# Vision tower mappings
|
|
165
|
+
elif key.startswith("vit."):
|
|
166
|
+
new_key = key.replace("vit.", "vision_tower.", 1)
|
|
167
|
+
|
|
168
|
+
# Handle Conv2d weight transposition for MLX
|
|
169
|
+
# PyTorch Conv2d: [out_channels, in_channels, kH, kW]
|
|
170
|
+
# MLX Conv2d: [out_channels, kH, kW, in_channels]
|
|
171
|
+
if (
|
|
172
|
+
"patch_embedding.weight" in new_key
|
|
173
|
+
or "proj.0.weight" in new_key
|
|
174
|
+
or "proj.2.weight" in new_key
|
|
175
|
+
):
|
|
176
|
+
if not check_array_shape(value):
|
|
177
|
+
value = value.transpose(0, 2, 3, 1)
|
|
178
|
+
|
|
179
|
+
sanitized[new_key] = value
|
|
180
|
+
|
|
181
|
+
return sanitized
|
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ..base import (
|
|
8
|
+
LanguageModelOutput,
|
|
9
|
+
create_attention_mask,
|
|
10
|
+
scaled_dot_product_attention,
|
|
11
|
+
)
|
|
12
|
+
from ..cache import KVCache
|
|
13
|
+
from .config import ModelConfig, TextConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HunyuanRotaryEmbedding:
|
|
17
|
+
def __init__(self, config: TextConfig):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.config = config
|
|
20
|
+
self.dim = config.head_dim
|
|
21
|
+
self.max_position_embeddings = config.max_position_embeddings
|
|
22
|
+
self.base = config.rope_theta
|
|
23
|
+
|
|
24
|
+
# Handle xdrope/dynamic scaling
|
|
25
|
+
self.xdrope_section = config.rope_scaling.get("xdrope_section")
|
|
26
|
+
self.rope_type = config.rope_scaling.get("type")
|
|
27
|
+
alpha = config.rope_scaling.get("alpha")
|
|
28
|
+
|
|
29
|
+
if config.rope_scaling is not None and self.rope_type in ["xdrope", "dynamic"]:
|
|
30
|
+
if alpha:
|
|
31
|
+
self.base = self.base * (alpha ** (self.dim / (self.dim - 2)))
|
|
32
|
+
|
|
33
|
+
inv_freq = 1.0 / (
|
|
34
|
+
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
|
|
35
|
+
)
|
|
36
|
+
self._inv_freq = inv_freq
|
|
37
|
+
self._cos_cached = None
|
|
38
|
+
self._sin_cached = None
|
|
39
|
+
self._cached_seq_len = 0
|
|
40
|
+
|
|
41
|
+
def _update_cache(self, seq_len: int, dtype: mx.Dtype):
|
|
42
|
+
if seq_len > self._cached_seq_len:
|
|
43
|
+
self._cached_seq_len = seq_len
|
|
44
|
+
t = mx.arange(seq_len, dtype=mx.float32)
|
|
45
|
+
freqs = mx.outer(t, self._inv_freq)
|
|
46
|
+
emb = mx.concatenate([freqs, freqs], axis=-1)
|
|
47
|
+
self._cos_cached = mx.cos(emb).astype(dtype)
|
|
48
|
+
self._sin_cached = mx.sin(emb).astype(dtype)
|
|
49
|
+
|
|
50
|
+
def __call__(self, x: mx.array, seq_len: int) -> Tuple[mx.array, mx.array]:
|
|
51
|
+
self._update_cache(seq_len, x.dtype)
|
|
52
|
+
return self._cos_cached[:seq_len], self._sin_cached[:seq_len]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def rotate_half(x: mx.array) -> mx.array:
|
|
56
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
57
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
58
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def apply_rotary_pos_emb_xdrope(
|
|
62
|
+
q: mx.array,
|
|
63
|
+
k: mx.array,
|
|
64
|
+
cos: mx.array,
|
|
65
|
+
sin: mx.array,
|
|
66
|
+
position_ids: mx.array,
|
|
67
|
+
xdrope_section: list,
|
|
68
|
+
output_size: tuple,
|
|
69
|
+
) -> Tuple[mx.array, mx.array]:
|
|
70
|
+
"""Applies XD Rotary Position Embedding."""
|
|
71
|
+
|
|
72
|
+
x_dim = len(xdrope_section)
|
|
73
|
+
cos = (
|
|
74
|
+
cos[position_ids, ...]
|
|
75
|
+
.transpose(0, 2, 1, 3)
|
|
76
|
+
.reshape(output_size[0], output_size[2], x_dim, -1)
|
|
77
|
+
)
|
|
78
|
+
sin = (
|
|
79
|
+
sin[position_ids, ...]
|
|
80
|
+
.transpose(0, 2, 1, 3)
|
|
81
|
+
.reshape(output_size[0], output_size[2], x_dim, -1)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
xdrope_section = xdrope_section * 2
|
|
85
|
+
|
|
86
|
+
# for xd concat
|
|
87
|
+
assert sum(xdrope_section) == cos.shape[-1], "Illegal partition for xd rope"
|
|
88
|
+
|
|
89
|
+
# Convert split sizes to split indices for MLX
|
|
90
|
+
split_indices = [
|
|
91
|
+
sum(xdrope_section[: i + 1]) for i in range(len(xdrope_section) - 1)
|
|
92
|
+
]
|
|
93
|
+
cos_splits = mx.split(cos, split_indices, axis=-1)
|
|
94
|
+
sin_splits = mx.split(sin, split_indices, axis=-1)
|
|
95
|
+
|
|
96
|
+
cos = mx.concatenate(
|
|
97
|
+
[m[:, :, i % x_dim, :] for i, m in enumerate(cos_splits)], axis=-1
|
|
98
|
+
)
|
|
99
|
+
sin = mx.concatenate(
|
|
100
|
+
[m[:, :, i % x_dim, :] for i, m in enumerate(sin_splits)], axis=-1
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# for head repeat
|
|
104
|
+
cos = cos.reshape(output_size[0], 1, output_size[2], -1)
|
|
105
|
+
sin = sin.reshape(output_size[0], 1, output_size[2], -1)
|
|
106
|
+
|
|
107
|
+
origin_dtype = q.dtype
|
|
108
|
+
q, k = q.astype(mx.float32), k.astype(mx.float32)
|
|
109
|
+
cos, sin = cos.astype(mx.float32), sin.astype(mx.float32)
|
|
110
|
+
|
|
111
|
+
q_out = (q * cos) + (rotate_half(q) * sin)
|
|
112
|
+
k_out = (k * cos) + (rotate_half(k) * sin)
|
|
113
|
+
|
|
114
|
+
return q_out.astype(origin_dtype), k_out.astype(origin_dtype)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def apply_rotary_pos_emb(
|
|
118
|
+
q: mx.array, k: mx.array, cos: mx.array, sin: mx.array, unsqueeze_dim: int = 1
|
|
119
|
+
) -> Tuple[mx.array, mx.array]:
|
|
120
|
+
"""Standard rotary position embedding.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
q: Queries with shape (batch, n_heads, seq_len, head_dim)
|
|
124
|
+
k: Keys with shape (batch, n_heads, seq_len, head_dim)
|
|
125
|
+
cos: Cosine values with shape (seq_len, head_dim)
|
|
126
|
+
sin: Sine values with shape (seq_len, head_dim)
|
|
127
|
+
"""
|
|
128
|
+
# Expand cos/sin to (1, 1, seq_len, head_dim) for broadcasting
|
|
129
|
+
cos = cos[None, None, :, :]
|
|
130
|
+
sin = sin[None, None, :, :]
|
|
131
|
+
|
|
132
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
133
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
134
|
+
|
|
135
|
+
return q_embed, k_embed
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Attention(nn.Module):
|
|
139
|
+
def __init__(self, config: TextConfig):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.config = config
|
|
142
|
+
|
|
143
|
+
self.hidden_size = config.hidden_size
|
|
144
|
+
self.n_heads = config.num_attention_heads
|
|
145
|
+
self.n_kv_heads = config.num_key_value_heads
|
|
146
|
+
self.head_dim = config.head_dim
|
|
147
|
+
self.scale = self.head_dim**-0.5
|
|
148
|
+
|
|
149
|
+
self.q_proj = nn.Linear(
|
|
150
|
+
self.hidden_size, self.n_heads * self.head_dim, bias=config.attention_bias
|
|
151
|
+
)
|
|
152
|
+
self.k_proj = nn.Linear(
|
|
153
|
+
self.hidden_size,
|
|
154
|
+
self.n_kv_heads * self.head_dim,
|
|
155
|
+
bias=config.attention_bias,
|
|
156
|
+
)
|
|
157
|
+
self.v_proj = nn.Linear(
|
|
158
|
+
self.hidden_size,
|
|
159
|
+
self.n_kv_heads * self.head_dim,
|
|
160
|
+
bias=config.attention_bias,
|
|
161
|
+
)
|
|
162
|
+
self.o_proj = nn.Linear(
|
|
163
|
+
self.n_heads * self.head_dim,
|
|
164
|
+
config.hidden_size,
|
|
165
|
+
bias=config.attention_bias,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if config.use_qk_norm:
|
|
169
|
+
self.query_layernorm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
170
|
+
self.key_layernorm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
171
|
+
|
|
172
|
+
self.rotary_emb = HunyuanRotaryEmbedding(config=config)
|
|
173
|
+
|
|
174
|
+
self.xdrope_section = None
|
|
175
|
+
if config.rope_scaling is not None:
|
|
176
|
+
self.xdrope_section = config.rope_scaling.get("xdrope_section")
|
|
177
|
+
|
|
178
|
+
def __call__(
|
|
179
|
+
self,
|
|
180
|
+
x: mx.array,
|
|
181
|
+
mask: Optional[mx.array] = None,
|
|
182
|
+
cache: Optional[KVCache] = None,
|
|
183
|
+
position_ids: Optional[mx.array] = None,
|
|
184
|
+
) -> mx.array:
|
|
185
|
+
B, L, _ = x.shape
|
|
186
|
+
|
|
187
|
+
# Project Q, K, V
|
|
188
|
+
queries = self.q_proj(x)
|
|
189
|
+
keys = self.k_proj(x)
|
|
190
|
+
values = self.v_proj(x)
|
|
191
|
+
|
|
192
|
+
# Reshape to (B, n_heads, L, head_dim)
|
|
193
|
+
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
|
|
194
|
+
0, 2, 1, 3
|
|
195
|
+
)
|
|
196
|
+
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
197
|
+
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
|
|
198
|
+
0, 2, 1, 3
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
kv_seq_len = L
|
|
202
|
+
offset = 0
|
|
203
|
+
if cache is not None:
|
|
204
|
+
offset = cache.offset
|
|
205
|
+
kv_seq_len += offset
|
|
206
|
+
|
|
207
|
+
cos, sin = self.rotary_emb(values, seq_len=kv_seq_len)
|
|
208
|
+
|
|
209
|
+
# Apply rotary embeddings
|
|
210
|
+
if self.xdrope_section is not None and (cache is None or offset == 0):
|
|
211
|
+
# XD RoPE for prefill (first forward pass)
|
|
212
|
+
output_size = (B, self.n_heads, L, L)
|
|
213
|
+
queries, keys = apply_rotary_pos_emb_xdrope(
|
|
214
|
+
queries,
|
|
215
|
+
keys,
|
|
216
|
+
cos,
|
|
217
|
+
sin,
|
|
218
|
+
position_ids,
|
|
219
|
+
self.xdrope_section,
|
|
220
|
+
output_size,
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
# Standard RoPE for decode (subsequent tokens)
|
|
224
|
+
if cache is not None and offset > 0:
|
|
225
|
+
cos = cos[-L:]
|
|
226
|
+
sin = sin[-L:]
|
|
227
|
+
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
|
|
228
|
+
|
|
229
|
+
# Apply QK normalization if configured
|
|
230
|
+
if self.config.use_qk_norm:
|
|
231
|
+
queries = self.query_layernorm(queries)
|
|
232
|
+
keys = self.key_layernorm(keys)
|
|
233
|
+
|
|
234
|
+
# Update cache
|
|
235
|
+
if cache is not None:
|
|
236
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
237
|
+
|
|
238
|
+
# Apply mask
|
|
239
|
+
if mask is not None and isinstance(mask, mx.array):
|
|
240
|
+
mask = mask[..., : keys.shape[-2]]
|
|
241
|
+
|
|
242
|
+
output = scaled_dot_product_attention(
|
|
243
|
+
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
247
|
+
return self.o_proj(output)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class MLP(nn.Module):
|
|
251
|
+
def __init__(self, config: TextConfig):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.hidden_size = config.hidden_size
|
|
254
|
+
self.intermediate_size = config.intermediate_size
|
|
255
|
+
|
|
256
|
+
self.gate_proj = nn.Linear(
|
|
257
|
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
|
258
|
+
)
|
|
259
|
+
self.up_proj = nn.Linear(
|
|
260
|
+
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
|
|
261
|
+
)
|
|
262
|
+
self.down_proj = nn.Linear(
|
|
263
|
+
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
267
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class DecoderLayer(nn.Module):
|
|
271
|
+
def __init__(self, config: TextConfig):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.hidden_size = config.hidden_size
|
|
274
|
+
self.self_attn = Attention(config)
|
|
275
|
+
self.mlp = MLP(config)
|
|
276
|
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
277
|
+
self.post_attention_layernorm = nn.RMSNorm(
|
|
278
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def __call__(
|
|
282
|
+
self,
|
|
283
|
+
x: mx.array,
|
|
284
|
+
mask: Optional[mx.array] = None,
|
|
285
|
+
cache: Optional[KVCache] = None,
|
|
286
|
+
position_ids: Optional[mx.array] = None,
|
|
287
|
+
) -> mx.array:
|
|
288
|
+
# Self-attention with residual
|
|
289
|
+
r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
|
|
290
|
+
h = x + r
|
|
291
|
+
|
|
292
|
+
# MLP with residual
|
|
293
|
+
r = self.mlp(self.post_attention_layernorm(h))
|
|
294
|
+
out = h + r
|
|
295
|
+
|
|
296
|
+
return out
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class HunyuanModel(nn.Module):
|
|
300
|
+
def __init__(self, config: TextConfig):
|
|
301
|
+
super().__init__()
|
|
302
|
+
self.config = config
|
|
303
|
+
self.vocab_size = config.vocab_size
|
|
304
|
+
self.num_hidden_layers = config.num_hidden_layers
|
|
305
|
+
|
|
306
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
307
|
+
self.layers = [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
308
|
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
309
|
+
|
|
310
|
+
def __call__(
|
|
311
|
+
self,
|
|
312
|
+
input_ids: Optional[mx.array] = None,
|
|
313
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
314
|
+
mask: Optional[mx.array] = None,
|
|
315
|
+
cache=None,
|
|
316
|
+
position_ids: Optional[mx.array] = None,
|
|
317
|
+
) -> mx.array:
|
|
318
|
+
|
|
319
|
+
if inputs_embeds is None:
|
|
320
|
+
h = self.embed_tokens(input_ids)
|
|
321
|
+
else:
|
|
322
|
+
h = inputs_embeds
|
|
323
|
+
|
|
324
|
+
if cache is None:
|
|
325
|
+
cache = [None] * len(self.layers)
|
|
326
|
+
|
|
327
|
+
if mask is None:
|
|
328
|
+
mask = create_attention_mask(h, cache)
|
|
329
|
+
|
|
330
|
+
for layer, c in zip(self.layers, cache):
|
|
331
|
+
h = layer(h, mask, c, position_ids)
|
|
332
|
+
|
|
333
|
+
return self.norm(h)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class LanguageModel(nn.Module):
|
|
337
|
+
def __init__(self, config: ModelConfig = None):
|
|
338
|
+
super().__init__()
|
|
339
|
+
self.args = config.text_config
|
|
340
|
+
self.config = config
|
|
341
|
+
self.model_type = self.args.model_type
|
|
342
|
+
self.model = HunyuanModel(self.args)
|
|
343
|
+
self._position_ids = None
|
|
344
|
+
|
|
345
|
+
if not self.args.tie_word_embeddings:
|
|
346
|
+
self.lm_head = nn.Linear(
|
|
347
|
+
self.args.hidden_size, self.args.vocab_size, bias=False
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def get_xdrope_input_positions(
|
|
351
|
+
self,
|
|
352
|
+
input_tokens: List[int],
|
|
353
|
+
image_grid_thw: Optional[mx.array],
|
|
354
|
+
image_token_id: int,
|
|
355
|
+
spatial_merge_size: int,
|
|
356
|
+
) -> mx.array:
|
|
357
|
+
"""Compute XD-RoPE position IDs for image-text interleaved inputs."""
|
|
358
|
+
|
|
359
|
+
xd_num = len(self.args.rope_scaling["xdrope_section"])
|
|
360
|
+
|
|
361
|
+
input_tokens_arr = np.array(input_tokens)
|
|
362
|
+
image_start_indices = np.where(input_tokens_arr == image_token_id)[0].tolist()
|
|
363
|
+
|
|
364
|
+
seq_len = len(input_tokens)
|
|
365
|
+
p_index = np.arange(seq_len)
|
|
366
|
+
w_index = np.arange(seq_len)
|
|
367
|
+
h_index = np.arange(seq_len)
|
|
368
|
+
t_index = np.arange(seq_len)
|
|
369
|
+
|
|
370
|
+
# Process image positions if we have images
|
|
371
|
+
if image_grid_thw is not None and len(image_start_indices) > 0:
|
|
372
|
+
for image_index in range(len(image_start_indices)):
|
|
373
|
+
# +2: skip first image_token and account for xdrope positions
|
|
374
|
+
pos = int(image_start_indices[image_index]) + 1
|
|
375
|
+
_, h, w = image_grid_thw.flatten().tolist()
|
|
376
|
+
|
|
377
|
+
llm_grid_h = h // spatial_merge_size
|
|
378
|
+
llm_grid_w = w // spatial_merge_size
|
|
379
|
+
|
|
380
|
+
token_num = (llm_grid_w + 1) * llm_grid_h
|
|
381
|
+
|
|
382
|
+
# Ensure we don't go out of bounds
|
|
383
|
+
end_pos = min(pos + token_num, seq_len)
|
|
384
|
+
actual_token_num = end_pos - pos
|
|
385
|
+
|
|
386
|
+
if actual_token_num > 0:
|
|
387
|
+
# w_index: [0, 1, ..., grid_w, 0, 1, ..., grid_w, ...] repeated for each row
|
|
388
|
+
w_pattern = np.tile(np.arange(llm_grid_w + 1), llm_grid_h)[
|
|
389
|
+
:actual_token_num
|
|
390
|
+
]
|
|
391
|
+
w_index[pos:end_pos] = w_pattern
|
|
392
|
+
|
|
393
|
+
# h_index: [0, 0, ..., 0, 1, 1, ..., 1, ...] each repeated (grid_w + 1) times
|
|
394
|
+
h_pattern = np.repeat(np.arange(llm_grid_h), llm_grid_w + 1)[
|
|
395
|
+
:actual_token_num
|
|
396
|
+
]
|
|
397
|
+
h_index[pos:end_pos] = h_pattern
|
|
398
|
+
|
|
399
|
+
# t_index: image index for temporal dimension
|
|
400
|
+
t_index[pos:end_pos] = image_index
|
|
401
|
+
|
|
402
|
+
# Stack based on number of xdrope dimensions
|
|
403
|
+
if xd_num == 4:
|
|
404
|
+
llm_positions = mx.stack(
|
|
405
|
+
[
|
|
406
|
+
mx.array(p_index),
|
|
407
|
+
mx.array(t_index),
|
|
408
|
+
mx.array(h_index),
|
|
409
|
+
mx.array(w_index),
|
|
410
|
+
]
|
|
411
|
+
)
|
|
412
|
+
elif xd_num == 3:
|
|
413
|
+
llm_positions = mx.stack(
|
|
414
|
+
[
|
|
415
|
+
mx.array(t_index),
|
|
416
|
+
mx.array(h_index),
|
|
417
|
+
mx.array(w_index),
|
|
418
|
+
]
|
|
419
|
+
)
|
|
420
|
+
else:
|
|
421
|
+
# Fallback: just use sequential positions
|
|
422
|
+
llm_positions = mx.stack([mx.array(p_index)] * xd_num)
|
|
423
|
+
|
|
424
|
+
return llm_positions
|
|
425
|
+
|
|
426
|
+
def __call__(
|
|
427
|
+
self,
|
|
428
|
+
inputs: Optional[mx.array] = None,
|
|
429
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
430
|
+
mask: Optional[mx.array] = None,
|
|
431
|
+
cache=None,
|
|
432
|
+
**kwargs,
|
|
433
|
+
) -> LanguageModelOutput:
|
|
434
|
+
|
|
435
|
+
kwargs_position_ids = kwargs.pop("position_ids", None)
|
|
436
|
+
|
|
437
|
+
# Compute cache offset
|
|
438
|
+
cache_offset = 0
|
|
439
|
+
if cache is not None and cache[0] is not None:
|
|
440
|
+
offset = cache[0].offset
|
|
441
|
+
if isinstance(offset, int):
|
|
442
|
+
cache_offset = offset
|
|
443
|
+
elif isinstance(offset, mx.array):
|
|
444
|
+
cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
|
|
445
|
+
else:
|
|
446
|
+
cache_offset = int(offset)
|
|
447
|
+
|
|
448
|
+
# Determine sequence length from inputs or inputs_embeds
|
|
449
|
+
if inputs_embeds is not None:
|
|
450
|
+
seq_length = inputs_embeds.shape[1]
|
|
451
|
+
elif inputs is not None:
|
|
452
|
+
seq_length = inputs.shape[1]
|
|
453
|
+
else:
|
|
454
|
+
seq_length = 0
|
|
455
|
+
|
|
456
|
+
position_ids = None
|
|
457
|
+
if cache is None or cache_offset == 0:
|
|
458
|
+
# Prefill phase - need xdrope position_ids
|
|
459
|
+
if self._position_ids is not None:
|
|
460
|
+
# Use stored position_ids (sliced for chunked prefill)
|
|
461
|
+
position_ids = self._position_ids[
|
|
462
|
+
:, :, cache_offset : cache_offset + seq_length
|
|
463
|
+
]
|
|
464
|
+
elif kwargs_position_ids is not None:
|
|
465
|
+
# Use position_ids from kwargs (e.g., from processor)
|
|
466
|
+
if not isinstance(kwargs_position_ids, mx.array):
|
|
467
|
+
kwargs_position_ids = mx.array(kwargs_position_ids)
|
|
468
|
+
# Store for potential future chunks and slice for current chunk
|
|
469
|
+
self._position_ids = kwargs_position_ids
|
|
470
|
+
position_ids = self._position_ids[
|
|
471
|
+
:, :, cache_offset : cache_offset + seq_length
|
|
472
|
+
]
|
|
473
|
+
elif inputs is not None:
|
|
474
|
+
# Compute position_ids on the fly (for non-chunked prefill)
|
|
475
|
+
position_ids = self.get_xdrope_input_positions(
|
|
476
|
+
input_tokens=inputs[0].tolist(),
|
|
477
|
+
image_grid_thw=kwargs.get("image_grid_thw", None),
|
|
478
|
+
image_token_id=self.config.image_token_id,
|
|
479
|
+
spatial_merge_size=self.config.vision_config.spatial_merge_size,
|
|
480
|
+
)[None, ...]
|
|
481
|
+
# Store for potential future chunks
|
|
482
|
+
self._position_ids = position_ids
|
|
483
|
+
|
|
484
|
+
out = self.model(
|
|
485
|
+
input_ids=inputs,
|
|
486
|
+
inputs_embeds=inputs_embeds,
|
|
487
|
+
mask=mask,
|
|
488
|
+
cache=cache,
|
|
489
|
+
position_ids=position_ids,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
if self.args.tie_word_embeddings:
|
|
493
|
+
logits = self.model.embed_tokens.as_linear(out)
|
|
494
|
+
else:
|
|
495
|
+
logits = self.lm_head(out)
|
|
496
|
+
|
|
497
|
+
return LanguageModelOutput(logits=logits)
|
|
498
|
+
|
|
499
|
+
@property
|
|
500
|
+
def layers(self):
|
|
501
|
+
return self.model.layers
|
|
502
|
+
|
|
503
|
+
@property
|
|
504
|
+
def head_dim(self):
|
|
505
|
+
return self.args.head_dim
|
|
506
|
+
|
|
507
|
+
@property
|
|
508
|
+
def n_kv_heads(self):
|
|
509
|
+
return self.args.num_key_value_heads
|