nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
|
@@ -1,307 +0,0 @@
|
|
|
1
|
-
import mlx.core as mx
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def nearest_interpolate(x, size=None, scale_factor=None):
|
|
5
|
-
"""
|
|
6
|
-
Nearest neighbor interpolation that exactly matches PyTorch's behavior.
|
|
7
|
-
"""
|
|
8
|
-
# Get input dimensions
|
|
9
|
-
batch_size, channels, in_h, in_w = x.shape
|
|
10
|
-
|
|
11
|
-
# Calculate output dimensions
|
|
12
|
-
if size is not None:
|
|
13
|
-
out_h, out_w = size
|
|
14
|
-
elif scale_factor is not None:
|
|
15
|
-
if isinstance(scale_factor, (int, float)):
|
|
16
|
-
scale_h = scale_w = scale_factor
|
|
17
|
-
else:
|
|
18
|
-
scale_h, scale_w = scale_factor
|
|
19
|
-
out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
|
|
20
|
-
else:
|
|
21
|
-
raise ValueError("Either size or scale_factor must be specified")
|
|
22
|
-
|
|
23
|
-
# Create dimensions tensor
|
|
24
|
-
dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
|
|
25
|
-
|
|
26
|
-
# Reshape input tensor to 1D for kernel processing
|
|
27
|
-
x_flat = x.reshape(-1)
|
|
28
|
-
input_dtype = x.dtype
|
|
29
|
-
if input_dtype != mx.float32:
|
|
30
|
-
x_flat = x_flat.astype(mx.float32)
|
|
31
|
-
|
|
32
|
-
# Metal kernel source that matches PyTorch's coordinate calculation
|
|
33
|
-
source = """
|
|
34
|
-
uint x_out = thread_position_in_grid.x;
|
|
35
|
-
uint y_out = thread_position_in_grid.y;
|
|
36
|
-
uint bc_idx = thread_position_in_grid.z;
|
|
37
|
-
|
|
38
|
-
int batch_size = dims[0];
|
|
39
|
-
int channels = dims[1];
|
|
40
|
-
int in_h = dims[2];
|
|
41
|
-
int in_w = dims[3];
|
|
42
|
-
int out_h = dims[4];
|
|
43
|
-
int out_w = dims[5];
|
|
44
|
-
|
|
45
|
-
if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
|
|
46
|
-
return;
|
|
47
|
-
|
|
48
|
-
int c = bc_idx % channels;
|
|
49
|
-
int b = bc_idx / channels;
|
|
50
|
-
|
|
51
|
-
// PyTorch's coordinate calculation for nearest neighbor
|
|
52
|
-
// This matches: torch.nn.functional.interpolate(..., mode='nearest')
|
|
53
|
-
float scale_h = float(in_h) / float(out_h);
|
|
54
|
-
float scale_w = float(in_w) / float(out_w);
|
|
55
|
-
|
|
56
|
-
// PyTorch uses floor for nearest neighbor coordinate mapping
|
|
57
|
-
int y_in = int(floor(float(y_out) * scale_h));
|
|
58
|
-
int x_in = int(floor(float(x_out) * scale_w));
|
|
59
|
-
|
|
60
|
-
// Clamp to bounds
|
|
61
|
-
y_in = max(0, min(y_in, in_h - 1));
|
|
62
|
-
x_in = max(0, min(x_in, in_w - 1));
|
|
63
|
-
|
|
64
|
-
int input_offset = ((b * channels + c) * in_h + y_in) * in_w + x_in;
|
|
65
|
-
int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
|
|
66
|
-
|
|
67
|
-
output[output_offset] = input[input_offset];
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
# Create and run kernel
|
|
71
|
-
kernel = mx.fast.metal_kernel(
|
|
72
|
-
name="nearest_interpolation",
|
|
73
|
-
input_names=["input", "dims"],
|
|
74
|
-
output_names=["output"],
|
|
75
|
-
source=source,
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
threadgroup = get_optimal_threadgroup(out_w, out_h)
|
|
79
|
-
outputs = kernel(
|
|
80
|
-
inputs=[x_flat, dims],
|
|
81
|
-
grid=(out_w, out_h, batch_size * channels),
|
|
82
|
-
threadgroup=threadgroup,
|
|
83
|
-
output_shapes=[(batch_size * channels * out_h * out_w,)],
|
|
84
|
-
output_dtypes=[mx.float32],
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
result = outputs[0].reshape(batch_size, channels, out_h, out_w)
|
|
88
|
-
if input_dtype != mx.float32:
|
|
89
|
-
result = result.astype(input_dtype)
|
|
90
|
-
|
|
91
|
-
return result
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def bicubic_interpolate(x, size=None, scale_factor=None, align_corners=False):
|
|
95
|
-
"""
|
|
96
|
-
Bicubic interpolation using MLX's built-in interpolate function.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
x: MLX tensor of shape [B, C, H, W]
|
|
100
|
-
size: Tuple of (out_h, out_w) or None
|
|
101
|
-
scale_factor: Float or tuple of (scale_h, scale_w) or None
|
|
102
|
-
align_corners: Whether to align corners
|
|
103
|
-
|
|
104
|
-
Returns:
|
|
105
|
-
Interpolated MLX tensor
|
|
106
|
-
"""
|
|
107
|
-
# Get input dimensions
|
|
108
|
-
batch_size, channels, in_h, in_w = x.shape
|
|
109
|
-
|
|
110
|
-
# Calculate output dimensions
|
|
111
|
-
if size is not None:
|
|
112
|
-
out_h, out_w = size
|
|
113
|
-
scale_h, scale_w = out_h / in_h, out_w / in_w
|
|
114
|
-
elif scale_factor is not None:
|
|
115
|
-
if isinstance(scale_factor, (int, float)):
|
|
116
|
-
scale_h = scale_w = scale_factor
|
|
117
|
-
else:
|
|
118
|
-
scale_h, scale_w = scale_factor
|
|
119
|
-
out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
|
|
120
|
-
else:
|
|
121
|
-
raise ValueError("Either size or scale_factor must be specified")
|
|
122
|
-
|
|
123
|
-
# Create scale and align_corners parameters tensor
|
|
124
|
-
params = mx.array(
|
|
125
|
-
[scale_h, scale_w, 1.0 if align_corners else 0.0], dtype=mx.float32
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
# Create dimensions tensor
|
|
129
|
-
dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
|
|
130
|
-
|
|
131
|
-
# Reshape input tensor to 1D for kernel processing
|
|
132
|
-
x_flat = x.reshape(-1)
|
|
133
|
-
|
|
134
|
-
# Convert to float32 for processing if needed
|
|
135
|
-
input_dtype = x.dtype
|
|
136
|
-
if input_dtype != mx.float32:
|
|
137
|
-
x_flat = x_flat.astype(mx.float32)
|
|
138
|
-
|
|
139
|
-
# Metal kernel source code
|
|
140
|
-
source = """
|
|
141
|
-
// Get thread position
|
|
142
|
-
uint x_out = thread_position_in_grid.x;
|
|
143
|
-
uint y_out = thread_position_in_grid.y;
|
|
144
|
-
uint bc_idx = thread_position_in_grid.z;
|
|
145
|
-
|
|
146
|
-
// Extract dimensions from dims
|
|
147
|
-
int batch_size = dims[0];
|
|
148
|
-
int channels = dims[1];
|
|
149
|
-
int in_h = dims[2];
|
|
150
|
-
int in_w = dims[3];
|
|
151
|
-
int out_h = dims[4];
|
|
152
|
-
int out_w = dims[5];
|
|
153
|
-
|
|
154
|
-
// Extract scales and flags
|
|
155
|
-
float scale_h = params[0];
|
|
156
|
-
float scale_w = params[1];
|
|
157
|
-
bool align_corners = params[2] > 0.5;
|
|
158
|
-
|
|
159
|
-
// Check bounds
|
|
160
|
-
if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
|
|
161
|
-
return;
|
|
162
|
-
|
|
163
|
-
// Calculate batch and channel indices
|
|
164
|
-
int c = bc_idx % channels;
|
|
165
|
-
int b = bc_idx / channels;
|
|
166
|
-
|
|
167
|
-
// Calculate input coordinates based on output position
|
|
168
|
-
float x_in, y_in;
|
|
169
|
-
|
|
170
|
-
if (align_corners && out_w > 1 && out_h > 1) {
|
|
171
|
-
x_in = float(x_out) * (in_w - 1) / (out_w - 1);
|
|
172
|
-
y_in = float(y_out) * (in_h - 1) / (out_h - 1);
|
|
173
|
-
} else {
|
|
174
|
-
// Fix the alignment calculation to ensure consistent mapping across thread boundaries
|
|
175
|
-
x_in = ((float(x_out) + 0.5f) / float(out_w)) * float(in_w) - 0.5f;
|
|
176
|
-
y_in = ((float(y_out) + 0.5f) / float(out_h)) * float(in_h) - 0.5f;
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
// Get integer and fractional parts
|
|
180
|
-
int x0 = int(floor(x_in));
|
|
181
|
-
int y0 = int(floor(y_in));
|
|
182
|
-
float x_frac = x_in - x0;
|
|
183
|
-
float y_frac = y_in - y0;
|
|
184
|
-
|
|
185
|
-
// Improved cubic kernel function for better continuity
|
|
186
|
-
auto cubic_kernel = [](float x) -> float {
|
|
187
|
-
float absx = fabs(x);
|
|
188
|
-
float absx2 = absx * absx;
|
|
189
|
-
float absx3 = absx2 * absx;
|
|
190
|
-
|
|
191
|
-
// Use a=-0.5 for smoother interpolation
|
|
192
|
-
const float a = -0.5f;
|
|
193
|
-
|
|
194
|
-
if (absx <= 1.0f) {
|
|
195
|
-
return (a+2.0f)*absx3 - (a+3.0f)*absx2 + 1.0f;
|
|
196
|
-
} else if (absx < 2.0f) {
|
|
197
|
-
return a*absx3 - 5.0f*a*absx2 + 8.0f*a*absx - 4.0f*a;
|
|
198
|
-
}
|
|
199
|
-
return 0.0f;
|
|
200
|
-
};
|
|
201
|
-
|
|
202
|
-
// Perform bicubic interpolation with improved boundary handling
|
|
203
|
-
float result = 0.0f;
|
|
204
|
-
float weight_sum = 0.0f; // Track weight sum for normalization
|
|
205
|
-
|
|
206
|
-
for (int i = -1; i <= 2; i++) {
|
|
207
|
-
int y_pos = y0 + i;
|
|
208
|
-
// Clamp y coordinate to valid range
|
|
209
|
-
y_pos = max(0, min(y_pos, in_h - 1));
|
|
210
|
-
float wy = cubic_kernel(y_frac - i);
|
|
211
|
-
|
|
212
|
-
for (int j = -1; j <= 2; j++) {
|
|
213
|
-
int x_pos = x0 + j;
|
|
214
|
-
// Clamp x coordinate to valid range
|
|
215
|
-
x_pos = max(0, min(x_pos, in_w - 1));
|
|
216
|
-
float wx = cubic_kernel(x_frac - j);
|
|
217
|
-
float weight = wy * wx;
|
|
218
|
-
|
|
219
|
-
// Calculate input tensor offset
|
|
220
|
-
int input_offset = ((b * channels + c) * in_h + y_pos) * in_w + x_pos;
|
|
221
|
-
|
|
222
|
-
// Add weighted contribution
|
|
223
|
-
result += input[input_offset] * weight;
|
|
224
|
-
weight_sum += weight;
|
|
225
|
-
}
|
|
226
|
-
}
|
|
227
|
-
|
|
228
|
-
// Normalize by weight sum to ensure consistent intensity
|
|
229
|
-
if (weight_sum > 0.0f) {
|
|
230
|
-
result /= weight_sum;
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
// Calculate output tensor offset
|
|
234
|
-
int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
|
|
235
|
-
|
|
236
|
-
// Assign the result to output
|
|
237
|
-
output[output_offset] = (float)result;
|
|
238
|
-
"""
|
|
239
|
-
|
|
240
|
-
# Create the kernel
|
|
241
|
-
kernel = mx.fast.metal_kernel(
|
|
242
|
-
name="bicubic_interpolation",
|
|
243
|
-
input_names=["input", "dims", "params"],
|
|
244
|
-
output_names=["output"],
|
|
245
|
-
source=source,
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
# Run the kernel
|
|
249
|
-
threadgroup = get_optimal_threadgroup(out_w, out_h)
|
|
250
|
-
outputs = kernel(
|
|
251
|
-
inputs=[x_flat, dims, params],
|
|
252
|
-
grid=(out_w, out_h, batch_size * channels),
|
|
253
|
-
threadgroup=threadgroup,
|
|
254
|
-
output_shapes=[(batch_size * channels * out_h * out_w,)],
|
|
255
|
-
output_dtypes=[mx.float32], # Always use float32 for kernel output
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
# Reshape output back to 4D tensor and convert back to original dtype
|
|
259
|
-
result = outputs[0].reshape(batch_size, channels, out_h, out_w)
|
|
260
|
-
if input_dtype != mx.float32:
|
|
261
|
-
result = result.astype(input_dtype)
|
|
262
|
-
|
|
263
|
-
return result
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def get_optimal_threadgroup(out_w, out_h):
|
|
267
|
-
# Calculate optimal threadgroup dimensions based on output dimensions
|
|
268
|
-
|
|
269
|
-
# Maximum threadgroup size for most Metal GPUs
|
|
270
|
-
# This could be made more dynamic with Metal API queries if needed
|
|
271
|
-
MAX_THREADS_PER_GROUP = 1024
|
|
272
|
-
MAX_THREADS_PER_DIM = 1024
|
|
273
|
-
|
|
274
|
-
# Start with a reasonable default size for 2D workloads
|
|
275
|
-
default_threadgroup = (32, 32, 1)
|
|
276
|
-
|
|
277
|
-
try:
|
|
278
|
-
# Don't create threadgroups larger than the work dimensions
|
|
279
|
-
max_width = min(MAX_THREADS_PER_DIM, out_w)
|
|
280
|
-
max_height = min(MAX_THREADS_PER_DIM, out_h)
|
|
281
|
-
|
|
282
|
-
# Find largest power of 2 that fits within our dimensions
|
|
283
|
-
width = 2 ** (max_width.bit_length() - 1)
|
|
284
|
-
if width > max_width:
|
|
285
|
-
width = width // 2
|
|
286
|
-
|
|
287
|
-
height = 2 ** (max_height.bit_length() - 1)
|
|
288
|
-
if height > max_height:
|
|
289
|
-
height = height // 2
|
|
290
|
-
|
|
291
|
-
# Ensure we don't exceed maximum threads per threadgroup
|
|
292
|
-
while width * height > MAX_THREADS_PER_GROUP:
|
|
293
|
-
# Reduce the larger dimension first
|
|
294
|
-
if width >= height:
|
|
295
|
-
width = width // 2
|
|
296
|
-
else:
|
|
297
|
-
height = height // 2
|
|
298
|
-
|
|
299
|
-
# Ensure minimum size for efficiency
|
|
300
|
-
width = max(8, width)
|
|
301
|
-
height = max(8, height)
|
|
302
|
-
|
|
303
|
-
return (width, height, 1)
|
|
304
|
-
|
|
305
|
-
except Exception:
|
|
306
|
-
# Return safe defaults if calculation fails
|
|
307
|
-
return default_threadgroup
|
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
import glob
|
|
2
|
-
import inspect
|
|
3
|
-
import json
|
|
4
|
-
import re
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import List, Optional
|
|
8
|
-
|
|
9
|
-
import mlx.core as mx
|
|
10
|
-
import mlx.nn as nn
|
|
11
|
-
import numpy as np
|
|
12
|
-
from huggingface_hub import snapshot_download
|
|
13
|
-
from transformers import AutoConfig
|
|
14
|
-
|
|
15
|
-
from .language import LanguageModel, TextConfig
|
|
16
|
-
from .vision import VisionConfig, VisionModel
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@dataclass
|
|
20
|
-
class ModelConfig:
|
|
21
|
-
text_config: TextConfig
|
|
22
|
-
vision_config: VisionConfig
|
|
23
|
-
model_type: str
|
|
24
|
-
ignore_index: int = -100
|
|
25
|
-
vocab_size: int = 128259
|
|
26
|
-
scale_factor: int = 2
|
|
27
|
-
media_placeholder_token_id: int = 163606
|
|
28
|
-
image_token_index: Optional[int] = None
|
|
29
|
-
eos_token_id: Optional[List[int]] = None
|
|
30
|
-
|
|
31
|
-
def __post_init__(self):
|
|
32
|
-
if self.image_token_index is None:
|
|
33
|
-
self.image_token_index = self.media_placeholder_token_id
|
|
34
|
-
|
|
35
|
-
@classmethod
|
|
36
|
-
def from_dict(cls, params):
|
|
37
|
-
return cls(
|
|
38
|
-
**{
|
|
39
|
-
k: v
|
|
40
|
-
for k, v in params.items()
|
|
41
|
-
if k in inspect.signature(cls).parameters
|
|
42
|
-
}
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class KimiVLMultiModalProjector(nn.Module):
|
|
47
|
-
|
|
48
|
-
def __init__(self, config: ModelConfig):
|
|
49
|
-
super().__init__()
|
|
50
|
-
|
|
51
|
-
self.hidden_size = (
|
|
52
|
-
config.vision_config.hidden_size
|
|
53
|
-
* config.vision_config.merge_kernel_size[0]
|
|
54
|
-
* config.vision_config.merge_kernel_size[1]
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-05)
|
|
58
|
-
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
|
59
|
-
self.act = nn.GELU()
|
|
60
|
-
self.linear_2 = nn.Linear(
|
|
61
|
-
self.hidden_size, config.text_config.hidden_size, bias=True
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
def __call__(self, image_features: list[mx.array]) -> mx.array:
|
|
65
|
-
image_features = mx.concatenate(image_features, axis=0)
|
|
66
|
-
h = self.pre_norm(image_features).reshape(-1, self.hidden_size)
|
|
67
|
-
h = self.linear_1(h)
|
|
68
|
-
h = self.act(h)
|
|
69
|
-
h = self.linear_2(h)
|
|
70
|
-
return h
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class Model(nn.Module):
|
|
74
|
-
def __init__(self, config: ModelConfig):
|
|
75
|
-
super().__init__()
|
|
76
|
-
self.model_type = config.model_type
|
|
77
|
-
self.config = config
|
|
78
|
-
|
|
79
|
-
self.vision_tower = VisionModel(config.vision_config)
|
|
80
|
-
self.language_model = LanguageModel(config.text_config)
|
|
81
|
-
self.multi_modal_projector = KimiVLMultiModalProjector(config)
|
|
82
|
-
|
|
83
|
-
def get_input_embeddings(
|
|
84
|
-
self,
|
|
85
|
-
input_ids: Optional[mx.array] = None,
|
|
86
|
-
pixel_values: Optional[mx.array] = None,
|
|
87
|
-
grid_thw: Optional[mx.array] = None,
|
|
88
|
-
):
|
|
89
|
-
if pixel_values is None:
|
|
90
|
-
return self.language_model.embed_tokens(input_ids)
|
|
91
|
-
|
|
92
|
-
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
93
|
-
|
|
94
|
-
hidden_state = self.vision_tower(
|
|
95
|
-
pixel_values.transpose(0, 2, 3, 1),
|
|
96
|
-
output_hidden_states=True,
|
|
97
|
-
grid_thw=grid_thw,
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
image_features = self.multi_modal_projector(hidden_state)
|
|
101
|
-
|
|
102
|
-
final_inputs_embeds = self._prepare_inputs_for_multimodal(
|
|
103
|
-
image_features, inputs_embeds, input_ids
|
|
104
|
-
)
|
|
105
|
-
return final_inputs_embeds
|
|
106
|
-
|
|
107
|
-
def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
|
|
108
|
-
image_token_index = self.config.image_token_index
|
|
109
|
-
|
|
110
|
-
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
|
111
|
-
image_positions = np.where(input_ids == image_token_index)[1].tolist()
|
|
112
|
-
|
|
113
|
-
inputs_embeds[:, image_positions, :] = image_features
|
|
114
|
-
|
|
115
|
-
return inputs_embeds
|
|
116
|
-
|
|
117
|
-
@property
|
|
118
|
-
def layers(self):
|
|
119
|
-
return self.language_model.model.layers
|
|
120
|
-
|
|
121
|
-
def __call__(
|
|
122
|
-
self,
|
|
123
|
-
input_ids: mx.array,
|
|
124
|
-
pixel_values: mx.array,
|
|
125
|
-
cache=None,
|
|
126
|
-
**kwargs,
|
|
127
|
-
):
|
|
128
|
-
image_grid_thw = kwargs.pop("image_grid_hws", None)
|
|
129
|
-
video_grid_thw = kwargs.pop("video_grid_hws", None)
|
|
130
|
-
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
131
|
-
input_embeddings = self.get_input_embeddings(
|
|
132
|
-
input_ids, pixel_values, grid_thw=grid_thw
|
|
133
|
-
)
|
|
134
|
-
logits = self.language_model(
|
|
135
|
-
inputs=input_ids, cache=cache, inputs_embeds=input_embeddings
|
|
136
|
-
)
|
|
137
|
-
return logits
|
|
138
|
-
|
|
139
|
-
def sanitize(self, weights):
|
|
140
|
-
return {
|
|
141
|
-
k.replace("encoder.", "") if "vision_tower" in k else k: v
|
|
142
|
-
for k, v in weights.items()
|
|
143
|
-
}
|