nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
|
@@ -1,490 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
import mlx.core as mx
|
|
4
|
-
import mlx.nn as nn
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
from ..base import (
|
|
8
|
-
LanguageModelOutput,
|
|
9
|
-
create_attention_mask,
|
|
10
|
-
scaled_dot_product_attention,
|
|
11
|
-
)
|
|
12
|
-
from ..cache import KVCache
|
|
13
|
-
from .config import ModelConfig, TextConfig
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Qwen2RotaryEmbedding:
|
|
17
|
-
def __init__(self, dim, max_position_embeddings=2048, base=10000):
|
|
18
|
-
self.dim = dim
|
|
19
|
-
self.max_position_embeddings = max_position_embeddings
|
|
20
|
-
self.base = base
|
|
21
|
-
|
|
22
|
-
inv_freq = 1.0 / (
|
|
23
|
-
self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
|
|
24
|
-
)
|
|
25
|
-
self.inv_freq = inv_freq
|
|
26
|
-
|
|
27
|
-
self._set_cos_sin_cache(seq_len=max_position_embeddings)
|
|
28
|
-
|
|
29
|
-
def _set_cos_sin_cache(self, seq_len):
|
|
30
|
-
self.max_seq_len_cached = seq_len
|
|
31
|
-
t = mx.arange(self.max_seq_len_cached).astype(mx.float32)
|
|
32
|
-
|
|
33
|
-
freqs = mx.outer(t, self.inv_freq)
|
|
34
|
-
emb = mx.concatenate((freqs, freqs), axis=-1)
|
|
35
|
-
self.cos_cached = mx.cos(emb)
|
|
36
|
-
self.sin_cached = mx.sin(emb)
|
|
37
|
-
|
|
38
|
-
def __call__(self, x, seq_len=None):
|
|
39
|
-
|
|
40
|
-
if seq_len > self.max_seq_len_cached:
|
|
41
|
-
self._set_cos_sin_cache(seq_len=seq_len)
|
|
42
|
-
|
|
43
|
-
return (
|
|
44
|
-
self.cos_cached[:seq_len].astype(x.dtype),
|
|
45
|
-
self.sin_cached[:seq_len].astype(x.dtype),
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def rotate_half(x):
|
|
50
|
-
"""Rotates half the hidden dims of the input."""
|
|
51
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
52
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
53
|
-
return mx.concatenate([-x2, x1], axis=-1)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section):
|
|
57
|
-
"""
|
|
58
|
-
Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
|
|
59
|
-
Args:
|
|
60
|
-
q (mx.array): The query tensor.
|
|
61
|
-
k (mx.array): The key tensor.
|
|
62
|
-
cos (mx.array): The cosine part of the rotary embedding.
|
|
63
|
-
sin (mx.array): The sine part of the rotary embedding.
|
|
64
|
-
mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
|
|
65
|
-
unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
|
|
66
|
-
Returns:
|
|
67
|
-
tuple(mx.array): The rotated query and key tensors.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
|
|
71
|
-
cos = cos[position_ids]
|
|
72
|
-
sin = sin[position_ids]
|
|
73
|
-
|
|
74
|
-
cos = mx.concatenate(
|
|
75
|
-
[m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
|
|
76
|
-
)[
|
|
77
|
-
:, None, :, :
|
|
78
|
-
] # unsqueeze dim 1
|
|
79
|
-
sin = mx.concatenate(
|
|
80
|
-
[m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
|
|
81
|
-
)[:, None, :, :]
|
|
82
|
-
|
|
83
|
-
# Apply rotary embedding
|
|
84
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
85
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
86
|
-
|
|
87
|
-
return q_embed, k_embed
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class Attention(nn.Module):
|
|
91
|
-
def __init__(self, args: TextConfig):
|
|
92
|
-
super().__init__()
|
|
93
|
-
|
|
94
|
-
dim = args.hidden_size
|
|
95
|
-
self.n_heads = n_heads = args.num_attention_heads
|
|
96
|
-
assert args.num_key_value_heads is not None
|
|
97
|
-
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
98
|
-
|
|
99
|
-
self.head_dim = head_dim = args.hidden_size // n_heads
|
|
100
|
-
self.scale = head_dim**-0.5
|
|
101
|
-
|
|
102
|
-
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
|
103
|
-
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
|
104
|
-
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
|
105
|
-
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
106
|
-
|
|
107
|
-
self.rope_scaling = args.rope_scaling
|
|
108
|
-
|
|
109
|
-
self.rotary_emb = Qwen2RotaryEmbedding(
|
|
110
|
-
head_dim,
|
|
111
|
-
max_position_embeddings=args.max_position_embeddings,
|
|
112
|
-
base=args.rope_theta,
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
def __call__(
|
|
116
|
-
self,
|
|
117
|
-
x: mx.array,
|
|
118
|
-
mask: Optional[mx.array] = None,
|
|
119
|
-
cache: Optional[KVCache] = None,
|
|
120
|
-
position_ids: Optional[mx.array] = None,
|
|
121
|
-
) -> mx.array:
|
|
122
|
-
B, L, D = x.shape
|
|
123
|
-
|
|
124
|
-
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
125
|
-
|
|
126
|
-
# Prepare the queries, keys and values for the attention computation
|
|
127
|
-
queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
|
|
128
|
-
0, 2, 1, 3
|
|
129
|
-
)
|
|
130
|
-
keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
131
|
-
values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
|
|
132
|
-
0, 2, 1, 3
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
kv_seq_len = keys.shape[-2]
|
|
136
|
-
|
|
137
|
-
if position_ids is None:
|
|
138
|
-
kv_seq_len += cache.offset + 1
|
|
139
|
-
position_ids = mx.arange(cache.offset, cache.offset + L)
|
|
140
|
-
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
141
|
-
position_ids = mx.tile(position_ids, (3, 1, 1))
|
|
142
|
-
else:
|
|
143
|
-
kv_seq_len += cache.offset + 1 if cache is not None else 0
|
|
144
|
-
|
|
145
|
-
cos, sin = self.rotary_emb(values, kv_seq_len)
|
|
146
|
-
|
|
147
|
-
if mask is not None and isinstance(mask, mx.array):
|
|
148
|
-
mask = mask[..., : keys.shape[-2]]
|
|
149
|
-
queries, keys = apply_multimodal_rotary_pos_emb(
|
|
150
|
-
queries, keys, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
if cache is not None:
|
|
154
|
-
keys, values = cache.update_and_fetch(keys, values)
|
|
155
|
-
|
|
156
|
-
output = scaled_dot_product_attention(
|
|
157
|
-
queries, keys, values, cache, scale=self.scale, mask=mask
|
|
158
|
-
)
|
|
159
|
-
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
160
|
-
return self.o_proj(output)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class MLP(nn.Module):
|
|
164
|
-
def __init__(self, dim, hidden_dim):
|
|
165
|
-
super().__init__()
|
|
166
|
-
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
167
|
-
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
168
|
-
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
169
|
-
|
|
170
|
-
def __call__(self, x) -> mx.array:
|
|
171
|
-
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
class Qwen2VLDecoderLayer(nn.Module):
|
|
175
|
-
def __init__(self, args: TextConfig):
|
|
176
|
-
super().__init__()
|
|
177
|
-
self.num_attention_heads = args.num_attention_heads
|
|
178
|
-
self.hidden_size = args.hidden_size
|
|
179
|
-
self.self_attn = Attention(args)
|
|
180
|
-
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
|
181
|
-
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
182
|
-
self.post_attention_layernorm = nn.RMSNorm(
|
|
183
|
-
args.hidden_size, eps=args.rms_norm_eps
|
|
184
|
-
)
|
|
185
|
-
self.args = args
|
|
186
|
-
|
|
187
|
-
def __call__(
|
|
188
|
-
self,
|
|
189
|
-
x: mx.array,
|
|
190
|
-
mask: Optional[mx.array] = None,
|
|
191
|
-
cache: Optional[KVCache] = None,
|
|
192
|
-
position_ids: Optional[mx.array] = None,
|
|
193
|
-
) -> mx.array:
|
|
194
|
-
r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
|
|
195
|
-
h = x + r
|
|
196
|
-
r = self.mlp(self.post_attention_layernorm(h))
|
|
197
|
-
out = h + r
|
|
198
|
-
return out
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
class Qwen2Model(nn.Module):
|
|
202
|
-
def __init__(self, args: TextConfig):
|
|
203
|
-
super().__init__()
|
|
204
|
-
self.args = args
|
|
205
|
-
self.vocab_size = args.vocab_size
|
|
206
|
-
self.num_hidden_layers = args.num_hidden_layers
|
|
207
|
-
assert self.vocab_size > 0
|
|
208
|
-
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
209
|
-
self.layers = [
|
|
210
|
-
Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
|
211
|
-
]
|
|
212
|
-
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
|
213
|
-
|
|
214
|
-
def __call__(
|
|
215
|
-
self,
|
|
216
|
-
inputs: mx.array,
|
|
217
|
-
inputs_embeds: Optional[mx.array] = None,
|
|
218
|
-
mask: Optional[mx.array] = None,
|
|
219
|
-
cache=None,
|
|
220
|
-
position_ids: Optional[mx.array] = None,
|
|
221
|
-
):
|
|
222
|
-
if inputs_embeds is None:
|
|
223
|
-
h = self.embed_tokens(inputs)
|
|
224
|
-
else:
|
|
225
|
-
h = inputs_embeds
|
|
226
|
-
|
|
227
|
-
if cache is None:
|
|
228
|
-
cache = [None] * len(self.layers)
|
|
229
|
-
|
|
230
|
-
if mask is None:
|
|
231
|
-
mask = create_attention_mask(h, cache)
|
|
232
|
-
|
|
233
|
-
for layer, c in zip(self.layers, cache):
|
|
234
|
-
h = layer(h, mask, c, position_ids)
|
|
235
|
-
|
|
236
|
-
return self.norm(h)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class LanguageModel(nn.Module):
|
|
240
|
-
def __init__(self, args: TextConfig, config: ModelConfig):
|
|
241
|
-
super().__init__()
|
|
242
|
-
self.args = args
|
|
243
|
-
self.config = config
|
|
244
|
-
self.model_type = args.model_type
|
|
245
|
-
self.model = Qwen2Model(args)
|
|
246
|
-
self.rope_deltas = None
|
|
247
|
-
|
|
248
|
-
if not args.tie_word_embeddings:
|
|
249
|
-
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
250
|
-
|
|
251
|
-
def get_rope_index(
|
|
252
|
-
self,
|
|
253
|
-
input_ids: mx.array,
|
|
254
|
-
image_grid_thw: Optional[mx.array] = None,
|
|
255
|
-
video_grid_thw: Optional[mx.array] = None,
|
|
256
|
-
attention_mask: Optional[mx.array] = None,
|
|
257
|
-
):
|
|
258
|
-
# Calculate RoPE index for image/video tokens
|
|
259
|
-
batch_size, seq_length = input_ids.shape
|
|
260
|
-
position_ids = mx.arange(seq_length, dtype=mx.int32)
|
|
261
|
-
position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
|
|
262
|
-
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
|
263
|
-
image_token_id = self.config.image_token_id
|
|
264
|
-
video_token_id = self.config.video_token_id
|
|
265
|
-
vision_start_token_id = self.config.vision_start_token_id
|
|
266
|
-
mrope_position_deltas = []
|
|
267
|
-
if input_ids is not None and (
|
|
268
|
-
image_grid_thw is not None or video_grid_thw is not None
|
|
269
|
-
):
|
|
270
|
-
total_input_ids = input_ids
|
|
271
|
-
if attention_mask is None:
|
|
272
|
-
attention_mask = mx.ones_like(input_ids)
|
|
273
|
-
position_ids = mx.ones(
|
|
274
|
-
(3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
|
|
275
|
-
)
|
|
276
|
-
image_index, video_index = 0, 0
|
|
277
|
-
for i, input_ids in enumerate(total_input_ids):
|
|
278
|
-
input_ids = mx.where(
|
|
279
|
-
attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
|
|
280
|
-
)
|
|
281
|
-
image_nums, video_nums = 0, 0
|
|
282
|
-
vision_start_indices = mx.sum(
|
|
283
|
-
mx.where(
|
|
284
|
-
input_ids == vision_start_token_id,
|
|
285
|
-
mx.arange(input_ids.shape[0]),
|
|
286
|
-
mx.zeros_like(input_ids),
|
|
287
|
-
)
|
|
288
|
-
)
|
|
289
|
-
vision_tokens = input_ids[vision_start_indices + 1]
|
|
290
|
-
image_nums = (vision_tokens == image_token_id).sum().item()
|
|
291
|
-
video_nums = (vision_tokens == video_token_id).sum().item()
|
|
292
|
-
input_tokens = input_ids.tolist()
|
|
293
|
-
llm_pos_ids_list: list = []
|
|
294
|
-
st = 0
|
|
295
|
-
remain_images, remain_videos = image_nums, video_nums
|
|
296
|
-
for _ in range(image_nums + video_nums):
|
|
297
|
-
if image_token_id in input_tokens and remain_images > 0:
|
|
298
|
-
ed_image = input_tokens.index(image_token_id, st)
|
|
299
|
-
else:
|
|
300
|
-
ed_image = len(input_tokens) + 1
|
|
301
|
-
if video_token_id in input_tokens and remain_videos > 0:
|
|
302
|
-
ed_video = input_tokens.index(video_token_id, st)
|
|
303
|
-
else:
|
|
304
|
-
ed_video = len(input_tokens) + 1
|
|
305
|
-
if ed_image < ed_video:
|
|
306
|
-
t, h, w = (
|
|
307
|
-
image_grid_thw[image_index][0],
|
|
308
|
-
image_grid_thw[image_index][1],
|
|
309
|
-
image_grid_thw[image_index][2],
|
|
310
|
-
)
|
|
311
|
-
image_index += 1
|
|
312
|
-
remain_images -= 1
|
|
313
|
-
ed = ed_image
|
|
314
|
-
else:
|
|
315
|
-
t, h, w = (
|
|
316
|
-
video_grid_thw[video_index][0],
|
|
317
|
-
video_grid_thw[video_index][1],
|
|
318
|
-
video_grid_thw[video_index][2],
|
|
319
|
-
)
|
|
320
|
-
video_index += 1
|
|
321
|
-
remain_videos -= 1
|
|
322
|
-
ed = ed_video
|
|
323
|
-
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
324
|
-
t.item(),
|
|
325
|
-
h.item() // spatial_merge_size,
|
|
326
|
-
w.item() // spatial_merge_size,
|
|
327
|
-
)
|
|
328
|
-
text_len = ed - st
|
|
329
|
-
st_idx = (
|
|
330
|
-
llm_pos_ids_list[-1].max() + 1
|
|
331
|
-
if len(llm_pos_ids_list) > 0
|
|
332
|
-
else 0
|
|
333
|
-
)
|
|
334
|
-
index = mx.arange(text_len).reshape(1, text_len)
|
|
335
|
-
index = mx.broadcast_to(index, (3, text_len))
|
|
336
|
-
index = index + st_idx
|
|
337
|
-
llm_pos_ids_list.append(index)
|
|
338
|
-
t_index = mx.arange(llm_grid_t).reshape(
|
|
339
|
-
llm_grid_t, 1
|
|
340
|
-
) # Equivalent to .view(-1, 1)
|
|
341
|
-
t_index = mx.broadcast_to(
|
|
342
|
-
t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
|
|
343
|
-
) # Equivalent to expand()
|
|
344
|
-
t_index = t_index.flatten() # Flattens to 1D
|
|
345
|
-
|
|
346
|
-
h_index = mx.arange(llm_grid_h).reshape(
|
|
347
|
-
1, llm_grid_h, 1
|
|
348
|
-
) # Equivalent to .view(1, -1)
|
|
349
|
-
h_index = mx.broadcast_to(
|
|
350
|
-
h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
351
|
-
) # Equivalent to expand()
|
|
352
|
-
h_index = h_index.flatten() # Flattens to 1D
|
|
353
|
-
|
|
354
|
-
w_index = mx.arange(llm_grid_w).reshape(
|
|
355
|
-
1, 1, llm_grid_w
|
|
356
|
-
) # Equivalent to .view(1, -1)
|
|
357
|
-
w_index = mx.broadcast_to(
|
|
358
|
-
w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
359
|
-
) # Equivalent to expand()
|
|
360
|
-
w_index = w_index.flatten() # Flattens to 1D
|
|
361
|
-
|
|
362
|
-
llm_pos_ids_list.append(
|
|
363
|
-
mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
364
|
-
)
|
|
365
|
-
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
366
|
-
if st < len(input_tokens):
|
|
367
|
-
st_idx = (
|
|
368
|
-
llm_pos_ids_list[-1].max() + 1
|
|
369
|
-
if len(llm_pos_ids_list) > 0
|
|
370
|
-
else 0
|
|
371
|
-
)
|
|
372
|
-
text_len = len(input_tokens) - st
|
|
373
|
-
|
|
374
|
-
t_index = mx.arange(text_len).reshape(
|
|
375
|
-
1, text_len
|
|
376
|
-
) # Equivalent to .view(-1, 1)
|
|
377
|
-
t_index = mx.broadcast_to(
|
|
378
|
-
t_index, (3, text_len)
|
|
379
|
-
) # Equivalent to expand(3, -1)
|
|
380
|
-
|
|
381
|
-
llm_pos_ids_list.append(t_index + st_idx)
|
|
382
|
-
|
|
383
|
-
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
384
|
-
mask = mx.array(attention_mask[i] == 1)
|
|
385
|
-
expanded_mask = mx.expand_dims(mask, axis=0)
|
|
386
|
-
expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
|
|
387
|
-
expanded_positions = mx.expand_dims(llm_positions, axis=1)
|
|
388
|
-
new_positions = mx.where(
|
|
389
|
-
expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
|
|
390
|
-
)
|
|
391
|
-
updated_position_ids = mx.concatenate(
|
|
392
|
-
[
|
|
393
|
-
position_ids[:, :i, :],
|
|
394
|
-
new_positions,
|
|
395
|
-
position_ids[:, i + 1 :, :],
|
|
396
|
-
],
|
|
397
|
-
axis=1,
|
|
398
|
-
)
|
|
399
|
-
position_ids = updated_position_ids
|
|
400
|
-
mrope_position_deltas.append(
|
|
401
|
-
llm_positions.max() + 1 - len(total_input_ids[i])
|
|
402
|
-
)
|
|
403
|
-
mrope_position_deltas = mx.array(mrope_position_deltas)[0]
|
|
404
|
-
return position_ids, mrope_position_deltas
|
|
405
|
-
else:
|
|
406
|
-
if attention_mask is not None:
|
|
407
|
-
position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
|
|
408
|
-
position_ids = mx.where(
|
|
409
|
-
attention_mask == 0, mx.ones_like(position_ids), position_ids
|
|
410
|
-
)
|
|
411
|
-
position_ids = mx.expand_dims(position_ids[0], axis=0)
|
|
412
|
-
position_ids = mx.tile(position_ids, (3, 1, 1))
|
|
413
|
-
max_position_ids = position_ids.max(0, keepdims=False)[0].max(
|
|
414
|
-
-1, keepdims=True
|
|
415
|
-
)[0]
|
|
416
|
-
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
417
|
-
else:
|
|
418
|
-
position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
|
|
419
|
-
position_ids = mx.broadcast_to(
|
|
420
|
-
position_ids, (3, input_ids.shape[0], input_ids.shape[1])
|
|
421
|
-
)
|
|
422
|
-
mrope_position_deltas = mx.zeros(
|
|
423
|
-
[input_ids.shape[0], 1],
|
|
424
|
-
dtype=input_ids.dtype,
|
|
425
|
-
)
|
|
426
|
-
return position_ids, mrope_position_deltas
|
|
427
|
-
|
|
428
|
-
def __call__(
|
|
429
|
-
self,
|
|
430
|
-
inputs: mx.array,
|
|
431
|
-
inputs_embeds: Optional[mx.array] = None,
|
|
432
|
-
mask: Optional[mx.array] = None,
|
|
433
|
-
cache=None,
|
|
434
|
-
**kwargs,
|
|
435
|
-
):
|
|
436
|
-
|
|
437
|
-
position_ids = kwargs.pop("position_ids", None)
|
|
438
|
-
pixel_values = kwargs.pop("pixel_values", None)
|
|
439
|
-
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
440
|
-
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
|
441
|
-
# reset rope_deltas when processing a new image/video
|
|
442
|
-
if pixel_values is not None:
|
|
443
|
-
self.rope_deltas = None
|
|
444
|
-
|
|
445
|
-
if position_ids is None and (mask is None or mask.ndim == 2):
|
|
446
|
-
# Calculate RoPE index once per generation in the pre-fill stage only
|
|
447
|
-
if (
|
|
448
|
-
(cache is not None and cache[0] is not None and cache[0].offset == 0)
|
|
449
|
-
or self.rope_deltas is None
|
|
450
|
-
or cache is None
|
|
451
|
-
):
|
|
452
|
-
position_ids, rope_deltas = self.get_rope_index(
|
|
453
|
-
inputs, image_grid_thw, video_grid_thw, mask
|
|
454
|
-
)
|
|
455
|
-
self.rope_deltas = rope_deltas
|
|
456
|
-
else:
|
|
457
|
-
# Use the prev pre-calculated rope-deltas to get the correct position ids
|
|
458
|
-
batch_size, seq_length = inputs.shape
|
|
459
|
-
delta = cache[-1].offset + self.rope_deltas if cache is not None else 0
|
|
460
|
-
delta = delta[None][None]
|
|
461
|
-
position_ids = mx.arange(seq_length).reshape(1, seq_length)
|
|
462
|
-
position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
|
|
463
|
-
if cache is not None:
|
|
464
|
-
# Repeat delta for each batch
|
|
465
|
-
delta = mx.repeat(delta, batch_size // delta.shape[0], axis=0)
|
|
466
|
-
position_ids = mx.add(position_ids, delta).reshape(position_ids.shape)
|
|
467
|
-
position_ids = mx.broadcast_to(
|
|
468
|
-
position_ids, (3, batch_size, seq_length)
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
out = self.model(
|
|
472
|
-
inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
|
|
473
|
-
)
|
|
474
|
-
if self.args.tie_word_embeddings:
|
|
475
|
-
out = self.model.embed_tokens.as_linear(out)
|
|
476
|
-
else:
|
|
477
|
-
out = self.lm_head(out)
|
|
478
|
-
return LanguageModelOutput(logits=out)
|
|
479
|
-
|
|
480
|
-
@property
|
|
481
|
-
def layers(self):
|
|
482
|
-
return self.model.layers
|
|
483
|
-
|
|
484
|
-
@property
|
|
485
|
-
def head_dim(self):
|
|
486
|
-
return self.args.hidden_size // self.args.num_attention_heads
|
|
487
|
-
|
|
488
|
-
@property
|
|
489
|
-
def n_kv_heads(self):
|
|
490
|
-
return self.args.num_key_value_heads
|
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
import glob
|
|
2
|
-
import inspect
|
|
3
|
-
import json
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import List, Optional
|
|
7
|
-
|
|
8
|
-
import mlx.core as mx
|
|
9
|
-
import mlx.nn as nn
|
|
10
|
-
import numpy as np
|
|
11
|
-
from huggingface_hub import snapshot_download
|
|
12
|
-
|
|
13
|
-
from .config import ModelConfig, TextConfig, VisionConfig
|
|
14
|
-
from .language import LanguageModel
|
|
15
|
-
from .vision import VisionModel
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Model(nn.Module):
|
|
19
|
-
def __init__(self, config: ModelConfig):
|
|
20
|
-
super().__init__()
|
|
21
|
-
self.config = config
|
|
22
|
-
self.vision_tower = VisionModel(config.vision_config)
|
|
23
|
-
self.language_model = LanguageModel(config.text_config, config)
|
|
24
|
-
|
|
25
|
-
def get_input_embeddings(
|
|
26
|
-
self,
|
|
27
|
-
input_ids: Optional[mx.array] = None,
|
|
28
|
-
pixel_values: Optional[mx.array] = None,
|
|
29
|
-
grid_thw: Optional[mx.array] = None,
|
|
30
|
-
):
|
|
31
|
-
|
|
32
|
-
if pixel_values is None:
|
|
33
|
-
return self.language_model.model.embed_tokens(input_ids)
|
|
34
|
-
|
|
35
|
-
dtype = self.vision_tower.patch_embed.proj.weight.dtype
|
|
36
|
-
pixel_values = pixel_values.astype(dtype)
|
|
37
|
-
|
|
38
|
-
# Get the input embeddings from the language model
|
|
39
|
-
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
40
|
-
|
|
41
|
-
# Get the ouptut hidden states from the vision model
|
|
42
|
-
hidden_states = self.vision_tower(
|
|
43
|
-
pixel_values, grid_thw, output_hidden_states=False
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
# Insert special image tokens in the input_ids
|
|
47
|
-
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
48
|
-
self.config.image_token_id,
|
|
49
|
-
self.config.video_token_id,
|
|
50
|
-
hidden_states,
|
|
51
|
-
inputs_embeds,
|
|
52
|
-
input_ids,
|
|
53
|
-
)
|
|
54
|
-
return final_inputs_embeds
|
|
55
|
-
|
|
56
|
-
@staticmethod
|
|
57
|
-
def merge_input_ids_with_image_features(
|
|
58
|
-
image_token_id,
|
|
59
|
-
video_token_id,
|
|
60
|
-
image_features,
|
|
61
|
-
inputs_embeds,
|
|
62
|
-
input_ids,
|
|
63
|
-
):
|
|
64
|
-
"""Merge image features into input embeddings at image token positions.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
image_features: Vision features from the vision tower [num_features, hidden_dim]
|
|
68
|
-
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
69
|
-
input_ids: Input token IDs [batch_size, seq_len]
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
Updated input embeddings with image features inserted
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
# Positions of <image> tokens in input_ids
|
|
76
|
-
image_positions = input_ids == image_token_id
|
|
77
|
-
if mx.sum(image_positions) == 0:
|
|
78
|
-
image_positions = input_ids == video_token_id
|
|
79
|
-
|
|
80
|
-
# Get dimensions
|
|
81
|
-
batch_size, seq_len = input_ids.shape
|
|
82
|
-
|
|
83
|
-
# Process each batch item
|
|
84
|
-
batch_outputs = []
|
|
85
|
-
feature_start_idx = 0
|
|
86
|
-
|
|
87
|
-
for batch_idx in range(batch_size):
|
|
88
|
-
# Get mask for this batch
|
|
89
|
-
image_mask = image_positions[batch_idx]
|
|
90
|
-
num_positions = mx.sum(image_mask).item()
|
|
91
|
-
|
|
92
|
-
if num_positions > 0:
|
|
93
|
-
# Extract features for this batch
|
|
94
|
-
batch_features = image_features[
|
|
95
|
-
feature_start_idx : feature_start_idx + num_positions
|
|
96
|
-
]
|
|
97
|
-
|
|
98
|
-
# Validate we have the right number of features
|
|
99
|
-
if batch_features.shape[0] != num_positions:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"Number of image token positions ({num_positions}) does not match "
|
|
102
|
-
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
# Create indices for gathering
|
|
106
|
-
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
107
|
-
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
108
|
-
|
|
109
|
-
# Gather features
|
|
110
|
-
gathered_features = batch_features[feature_indices]
|
|
111
|
-
|
|
112
|
-
# Combine with original embeddings
|
|
113
|
-
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
114
|
-
batch_output = mx.where(
|
|
115
|
-
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
feature_start_idx += num_positions
|
|
119
|
-
else:
|
|
120
|
-
# No image tokens in this batch item
|
|
121
|
-
batch_output = inputs_embeds[batch_idx]
|
|
122
|
-
|
|
123
|
-
batch_outputs.append(batch_output)
|
|
124
|
-
|
|
125
|
-
# Stack all batch outputs
|
|
126
|
-
return mx.stack(batch_outputs, axis=0)
|
|
127
|
-
|
|
128
|
-
@property
|
|
129
|
-
def layers(self):
|
|
130
|
-
return self.language_model.model.layers
|
|
131
|
-
|
|
132
|
-
def __call__(
|
|
133
|
-
self,
|
|
134
|
-
input_ids: mx.array,
|
|
135
|
-
pixel_values: Optional[mx.array] = None,
|
|
136
|
-
mask: Optional[mx.array] = None,
|
|
137
|
-
cache=None,
|
|
138
|
-
**kwargs,
|
|
139
|
-
):
|
|
140
|
-
|
|
141
|
-
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
|
142
|
-
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
|
143
|
-
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
144
|
-
input_embddings = self.get_input_embeddings(input_ids, pixel_values, grid_thw)
|
|
145
|
-
kwargs = {
|
|
146
|
-
"pixel_values": pixel_values,
|
|
147
|
-
"image_grid_thw": image_grid_thw,
|
|
148
|
-
"video_grid_thw": video_grid_thw,
|
|
149
|
-
**kwargs,
|
|
150
|
-
}
|
|
151
|
-
logits = self.language_model(
|
|
152
|
-
input_ids, input_embddings, mask=mask, cache=cache, **kwargs
|
|
153
|
-
)
|
|
154
|
-
return logits
|
|
155
|
-
|
|
156
|
-
def sanitize(self, weights):
|
|
157
|
-
def transform_key(key):
|
|
158
|
-
if "vision_tower" not in key:
|
|
159
|
-
key = key.replace("visual", "vision_tower")
|
|
160
|
-
if "language_model" not in key:
|
|
161
|
-
if "model" in key:
|
|
162
|
-
key = key.replace("model", "language_model.model")
|
|
163
|
-
elif "lm_head" in key:
|
|
164
|
-
key = key.replace("lm_head", "language_model.lm_head")
|
|
165
|
-
return key
|
|
166
|
-
|
|
167
|
-
return {transform_key(k): v for k, v in weights.items()}
|