nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
# Copyright © Nexa AI
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import os
|
|
16
|
-
import sys
|
|
17
|
-
import numpy as np
|
|
18
|
-
from pathlib import Path
|
|
19
|
-
|
|
20
|
-
# Add parent path for imports
|
|
21
|
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
22
|
-
|
|
23
|
-
# Import from interface (uses the factory pattern with routing)
|
|
24
|
-
from .interface import create_embedder
|
|
25
|
-
from .interface import EmbeddingConfig
|
|
26
|
-
from huggingface_hub import snapshot_download
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def download_model_if_needed(model_id, local_dir):
|
|
30
|
-
"""Download model from Hugging Face Hub if not present locally."""
|
|
31
|
-
if not os.path.exists(os.path.join(local_dir, "config.json")):
|
|
32
|
-
print(f"📥 Model not found locally. Downloading {model_id}...")
|
|
33
|
-
os.makedirs(local_dir, exist_ok=True)
|
|
34
|
-
try:
|
|
35
|
-
snapshot_download(
|
|
36
|
-
repo_id=model_id,
|
|
37
|
-
local_dir=local_dir,
|
|
38
|
-
resume_download=True,
|
|
39
|
-
local_dir_use_symlinks=False
|
|
40
|
-
)
|
|
41
|
-
print("✅ Model download completed!")
|
|
42
|
-
except Exception as e:
|
|
43
|
-
print(f"❌ Failed to download model: {e}")
|
|
44
|
-
raise
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def test_embedding_interface(model_path, is_local=False):
|
|
48
|
-
"""Test embedding model functionality using the interface."""
|
|
49
|
-
|
|
50
|
-
print("=" * 70)
|
|
51
|
-
print("TESTING EMBEDDING MODEL VIA INTERFACE")
|
|
52
|
-
print("=" * 70)
|
|
53
|
-
|
|
54
|
-
# Handle model path - download if it's a HF model ID
|
|
55
|
-
if not is_local and "/" in model_path:
|
|
56
|
-
# It's a HuggingFace model ID
|
|
57
|
-
local_dir = f"./modelfiles/{model_path.replace('/', '_')}"
|
|
58
|
-
download_model_if_needed(model_path, local_dir)
|
|
59
|
-
model_path = local_dir
|
|
60
|
-
|
|
61
|
-
# Create embedder using factory function (will auto-detect model type)
|
|
62
|
-
print(f"\n🔍 Creating embedder for: {model_path}")
|
|
63
|
-
embedder = create_embedder(model_path=model_path)
|
|
64
|
-
print(f"✅ Created embedder type: {type(embedder).__name__}")
|
|
65
|
-
|
|
66
|
-
# Load the model
|
|
67
|
-
print("\n📚 Loading embedding model...")
|
|
68
|
-
success = embedder.load_model(model_path)
|
|
69
|
-
|
|
70
|
-
if not success:
|
|
71
|
-
print("❌ Failed to load model!")
|
|
72
|
-
return
|
|
73
|
-
|
|
74
|
-
print("✅ Model loaded successfully!")
|
|
75
|
-
print(f"📏 Embedding dimension: {embedder.embedding_dim()}")
|
|
76
|
-
|
|
77
|
-
# Test texts
|
|
78
|
-
test_texts = [
|
|
79
|
-
"Hello, how are you?",
|
|
80
|
-
"What is machine learning?",
|
|
81
|
-
"The weather is nice today.",
|
|
82
|
-
"Python is a programming language.",
|
|
83
|
-
"Artificial intelligence is changing the world."
|
|
84
|
-
]
|
|
85
|
-
|
|
86
|
-
# Configure embedding with different settings
|
|
87
|
-
configs = [
|
|
88
|
-
EmbeddingConfig(batch_size=2, normalize=True, normalize_method="l2"),
|
|
89
|
-
EmbeddingConfig(batch_size=3, normalize=False),
|
|
90
|
-
]
|
|
91
|
-
|
|
92
|
-
for config_idx, config in enumerate(configs):
|
|
93
|
-
print(f"\n{'='*50}")
|
|
94
|
-
print(f"TEST {config_idx + 1}: Config - Batch: {config.batch_size}, "
|
|
95
|
-
f"Normalize: {config.normalize}, Method: {config.normalize_method}")
|
|
96
|
-
print('='*50)
|
|
97
|
-
|
|
98
|
-
# Generate embeddings
|
|
99
|
-
embeddings = embedder.embed(test_texts, config)
|
|
100
|
-
|
|
101
|
-
# Display results
|
|
102
|
-
print(f"\n📊 Generated {len(embeddings)} embeddings")
|
|
103
|
-
|
|
104
|
-
for i, (text, embedding) in enumerate(zip(test_texts[:3], embeddings[:3])):
|
|
105
|
-
print(f"\n Text {i+1}: '{text}'")
|
|
106
|
-
print(f" Dimension: {len(embedding)}")
|
|
107
|
-
print(f" First 5 values: {[f'{v:.4f}' for v in embedding[:5]]}")
|
|
108
|
-
|
|
109
|
-
# Calculate magnitude
|
|
110
|
-
magnitude = np.linalg.norm(embedding)
|
|
111
|
-
print(f" Magnitude: {magnitude:.6f}")
|
|
112
|
-
|
|
113
|
-
# Compute similarity matrix for normalized embeddings
|
|
114
|
-
print("\n" + "="*50)
|
|
115
|
-
print("SIMILARITY MATRIX (L2 Normalized)")
|
|
116
|
-
print("="*50)
|
|
117
|
-
|
|
118
|
-
config = EmbeddingConfig(batch_size=len(test_texts), normalize=True, normalize_method="l2")
|
|
119
|
-
embeddings = embedder.embed(test_texts, config)
|
|
120
|
-
|
|
121
|
-
# Convert to numpy for easier computation
|
|
122
|
-
embeddings_np = np.array(embeddings)
|
|
123
|
-
similarity_matrix = np.dot(embeddings_np, embeddings_np.T)
|
|
124
|
-
|
|
125
|
-
print("\nTexts:")
|
|
126
|
-
for i, text in enumerate(test_texts):
|
|
127
|
-
print(f" [{i}] {text[:30]}...")
|
|
128
|
-
|
|
129
|
-
print("\nSimilarity Matrix:")
|
|
130
|
-
print(" ", end="")
|
|
131
|
-
for i in range(len(test_texts)):
|
|
132
|
-
print(f" [{i}] ", end="")
|
|
133
|
-
print()
|
|
134
|
-
|
|
135
|
-
for i in range(len(test_texts)):
|
|
136
|
-
print(f" [{i}]", end="")
|
|
137
|
-
for j in range(len(test_texts)):
|
|
138
|
-
print(f" {similarity_matrix[i, j]:5.2f}", end="")
|
|
139
|
-
print()
|
|
140
|
-
|
|
141
|
-
# Find most similar pairs
|
|
142
|
-
print("\n🔍 Most Similar Pairs (excluding self-similarity):")
|
|
143
|
-
similarities = []
|
|
144
|
-
for i in range(len(test_texts)):
|
|
145
|
-
for j in range(i+1, len(test_texts)):
|
|
146
|
-
similarities.append((similarity_matrix[i, j], i, j))
|
|
147
|
-
|
|
148
|
-
similarities.sort(reverse=True)
|
|
149
|
-
for sim, i, j in similarities[:3]:
|
|
150
|
-
print(f" • Texts [{i}] and [{j}]: {sim:.4f}")
|
|
151
|
-
|
|
152
|
-
# Cleanup
|
|
153
|
-
embedder.close()
|
|
154
|
-
print("\n✅ Interface test completed successfully!")
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
if __name__ == "__main__":
|
|
158
|
-
import argparse
|
|
159
|
-
parser = argparse.ArgumentParser(description="Test embedding models via interface")
|
|
160
|
-
parser.add_argument(
|
|
161
|
-
"--model_path",
|
|
162
|
-
type=str,
|
|
163
|
-
default="nexaml/jina-v2-fp16-mlx",
|
|
164
|
-
help="Model path (local) or HuggingFace model ID"
|
|
165
|
-
)
|
|
166
|
-
parser.add_argument(
|
|
167
|
-
"--local",
|
|
168
|
-
action="store_true",
|
|
169
|
-
help="Indicate if model_path is a local directory"
|
|
170
|
-
)
|
|
171
|
-
args = parser.parse_args()
|
|
172
|
-
|
|
173
|
-
test_embedding_interface(args.model_path, args.local)
|
|
File without changes
|
|
@@ -1,399 +0,0 @@
|
|
|
1
|
-
# Copyright © Nexa AI
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import math
|
|
16
|
-
from dataclasses import dataclass
|
|
17
|
-
from typing import Any, Dict, List, Optional, Union
|
|
18
|
-
|
|
19
|
-
import mlx.core as mx
|
|
20
|
-
import mlx.nn as nn
|
|
21
|
-
|
|
22
|
-
import os
|
|
23
|
-
import sys
|
|
24
|
-
|
|
25
|
-
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
26
|
-
llm_common_dir = os.path.join(curr_dir, "..", "..")
|
|
27
|
-
sys.path.append(llm_common_dir)
|
|
28
|
-
|
|
29
|
-
from mlx_lm.models.base import (
|
|
30
|
-
BaseModelArgs,
|
|
31
|
-
scaled_dot_product_attention,
|
|
32
|
-
)
|
|
33
|
-
from tokenizers import Tokenizer
|
|
34
|
-
|
|
35
|
-
@dataclass
|
|
36
|
-
class ModelArgs(BaseModelArgs):
|
|
37
|
-
model_type: str = "bert"
|
|
38
|
-
vocab_size: int = 61056 # Updated from config
|
|
39
|
-
hidden_size: int = 768
|
|
40
|
-
num_hidden_layers: int = 12
|
|
41
|
-
num_attention_heads: int = 12
|
|
42
|
-
intermediate_size: int = 3072
|
|
43
|
-
hidden_act: str = "gelu"
|
|
44
|
-
hidden_dropout_prob: float = 0.1
|
|
45
|
-
attention_probs_dropout_prob: float = 0.1
|
|
46
|
-
max_position_embeddings: int = 8192 # Updated from config
|
|
47
|
-
type_vocab_size: int = 2
|
|
48
|
-
initializer_range: float = 0.02
|
|
49
|
-
layer_norm_eps: float = 1e-12
|
|
50
|
-
pad_token_id: int = 0
|
|
51
|
-
position_embedding_type: str = "alibi" # Updated from config
|
|
52
|
-
use_cache: bool = True
|
|
53
|
-
classifier_dropout: Optional[float] = None
|
|
54
|
-
feed_forward_type: str = "geglu" # Updated from config
|
|
55
|
-
emb_pooler: str = "mean" # Updated from config
|
|
56
|
-
attn_implementation: str = "torch"
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class JinaBertEmbeddings(nn.Module):
|
|
60
|
-
def __init__(self, config: ModelArgs):
|
|
61
|
-
super().__init__()
|
|
62
|
-
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
63
|
-
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
64
|
-
# Use PyTorch-style naming for weight loading compatibility
|
|
65
|
-
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
66
|
-
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
|
67
|
-
|
|
68
|
-
def __call__(
|
|
69
|
-
self,
|
|
70
|
-
input_ids: Optional[mx.array] = None,
|
|
71
|
-
token_type_ids: Optional[mx.array] = None,
|
|
72
|
-
) -> mx.array:
|
|
73
|
-
if token_type_ids is None:
|
|
74
|
-
input_shape = input_ids.shape
|
|
75
|
-
token_type_ids = mx.zeros(input_shape, dtype=mx.int64)
|
|
76
|
-
|
|
77
|
-
inputs_embeds = self.word_embeddings(input_ids)
|
|
78
|
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
79
|
-
|
|
80
|
-
embeddings = inputs_embeds + token_type_embeddings
|
|
81
|
-
embeddings = self.LayerNorm(embeddings)
|
|
82
|
-
return embeddings
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class JinaBertSelfAttention(nn.Module):
|
|
86
|
-
def __init__(self, config: ModelArgs, position_embedding_type=None):
|
|
87
|
-
super().__init__()
|
|
88
|
-
if config.hidden_size % config.num_attention_heads != 0:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
|
91
|
-
f"heads ({config.num_attention_heads})"
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
self.attn_implementation = config.attn_implementation
|
|
95
|
-
self.num_attention_heads = config.num_attention_heads
|
|
96
|
-
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
97
|
-
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
98
|
-
|
|
99
|
-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
|
100
|
-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
|
101
|
-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
|
102
|
-
|
|
103
|
-
self.position_embedding_type = position_embedding_type or getattr(
|
|
104
|
-
config, "position_embedding_type", "absolute"
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
def transpose_for_scores(self, x: mx.array) -> mx.array:
|
|
108
|
-
new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size)
|
|
109
|
-
x = x.reshape(new_x_shape)
|
|
110
|
-
return x.transpose(0, 2, 1, 3)
|
|
111
|
-
|
|
112
|
-
def __call__(
|
|
113
|
-
self,
|
|
114
|
-
hidden_states: mx.array,
|
|
115
|
-
attention_mask: Optional[mx.array] = None,
|
|
116
|
-
bias: Optional[mx.array] = None,
|
|
117
|
-
) -> mx.array:
|
|
118
|
-
mixed_query_layer = self.query(hidden_states)
|
|
119
|
-
|
|
120
|
-
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
|
121
|
-
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
|
122
|
-
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
123
|
-
|
|
124
|
-
scale = 1.0 / math.sqrt(self.attention_head_size)
|
|
125
|
-
|
|
126
|
-
mask = None
|
|
127
|
-
if attention_mask is not None or bias is not None:
|
|
128
|
-
if attention_mask is not None and bias is not None:
|
|
129
|
-
mask = attention_mask + bias
|
|
130
|
-
elif attention_mask is not None:
|
|
131
|
-
mask = attention_mask
|
|
132
|
-
else:
|
|
133
|
-
mask = bias
|
|
134
|
-
|
|
135
|
-
# Cast mask to same dtype as hidden_states
|
|
136
|
-
if mask is not None:
|
|
137
|
-
mask = mask.astype(hidden_states.dtype)
|
|
138
|
-
|
|
139
|
-
context_layer = scaled_dot_product_attention(
|
|
140
|
-
query_layer, key_layer, value_layer, cache=None, scale=scale, mask=mask
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
context_layer = context_layer.transpose(0, 2, 1, 3)
|
|
144
|
-
new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
|
|
145
|
-
context_layer = context_layer.reshape(new_context_layer_shape)
|
|
146
|
-
return context_layer
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
class JinaBertSelfOutput(nn.Module):
|
|
150
|
-
def __init__(self, config):
|
|
151
|
-
super().__init__()
|
|
152
|
-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
153
|
-
# Use PyTorch-style naming for weight loading compatibility
|
|
154
|
-
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
155
|
-
|
|
156
|
-
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
|
|
157
|
-
hidden_states = self.dense(hidden_states)
|
|
158
|
-
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
|
159
|
-
return hidden_states
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
class JinaBertAttention(nn.Module):
|
|
163
|
-
def __init__(self, config, position_embedding_type=None):
|
|
164
|
-
super().__init__()
|
|
165
|
-
self.self = JinaBertSelfAttention(config, position_embedding_type=position_embedding_type)
|
|
166
|
-
self.output = JinaBertSelfOutput(config)
|
|
167
|
-
|
|
168
|
-
def __call__(
|
|
169
|
-
self,
|
|
170
|
-
hidden_states: mx.array,
|
|
171
|
-
attention_mask: Optional[mx.array] = None,
|
|
172
|
-
bias: Optional[mx.array] = None,
|
|
173
|
-
) -> mx.array:
|
|
174
|
-
self_outputs = self.self(hidden_states, attention_mask, bias)
|
|
175
|
-
attention_output = self.output(self_outputs, hidden_states)
|
|
176
|
-
return attention_output
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
class JinaBertGLUMLP(nn.Module):
|
|
180
|
-
def __init__(self, config: ModelArgs):
|
|
181
|
-
super().__init__()
|
|
182
|
-
self.config = config
|
|
183
|
-
self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
|
|
184
|
-
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
185
|
-
# Use PyTorch-style naming for weight loading compatibility
|
|
186
|
-
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
187
|
-
|
|
188
|
-
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
189
|
-
residual_connection = hidden_states
|
|
190
|
-
hidden_states = self.gated_layers(hidden_states)
|
|
191
|
-
|
|
192
|
-
if self.config.feed_forward_type == "geglu":
|
|
193
|
-
gated = hidden_states[..., : self.config.intermediate_size]
|
|
194
|
-
non_gated = hidden_states[..., self.config.intermediate_size :]
|
|
195
|
-
hidden_states = nn.gelu(gated) * non_gated
|
|
196
|
-
else:
|
|
197
|
-
# Original GLU
|
|
198
|
-
gated = hidden_states[..., : self.config.intermediate_size]
|
|
199
|
-
non_gated = hidden_states[..., self.config.intermediate_size :]
|
|
200
|
-
hidden_states = nn.gelu(gated) * non_gated
|
|
201
|
-
|
|
202
|
-
hidden_states = self.wo(hidden_states)
|
|
203
|
-
hidden_states = self.layernorm(hidden_states + residual_connection)
|
|
204
|
-
return hidden_states
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
class JinaBertLayer(nn.Module):
|
|
208
|
-
def __init__(self, config: ModelArgs):
|
|
209
|
-
super().__init__()
|
|
210
|
-
self.attention = JinaBertAttention(config)
|
|
211
|
-
self.feed_forward_type = config.feed_forward_type
|
|
212
|
-
self.mlp = JinaBertGLUMLP(config)
|
|
213
|
-
|
|
214
|
-
def __call__(
|
|
215
|
-
self,
|
|
216
|
-
hidden_states: mx.array,
|
|
217
|
-
attention_mask: Optional[mx.array] = None,
|
|
218
|
-
bias: Optional[mx.array] = None,
|
|
219
|
-
) -> mx.array:
|
|
220
|
-
attention_output = self.attention(hidden_states, attention_mask, bias=bias)
|
|
221
|
-
layer_output = self.mlp(attention_output)
|
|
222
|
-
return layer_output
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
class JinaBertEncoder(nn.Module):
|
|
226
|
-
def __init__(self, config: ModelArgs):
|
|
227
|
-
super().__init__()
|
|
228
|
-
self.config = config
|
|
229
|
-
# Use list instead of ModuleList for PyTorch compatibility
|
|
230
|
-
self.layer = [JinaBertLayer(config) for _ in range(config.num_hidden_layers)]
|
|
231
|
-
self.gradient_checkpointing = False
|
|
232
|
-
self.num_attention_heads = config.num_attention_heads
|
|
233
|
-
self._current_alibi_size = config.max_position_embeddings
|
|
234
|
-
|
|
235
|
-
# Build ALiBi tensor
|
|
236
|
-
# self.alibi = self.rebuild_alibi_tensor(size=config.max_position_embeddings)
|
|
237
|
-
|
|
238
|
-
def rebuild_alibi_tensor(self, size: int) -> mx.array:
|
|
239
|
-
"""Build ALiBi bias tensor"""
|
|
240
|
-
n_heads = self.num_attention_heads
|
|
241
|
-
|
|
242
|
-
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
|
243
|
-
def get_slopes_power_of_2(n):
|
|
244
|
-
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
245
|
-
ratio = start
|
|
246
|
-
return [start * ratio**i for i in range(n)]
|
|
247
|
-
|
|
248
|
-
if math.log2(n_heads).is_integer():
|
|
249
|
-
return get_slopes_power_of_2(n_heads)
|
|
250
|
-
else:
|
|
251
|
-
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
|
|
252
|
-
return (
|
|
253
|
-
get_slopes_power_of_2(closest_power_of_2)
|
|
254
|
-
+ _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
|
|
255
|
-
: n_heads - closest_power_of_2
|
|
256
|
-
]
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
context_position = mx.arange(size)[:, None]
|
|
260
|
-
memory_position = mx.arange(size)[None, :]
|
|
261
|
-
relative_position = mx.abs(memory_position - context_position)
|
|
262
|
-
relative_position = mx.expand_dims(relative_position, axis=0)
|
|
263
|
-
relative_position = mx.repeat(relative_position, n_heads, axis=0)
|
|
264
|
-
|
|
265
|
-
slopes = mx.array(_get_alibi_head_slopes(n_heads)) * -1
|
|
266
|
-
slopes = mx.expand_dims(mx.expand_dims(slopes, axis=1), axis=2)
|
|
267
|
-
alibi = slopes * relative_position
|
|
268
|
-
alibi = mx.expand_dims(alibi, axis=0)
|
|
269
|
-
|
|
270
|
-
self._current_alibi_size = size
|
|
271
|
-
return alibi
|
|
272
|
-
|
|
273
|
-
def __call__(
|
|
274
|
-
self,
|
|
275
|
-
hidden_states: mx.array,
|
|
276
|
-
attention_mask: Optional[mx.array] = None,
|
|
277
|
-
) -> mx.array:
|
|
278
|
-
_, seqlen, _ = hidden_states.shape
|
|
279
|
-
alibi_bias = self.rebuild_alibi_tensor(seqlen)
|
|
280
|
-
|
|
281
|
-
for i, layer_module in enumerate(self.layer):
|
|
282
|
-
layer_outputs = layer_module(hidden_states, attention_mask, alibi_bias)
|
|
283
|
-
hidden_states = layer_outputs
|
|
284
|
-
|
|
285
|
-
return hidden_states
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
class JinaBertPooler(nn.Module):
|
|
289
|
-
def __init__(self, config):
|
|
290
|
-
super().__init__()
|
|
291
|
-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
292
|
-
self.activation = nn.tanh
|
|
293
|
-
|
|
294
|
-
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
295
|
-
# We "pool" the model by simply taking the hidden state corresponding
|
|
296
|
-
# to the first token.
|
|
297
|
-
first_token_tensor = hidden_states[:, 0]
|
|
298
|
-
pooled_output = self.dense(first_token_tensor)
|
|
299
|
-
pooled_output = self.activation(pooled_output)
|
|
300
|
-
return pooled_output
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
class JinaBertModel(nn.Module):
|
|
304
|
-
def __init__(self, config: ModelArgs):
|
|
305
|
-
super().__init__()
|
|
306
|
-
self.config = config
|
|
307
|
-
self.embeddings = JinaBertEmbeddings(config)
|
|
308
|
-
self.encoder = JinaBertEncoder(config)
|
|
309
|
-
# Add pooler layer for weight compatibility
|
|
310
|
-
self.pooler = JinaBertPooler(config)
|
|
311
|
-
|
|
312
|
-
def get_extended_attention_mask(self, attention_mask: mx.array, input_shape: tuple) -> mx.array:
|
|
313
|
-
"""Convert attention mask to extended format"""
|
|
314
|
-
if attention_mask.ndim == 3:
|
|
315
|
-
extended_attention_mask = attention_mask[:, None, :, :]
|
|
316
|
-
elif attention_mask.ndim == 2:
|
|
317
|
-
extended_attention_mask = attention_mask[:, None, None, :]
|
|
318
|
-
else:
|
|
319
|
-
raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})")
|
|
320
|
-
|
|
321
|
-
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
322
|
-
return extended_attention_mask
|
|
323
|
-
|
|
324
|
-
def mean_pooling(self, token_embeddings: mx.array, attention_mask: mx.array) -> mx.array:
|
|
325
|
-
input_mask_expanded = mx.expand_dims(attention_mask, axis=-1) * mx.ones_like(
|
|
326
|
-
token_embeddings
|
|
327
|
-
)
|
|
328
|
-
return mx.sum(token_embeddings * input_mask_expanded, axis=1) / mx.clip(
|
|
329
|
-
mx.sum(input_mask_expanded, axis=1), 1e-9, None
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
def __call__(
|
|
333
|
-
self,
|
|
334
|
-
input_ids: Optional[mx.array] = None,
|
|
335
|
-
attention_mask: Optional[mx.array] = None,
|
|
336
|
-
token_type_ids: Optional[mx.array] = None,
|
|
337
|
-
) -> mx.array:
|
|
338
|
-
input_shape = input_ids.shape
|
|
339
|
-
|
|
340
|
-
if attention_mask is not None:
|
|
341
|
-
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
342
|
-
else:
|
|
343
|
-
extended_attention_mask = None
|
|
344
|
-
|
|
345
|
-
embedding_output = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
|
|
346
|
-
encoder_outputs = self.encoder(embedding_output, attention_mask=extended_attention_mask)
|
|
347
|
-
|
|
348
|
-
return encoder_outputs
|
|
349
|
-
|
|
350
|
-
def encode(
|
|
351
|
-
self,
|
|
352
|
-
input_ids: mx.array,
|
|
353
|
-
attention_mask: mx.array,
|
|
354
|
-
token_type_ids: Optional[mx.array] = None,
|
|
355
|
-
) -> mx.array:
|
|
356
|
-
"""Encode inputs and return mean-pooled embeddings"""
|
|
357
|
-
token_embs = self(
|
|
358
|
-
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
|
359
|
-
)
|
|
360
|
-
embeddings = self.mean_pooling(token_embs, attention_mask)
|
|
361
|
-
return embeddings
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
class Model(nn.Module):
|
|
365
|
-
def __init__(self, args: ModelArgs):
|
|
366
|
-
super().__init__()
|
|
367
|
-
self.args = args
|
|
368
|
-
self.model_type = args.model_type
|
|
369
|
-
self.model = JinaBertModel(args)
|
|
370
|
-
|
|
371
|
-
def __call__(
|
|
372
|
-
self,
|
|
373
|
-
input_ids: mx.array,
|
|
374
|
-
attention_mask: Optional[mx.array] = None,
|
|
375
|
-
token_type_ids: Optional[mx.array] = None,
|
|
376
|
-
) -> mx.array:
|
|
377
|
-
return self.model(
|
|
378
|
-
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
def encode(
|
|
382
|
-
self,
|
|
383
|
-
input_ids: mx.array,
|
|
384
|
-
attention_mask: mx.array,
|
|
385
|
-
token_type_ids: Optional[mx.array] = None,
|
|
386
|
-
) -> mx.array:
|
|
387
|
-
"""Encode inputs and return mean-pooled embeddings"""
|
|
388
|
-
return self.model.encode(
|
|
389
|
-
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
def sanitize(self, weights):
|
|
393
|
-
"""Remove parameters that don't exist in our model"""
|
|
394
|
-
# No longer need to remove pooler weights since we now have them
|
|
395
|
-
return weights
|
|
396
|
-
|
|
397
|
-
@property
|
|
398
|
-
def layers(self):
|
|
399
|
-
return self.model.encoder.layer
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
# Image generation module for MLX backend
|