nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
|
@@ -1,928 +0,0 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
import glob
|
|
3
|
-
import importlib
|
|
4
|
-
import inspect
|
|
5
|
-
import json
|
|
6
|
-
import logging
|
|
7
|
-
import shutil
|
|
8
|
-
from io import BytesIO
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from textwrap import dedent
|
|
11
|
-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
|
12
|
-
|
|
13
|
-
import mlx.core as mx
|
|
14
|
-
import mlx.nn as nn
|
|
15
|
-
import numpy as np
|
|
16
|
-
import requests
|
|
17
|
-
import scipy.signal as signal
|
|
18
|
-
import soundfile as sf
|
|
19
|
-
from huggingface_hub import snapshot_download
|
|
20
|
-
from mlx.utils import tree_flatten, tree_map_with_path, tree_reduce, tree_unflatten
|
|
21
|
-
from mlx_lm.utils import quantize_model
|
|
22
|
-
from PIL import Image, ImageOps
|
|
23
|
-
from transformers import (
|
|
24
|
-
AutoConfig,
|
|
25
|
-
AutoProcessor,
|
|
26
|
-
PreTrainedTokenizer,
|
|
27
|
-
PreTrainedTokenizerFast,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
from .models.base import BaseImageProcessor
|
|
31
|
-
from .tokenizer_utils import load_tokenizer
|
|
32
|
-
from .trainer import apply_lora_layers
|
|
33
|
-
|
|
34
|
-
# Constants
|
|
35
|
-
MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"}
|
|
36
|
-
|
|
37
|
-
MAX_FILE_SIZE_GB = 5
|
|
38
|
-
|
|
39
|
-
MODEL_CONVERSION_DTYPES = ["float16", "bfloat16", "float32"]
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def skip_multimodal_module(path: str) -> bool:
|
|
43
|
-
"""
|
|
44
|
-
Check if a multimodal module (vision/audio) should skip quantization.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
path: The module path to check
|
|
48
|
-
|
|
49
|
-
Returns:
|
|
50
|
-
bool: True if the module is multimodal and should skip quantization, False otherwise
|
|
51
|
-
"""
|
|
52
|
-
return (
|
|
53
|
-
"vision_model" in path
|
|
54
|
-
or "vision_tower" in path
|
|
55
|
-
or "audio_model" in path
|
|
56
|
-
or "audio_tower" in path
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def get_model_and_args(config: dict):
|
|
61
|
-
"""
|
|
62
|
-
Retrieve the model object based on the configuration.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
config (dict): The model configuration.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
A tuple containing the Model class and the ModelArgs class.
|
|
69
|
-
"""
|
|
70
|
-
model_type = config["model_type"]
|
|
71
|
-
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
72
|
-
try:
|
|
73
|
-
# ===== NEXAAI CHANGES BEGIN =====
|
|
74
|
-
arch = importlib.import_module(f"vlm.modeling.models.{model_type}")
|
|
75
|
-
# ===== NEXAAI CHANGES END =====
|
|
76
|
-
except ImportError:
|
|
77
|
-
msg = f"Model type {model_type} not supported."
|
|
78
|
-
logging.error(msg)
|
|
79
|
-
raise ValueError(msg)
|
|
80
|
-
|
|
81
|
-
return arch, model_type
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def get_model_path(
|
|
85
|
-
path_or_hf_repo: str, revision: Optional[str] = None, force_download: bool = False
|
|
86
|
-
) -> Path:
|
|
87
|
-
"""
|
|
88
|
-
Ensures the model is available locally. If the path does not exist locally,
|
|
89
|
-
it is downloaded from the Hugging Face Hub.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
|
93
|
-
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
|
94
|
-
|
|
95
|
-
Returns:
|
|
96
|
-
Path: The path to the model.
|
|
97
|
-
"""
|
|
98
|
-
model_path = Path(path_or_hf_repo)
|
|
99
|
-
if not model_path.exists():
|
|
100
|
-
model_path = Path(
|
|
101
|
-
snapshot_download(
|
|
102
|
-
repo_id=path_or_hf_repo,
|
|
103
|
-
revision=revision,
|
|
104
|
-
allow_patterns=[
|
|
105
|
-
"*.json",
|
|
106
|
-
"*.safetensors",
|
|
107
|
-
"*.py",
|
|
108
|
-
"*.model",
|
|
109
|
-
"*.tiktoken",
|
|
110
|
-
"*.txt",
|
|
111
|
-
"*.jinja",
|
|
112
|
-
],
|
|
113
|
-
force_download=force_download,
|
|
114
|
-
)
|
|
115
|
-
)
|
|
116
|
-
return model_path
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module:
|
|
120
|
-
"""
|
|
121
|
-
Load and initialize the model from a given path.
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
model_path (Path): The path to load the model from.
|
|
125
|
-
lazy (bool): If False eval the model parameters to make sure they are
|
|
126
|
-
loaded in memory before returning, otherwise they will be loaded
|
|
127
|
-
when needed. Default: ``False``
|
|
128
|
-
revision (str, optional): A revision id which can be a branch name,
|
|
129
|
-
a tag, or a commit hash. Default: ``None``.
|
|
130
|
-
|
|
131
|
-
Returns:
|
|
132
|
-
nn.Module: The loaded and initialized model.
|
|
133
|
-
|
|
134
|
-
Raises:
|
|
135
|
-
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
136
|
-
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
137
|
-
"""
|
|
138
|
-
config = load_config(model_path, **kwargs)
|
|
139
|
-
quantization = config.get("quantization", None)
|
|
140
|
-
|
|
141
|
-
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
|
142
|
-
if not weight_files:
|
|
143
|
-
logging.error(f"No safetensors found in {model_path}")
|
|
144
|
-
message = f"""
|
|
145
|
-
No safetensors found in {model_path}
|
|
146
|
-
Create safetensors using the following code:
|
|
147
|
-
```
|
|
148
|
-
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
149
|
-
|
|
150
|
-
model_id= "<huggingface_model_id>"
|
|
151
|
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
152
|
-
processor = AutoProcessor.from_pretrained(model_id)
|
|
153
|
-
|
|
154
|
-
model.save_pretrained("<local_dir>")
|
|
155
|
-
processor.save_pretrained("<local_dir>")
|
|
156
|
-
```
|
|
157
|
-
Then use the <local_dir> as the --hf-path in the convert script.
|
|
158
|
-
```
|
|
159
|
-
python -m mlx_vlm.convert --hf-path <local_dir> --mlx-path <mlx_dir>
|
|
160
|
-
```
|
|
161
|
-
"""
|
|
162
|
-
raise FileNotFoundError(message)
|
|
163
|
-
|
|
164
|
-
weights = {}
|
|
165
|
-
for wf in weight_files:
|
|
166
|
-
weights.update(mx.load(wf))
|
|
167
|
-
|
|
168
|
-
model_class, model_type = get_model_and_args(config=config)
|
|
169
|
-
|
|
170
|
-
# Initialize text and vision configs if not present
|
|
171
|
-
config.setdefault("text_config", {})
|
|
172
|
-
config.setdefault("vision_config", {})
|
|
173
|
-
config.setdefault("audio_config", {})
|
|
174
|
-
|
|
175
|
-
# Initialize model config and update it with module configs
|
|
176
|
-
model_config = model_class.ModelConfig.from_dict(config)
|
|
177
|
-
modules = ["text", "vision", "perceiver", "projector", "audio"]
|
|
178
|
-
model_config = update_module_configs(model_config, model_class, config, modules)
|
|
179
|
-
|
|
180
|
-
model = model_class.Model(model_config)
|
|
181
|
-
|
|
182
|
-
# Sanitize weights
|
|
183
|
-
weights = sanitize_weights(model, weights)
|
|
184
|
-
weights = sanitize_weights(
|
|
185
|
-
model_class.VisionModel, weights, model_config.vision_config
|
|
186
|
-
)
|
|
187
|
-
weights = sanitize_weights(
|
|
188
|
-
model_class.LanguageModel, weights, model_config.text_config
|
|
189
|
-
)
|
|
190
|
-
if hasattr(model_class, "AudioModel"):
|
|
191
|
-
weights = sanitize_weights(
|
|
192
|
-
model_class.AudioModel, weights, model_config.audio_config
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
if (quantization := config.get("quantization", None)) is not None:
|
|
196
|
-
# Handle legacy models which may or may not have vision quantized
|
|
197
|
-
# TODO: Re-upload the models with the new quantization config and remove this
|
|
198
|
-
skip_vision = config.get("vision_config", {}).get("skip_vision", False)
|
|
199
|
-
|
|
200
|
-
def get_class_predicate(p, m):
|
|
201
|
-
# Always skip vision and audio models
|
|
202
|
-
if skip_multimodal_module(p) and skip_vision:
|
|
203
|
-
return False
|
|
204
|
-
# Handle custom per layer quantizations
|
|
205
|
-
if p in config["quantization"]:
|
|
206
|
-
return config["quantization"][p]
|
|
207
|
-
if not hasattr(m, "to_quantized"):
|
|
208
|
-
return False
|
|
209
|
-
# Skip layers not divisible by 64
|
|
210
|
-
if hasattr(m, "weight") and m.weight.size % 64 != 0:
|
|
211
|
-
return False
|
|
212
|
-
# Handle legacy models which may not have everything quantized
|
|
213
|
-
return f"{p}.scales" in weights
|
|
214
|
-
|
|
215
|
-
nn.quantize(
|
|
216
|
-
model,
|
|
217
|
-
group_size=quantization["group_size"],
|
|
218
|
-
bits=quantization["bits"],
|
|
219
|
-
class_predicate=get_class_predicate,
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
model.load_weights(list(weights.items()))
|
|
223
|
-
if not lazy:
|
|
224
|
-
mx.eval(model.parameters())
|
|
225
|
-
|
|
226
|
-
model.eval()
|
|
227
|
-
return model
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
def sanitize_weights(model_obj, weights, config=None):
|
|
231
|
-
"""Helper function to sanitize weights if the model has a sanitize method"""
|
|
232
|
-
if hasattr(model_obj, "sanitize"):
|
|
233
|
-
if config is not None:
|
|
234
|
-
model_obj = model_obj(config)
|
|
235
|
-
weights = model_obj.sanitize(weights)
|
|
236
|
-
return weights
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def update_module_configs(model_config, model_class, config, modules):
|
|
240
|
-
"""Updates configuration for model modules like text and vision modules.
|
|
241
|
-
|
|
242
|
-
Args:
|
|
243
|
-
model_config: The model configuration object that will be updated
|
|
244
|
-
model_class: The model class containing component config classes
|
|
245
|
-
config: Dictionary containing configuration parameters
|
|
246
|
-
modules: List of module names to update configs for (e.g. ["text", "vision"])
|
|
247
|
-
|
|
248
|
-
Returns:
|
|
249
|
-
The updated model_config object
|
|
250
|
-
"""
|
|
251
|
-
for config_name in modules:
|
|
252
|
-
config_attr = f"{config_name}_config"
|
|
253
|
-
if hasattr(model_config, config_attr):
|
|
254
|
-
config_class = getattr(model_class, f"{config_name.title()}Config")
|
|
255
|
-
setattr(
|
|
256
|
-
model_config, config_attr, config_class.from_dict(config[config_attr])
|
|
257
|
-
)
|
|
258
|
-
return model_config
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def load(
|
|
262
|
-
path_or_hf_repo: str,
|
|
263
|
-
adapter_path: Optional[str] = None,
|
|
264
|
-
lazy: bool = False,
|
|
265
|
-
revision: Optional[str] = None,
|
|
266
|
-
**kwargs,
|
|
267
|
-
) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
|
|
268
|
-
"""
|
|
269
|
-
Load the model and tokenizer from a given path or a huggingface repository.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
|
273
|
-
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
274
|
-
Defaults to an empty dictionary.
|
|
275
|
-
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
276
|
-
to the model. Default: ``None``.
|
|
277
|
-
lazy (bool): If False eval the model parameters to make sure they are
|
|
278
|
-
loaded in memory before returning, otherwise they will be loaded
|
|
279
|
-
when needed. Default: ``False``
|
|
280
|
-
revision (str, optional): A revision id which can be a branch name,
|
|
281
|
-
a tag, or a commit hash. Default: ``None``.
|
|
282
|
-
Returns:
|
|
283
|
-
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
|
284
|
-
|
|
285
|
-
Raises:
|
|
286
|
-
FileNotFoundError: If config file or safetensors are not found.
|
|
287
|
-
ValueError: If model class or args class are not found.
|
|
288
|
-
"""
|
|
289
|
-
force_download = kwargs.get("force_download", False)
|
|
290
|
-
model_path = get_model_path(
|
|
291
|
-
path_or_hf_repo, force_download=force_download, revision=revision
|
|
292
|
-
)
|
|
293
|
-
model = load_model(model_path, lazy, **kwargs)
|
|
294
|
-
if adapter_path is not None:
|
|
295
|
-
model = apply_lora_layers(model, adapter_path)
|
|
296
|
-
model.eval()
|
|
297
|
-
|
|
298
|
-
image_processor = load_image_processor(model_path, **kwargs)
|
|
299
|
-
|
|
300
|
-
# Get the eos_token_id from the model config
|
|
301
|
-
eos_token_id = getattr(model.config, "eos_token_id", None)
|
|
302
|
-
|
|
303
|
-
processor = load_processor(model_path, True, eos_token_ids=eos_token_id, **kwargs)
|
|
304
|
-
|
|
305
|
-
if image_processor is not None:
|
|
306
|
-
processor.image_processor = image_processor
|
|
307
|
-
|
|
308
|
-
return model, processor
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
def load_config(model_path: Union[str, Path], **kwargs) -> dict:
|
|
312
|
-
"""Load model configuration from a path or Hugging Face repo.
|
|
313
|
-
|
|
314
|
-
Args:
|
|
315
|
-
model_path: Local path or Hugging Face repo ID to load config from
|
|
316
|
-
**kwargs: Additional keyword arguments to pass to the config loader
|
|
317
|
-
|
|
318
|
-
Returns:
|
|
319
|
-
dict: Model configuration
|
|
320
|
-
|
|
321
|
-
Raises:
|
|
322
|
-
FileNotFoundError: If config.json is not found at the path
|
|
323
|
-
"""
|
|
324
|
-
if isinstance(model_path, str):
|
|
325
|
-
model_path = get_model_path(model_path)
|
|
326
|
-
|
|
327
|
-
try:
|
|
328
|
-
return AutoConfig.from_pretrained(model_path, **kwargs).to_dict()
|
|
329
|
-
except ValueError:
|
|
330
|
-
try:
|
|
331
|
-
with open(model_path / "config.json", encoding="utf-8") as f:
|
|
332
|
-
return json.load(f)
|
|
333
|
-
except FileNotFoundError as exc:
|
|
334
|
-
raise FileNotFoundError(f"Config not found at {model_path}") from exc
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
def load_image_processor(model_path: Union[str, Path], **kwargs) -> BaseImageProcessor:
|
|
338
|
-
if isinstance(model_path, str):
|
|
339
|
-
model_path = get_model_path(model_path)
|
|
340
|
-
|
|
341
|
-
if not kwargs:
|
|
342
|
-
config = load_config(model_path, trust_remote_code=True)
|
|
343
|
-
else:
|
|
344
|
-
config = load_config(model_path, **kwargs)
|
|
345
|
-
|
|
346
|
-
model_class, _ = get_model_and_args(config)
|
|
347
|
-
image_processor = None
|
|
348
|
-
|
|
349
|
-
if hasattr(model_class, "ImageProcessor"):
|
|
350
|
-
init_signature = inspect.signature(model_class.ImageProcessor.__init__)
|
|
351
|
-
|
|
352
|
-
if "config" in init_signature.parameters:
|
|
353
|
-
image_processor = model_class.ImageProcessor(config=config)
|
|
354
|
-
else:
|
|
355
|
-
image_processor = model_class.ImageProcessor()
|
|
356
|
-
|
|
357
|
-
return image_processor
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def load_processor(
|
|
361
|
-
model_path, add_detokenizer=True, eos_token_ids=None, **kwargs
|
|
362
|
-
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
|
363
|
-
#import ipdb; ipdb.set_trace()
|
|
364
|
-
config = load_config(model_path, **kwargs)
|
|
365
|
-
|
|
366
|
-
if "qwen2_5_vl" == str(config.get("model_type", "")):
|
|
367
|
-
from .processing_qwen2_5_vl import Qwen2_5_VLProcessor
|
|
368
|
-
processor = Qwen2_5_VLProcessor.from_pretrained(model_path, **kwargs)
|
|
369
|
-
elif "qwen2_vl" == str(config.get("model_type", "")):
|
|
370
|
-
from .processing_qwen2_vl import Qwen2VLProcessor
|
|
371
|
-
processor = Qwen2VLProcessor.from_pretrained(model_path, **kwargs)
|
|
372
|
-
else:
|
|
373
|
-
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
|
|
374
|
-
|
|
375
|
-
if add_detokenizer:
|
|
376
|
-
detokenizer_class = load_tokenizer(model_path, return_tokenizer=False)
|
|
377
|
-
|
|
378
|
-
# Get the tokenizer object
|
|
379
|
-
tokenizer_obj = (
|
|
380
|
-
processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
# Instantiate the detokenizer
|
|
384
|
-
processor.detokenizer = detokenizer_class(tokenizer_obj)
|
|
385
|
-
|
|
386
|
-
# Determine the EOS token IDs, prioritizing the function argument
|
|
387
|
-
final_eos_token_ids = (
|
|
388
|
-
eos_token_ids if eos_token_ids is not None else tokenizer_obj.eos_token_ids
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
# Create and assign the StoppingCriteria
|
|
392
|
-
criteria = StoppingCriteria(final_eos_token_ids, tokenizer_obj)
|
|
393
|
-
if hasattr(processor, "tokenizer"):
|
|
394
|
-
processor.tokenizer.stopping_criteria = criteria
|
|
395
|
-
else:
|
|
396
|
-
processor.stopping_criteria = criteria
|
|
397
|
-
|
|
398
|
-
return processor
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
def fetch_from_hub(
|
|
402
|
-
model_path: Path, lazy: bool = False, **kwargs
|
|
403
|
-
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
|
404
|
-
model = load_model(model_path, lazy, **kwargs)
|
|
405
|
-
config = load_config(model_path, **kwargs)
|
|
406
|
-
processor = load_processor(
|
|
407
|
-
model_path,
|
|
408
|
-
add_detokenizer=False,
|
|
409
|
-
eos_token_ids=config.get("eos_token_id", None),
|
|
410
|
-
**kwargs,
|
|
411
|
-
)
|
|
412
|
-
return model, config, processor
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
416
|
-
"""
|
|
417
|
-
Splits the weights into smaller shards.
|
|
418
|
-
|
|
419
|
-
Args:
|
|
420
|
-
weights (dict): Model weights.
|
|
421
|
-
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
422
|
-
|
|
423
|
-
Returns:
|
|
424
|
-
list: List of weight shards.
|
|
425
|
-
"""
|
|
426
|
-
max_file_size_bytes = max_file_size_gb << 30
|
|
427
|
-
shards = []
|
|
428
|
-
shard, shard_size = {}, 0
|
|
429
|
-
for k, v in weights.items():
|
|
430
|
-
if shard_size + v.nbytes > max_file_size_bytes:
|
|
431
|
-
shards.append(shard)
|
|
432
|
-
shard, shard_size = {}, 0
|
|
433
|
-
shard[k] = v
|
|
434
|
-
shard_size += v.nbytes
|
|
435
|
-
shards.append(shard)
|
|
436
|
-
return shards
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
440
|
-
"""
|
|
441
|
-
Uploads the model to Hugging Face hub.
|
|
442
|
-
|
|
443
|
-
Args:
|
|
444
|
-
path (str): Local path to the model.
|
|
445
|
-
upload_repo (str): Name of the HF repo to upload to.
|
|
446
|
-
hf_path (str): Path to the original Hugging Face model.
|
|
447
|
-
"""
|
|
448
|
-
import os
|
|
449
|
-
|
|
450
|
-
from huggingface_hub import HfApi, ModelCard, logging
|
|
451
|
-
|
|
452
|
-
from . import __version__
|
|
453
|
-
|
|
454
|
-
card = ModelCard.load(hf_path)
|
|
455
|
-
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
456
|
-
card.text = dedent(
|
|
457
|
-
f"""
|
|
458
|
-
# {upload_repo}
|
|
459
|
-
This model was converted to MLX format from [`{hf_path}`]() using mlx-vlm version **{__version__}**.
|
|
460
|
-
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
|
|
461
|
-
## Use with mlx
|
|
462
|
-
|
|
463
|
-
```bash
|
|
464
|
-
pip install -U mlx-vlm
|
|
465
|
-
```
|
|
466
|
-
|
|
467
|
-
```bash
|
|
468
|
-
python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temperature 0.0 --prompt "Describe this image." --image <path_to_image>
|
|
469
|
-
```
|
|
470
|
-
"""
|
|
471
|
-
)
|
|
472
|
-
card.save(os.path.join(path, "README.md"))
|
|
473
|
-
|
|
474
|
-
logging.set_verbosity_info()
|
|
475
|
-
|
|
476
|
-
api = HfApi()
|
|
477
|
-
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
478
|
-
api.upload_folder(
|
|
479
|
-
folder_path=path,
|
|
480
|
-
repo_id=upload_repo,
|
|
481
|
-
repo_type="model",
|
|
482
|
-
)
|
|
483
|
-
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
|
487
|
-
"""
|
|
488
|
-
Apply repetition penalty to specific logits based on the given context.
|
|
489
|
-
|
|
490
|
-
Paper: https://arxiv.org/abs/1909.05858
|
|
491
|
-
|
|
492
|
-
Args:
|
|
493
|
-
logits (mx.array): The logits produced by the language model.
|
|
494
|
-
generated_tokens (any): A list of N previous tokens.
|
|
495
|
-
penalty (float): The repetition penalty factor to be applied.
|
|
496
|
-
|
|
497
|
-
Returns:
|
|
498
|
-
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
|
499
|
-
"""
|
|
500
|
-
if len(generated_tokens) > 0:
|
|
501
|
-
indices = mx.array([token for token in generated_tokens])
|
|
502
|
-
selected_logits = logits[:, indices]
|
|
503
|
-
selected_logits = mx.where(
|
|
504
|
-
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
|
505
|
-
)
|
|
506
|
-
logits[:, indices] = selected_logits
|
|
507
|
-
return logits
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
def save_weights(
|
|
511
|
-
save_path: Union[str, Path],
|
|
512
|
-
model: nn.Module,
|
|
513
|
-
*,
|
|
514
|
-
donate_weights: bool = False,
|
|
515
|
-
) -> None:
|
|
516
|
-
"""Save model weights into specified directory."""
|
|
517
|
-
if isinstance(save_path, str):
|
|
518
|
-
save_path = Path(save_path)
|
|
519
|
-
|
|
520
|
-
weights = dict(tree_flatten(model.parameters()))
|
|
521
|
-
del model
|
|
522
|
-
|
|
523
|
-
save_path.mkdir(parents=True, exist_ok=True)
|
|
524
|
-
|
|
525
|
-
shards = make_shards(weights)
|
|
526
|
-
shards_count = len(shards)
|
|
527
|
-
shard_file_format = (
|
|
528
|
-
"model-{:05d}-of-{:05d}.safetensors"
|
|
529
|
-
if shards_count > 1
|
|
530
|
-
else "model.safetensors"
|
|
531
|
-
)
|
|
532
|
-
|
|
533
|
-
total_size = sum(v.nbytes for v in weights.values())
|
|
534
|
-
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
535
|
-
|
|
536
|
-
# Write the weights and make sure no references are kept other than the
|
|
537
|
-
# necessary ones
|
|
538
|
-
if donate_weights:
|
|
539
|
-
weights.clear()
|
|
540
|
-
del weights
|
|
541
|
-
|
|
542
|
-
for i in range(len(shards)):
|
|
543
|
-
shard = shards[i]
|
|
544
|
-
shards[i] = None
|
|
545
|
-
shard_name = shard_file_format.format(i + 1, shards_count)
|
|
546
|
-
shard_path = save_path / shard_name
|
|
547
|
-
|
|
548
|
-
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
|
|
549
|
-
|
|
550
|
-
for weight_name in shard.keys():
|
|
551
|
-
index_data["weight_map"][weight_name] = shard_name
|
|
552
|
-
del shard
|
|
553
|
-
|
|
554
|
-
index_data["weight_map"] = {
|
|
555
|
-
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
556
|
-
}
|
|
557
|
-
|
|
558
|
-
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
559
|
-
json.dump(
|
|
560
|
-
index_data,
|
|
561
|
-
f,
|
|
562
|
-
indent=4,
|
|
563
|
-
)
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
def save_config(
|
|
567
|
-
config: dict,
|
|
568
|
-
config_path: Union[str, Path],
|
|
569
|
-
) -> None:
|
|
570
|
-
"""Save the model configuration to the ``config_path``.
|
|
571
|
-
|
|
572
|
-
The final configuration will be sorted before saving for better readability.
|
|
573
|
-
|
|
574
|
-
Args:
|
|
575
|
-
config (dict): The model configuration.
|
|
576
|
-
config_path (Union[str, Path]): Model configuration file path.
|
|
577
|
-
"""
|
|
578
|
-
# Clean unused keys
|
|
579
|
-
config.pop("_name_or_path", None)
|
|
580
|
-
config.pop("torch_dtype", None)
|
|
581
|
-
|
|
582
|
-
# sort the config for better readability
|
|
583
|
-
config = dict(sorted(config.items()))
|
|
584
|
-
|
|
585
|
-
# write the updated config to the config_path (if provided)
|
|
586
|
-
with open(config_path, "w") as fid:
|
|
587
|
-
json.dump(config, fid, indent=4)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
def load_image(image_source: Union[str, Path, BytesIO], timeout: int = 10):
|
|
591
|
-
"""
|
|
592
|
-
Helper function to load an image from either a URL or file.
|
|
593
|
-
"""
|
|
594
|
-
if isinstance(image_source, BytesIO) or Path(image_source).is_file():
|
|
595
|
-
# for base64 encoded images
|
|
596
|
-
try:
|
|
597
|
-
image = Image.open(image_source)
|
|
598
|
-
except IOError as e:
|
|
599
|
-
raise ValueError(
|
|
600
|
-
f"Failed to load image from {image_source} with error: {e}"
|
|
601
|
-
) from e
|
|
602
|
-
elif image_source.startswith(("http://", "https://")):
|
|
603
|
-
try:
|
|
604
|
-
response = requests.get(image_source, stream=True, timeout=timeout)
|
|
605
|
-
response.raise_for_status()
|
|
606
|
-
image = Image.open(response.raw)
|
|
607
|
-
except Exception as e:
|
|
608
|
-
raise ValueError(
|
|
609
|
-
f"Failed to load image from URL: {image_source} with error {e}"
|
|
610
|
-
) from e
|
|
611
|
-
else:
|
|
612
|
-
raise ValueError(
|
|
613
|
-
f"The image {image_source} must be a valid URL or existing file."
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
image = ImageOps.exif_transpose(image)
|
|
617
|
-
image = image.convert("RGB")
|
|
618
|
-
return image
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
def resize_image(img, max_size):
|
|
622
|
-
ratio = min(max_size[0] / img.width, max_size[1] / img.height)
|
|
623
|
-
new_size = (int(img.width * ratio), int(img.height * ratio))
|
|
624
|
-
return img.resize(new_size)
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
def process_image(img, resize_shape, image_processor):
|
|
628
|
-
if isinstance(img, str):
|
|
629
|
-
img = load_image(img)
|
|
630
|
-
if resize_shape is not None and not isinstance(image_processor, BaseImageProcessor):
|
|
631
|
-
img = resize_image(img, resize_shape)
|
|
632
|
-
return img
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
|
|
636
|
-
gcd = np.gcd(orig_sr, target_sr)
|
|
637
|
-
up = target_sr // gcd
|
|
638
|
-
down = orig_sr // gcd
|
|
639
|
-
resampled = signal.resample_poly(audio, up, down, padtype="edge")
|
|
640
|
-
return resampled
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
def load_audio(
|
|
644
|
-
file: str,
|
|
645
|
-
sr: int,
|
|
646
|
-
timeout: int = 10,
|
|
647
|
-
):
|
|
648
|
-
"""
|
|
649
|
-
Helper function to load audio from either a URL or file.
|
|
650
|
-
"""
|
|
651
|
-
if file.startswith(("http://", "https://")):
|
|
652
|
-
try:
|
|
653
|
-
response = requests.get(file, stream=True, timeout=timeout)
|
|
654
|
-
response.raise_for_status()
|
|
655
|
-
audio, sample_rate = sf.read(BytesIO(response.content), always_2d=True)
|
|
656
|
-
except Exception as e:
|
|
657
|
-
raise ValueError(
|
|
658
|
-
f"Failed to load audio from URL: {file} with error {e}"
|
|
659
|
-
) from e
|
|
660
|
-
else:
|
|
661
|
-
audio, sample_rate = sf.read(file, always_2d=True)
|
|
662
|
-
|
|
663
|
-
if sample_rate != sr:
|
|
664
|
-
audio = resample_audio(audio, sample_rate, sr)
|
|
665
|
-
return np.array(audio).mean(axis=1)
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def process_inputs(
|
|
669
|
-
processor,
|
|
670
|
-
prompts,
|
|
671
|
-
images=None,
|
|
672
|
-
audio=None,
|
|
673
|
-
add_special_tokens=False,
|
|
674
|
-
return_tensors="mlx",
|
|
675
|
-
):
|
|
676
|
-
# Get the process method from the processor
|
|
677
|
-
process_method = getattr(processor, "process", processor)
|
|
678
|
-
|
|
679
|
-
# Prepare arguments
|
|
680
|
-
args = {
|
|
681
|
-
"text": prompts,
|
|
682
|
-
"images": images,
|
|
683
|
-
"padding": True,
|
|
684
|
-
"return_tensors": return_tensors,
|
|
685
|
-
}
|
|
686
|
-
|
|
687
|
-
# Add special tokens if supported
|
|
688
|
-
if "add_special_tokens" in inspect.signature(process_method).parameters:
|
|
689
|
-
args["add_special_tokens"] = add_special_tokens
|
|
690
|
-
|
|
691
|
-
# Add audio if provided and supported
|
|
692
|
-
if audio is not None:
|
|
693
|
-
if "audio" in inspect.signature(process_method).parameters:
|
|
694
|
-
args["audio"] = audio
|
|
695
|
-
else:
|
|
696
|
-
raise ValueError(f"Processor {processor} does not support audio parameter")
|
|
697
|
-
|
|
698
|
-
return process_method(**args)
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
def process_inputs_with_fallback(
|
|
702
|
-
processor, prompts, images, audio, add_special_tokens=False, return_tensors="mlx"
|
|
703
|
-
):
|
|
704
|
-
# First attempt with specified return_tensors
|
|
705
|
-
try:
|
|
706
|
-
return process_inputs(
|
|
707
|
-
processor,
|
|
708
|
-
prompts=prompts,
|
|
709
|
-
images=images,
|
|
710
|
-
audio=audio,
|
|
711
|
-
add_special_tokens=add_special_tokens,
|
|
712
|
-
return_tensors=return_tensors,
|
|
713
|
-
)
|
|
714
|
-
except Exception as e:
|
|
715
|
-
# Fallback to PyTorch tensors if MLX fails
|
|
716
|
-
if return_tensors != "pt":
|
|
717
|
-
try:
|
|
718
|
-
return process_inputs(
|
|
719
|
-
processor,
|
|
720
|
-
prompts=prompts,
|
|
721
|
-
images=images,
|
|
722
|
-
audio=audio,
|
|
723
|
-
add_special_tokens=add_special_tokens,
|
|
724
|
-
return_tensors="pt",
|
|
725
|
-
)
|
|
726
|
-
except Exception as fallback_error:
|
|
727
|
-
raise ValueError(
|
|
728
|
-
f"Failed to process inputs with error: {fallback_error}"
|
|
729
|
-
)
|
|
730
|
-
|
|
731
|
-
raise ValueError(f"Failed to process inputs with error: {e}")
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
def prepare_inputs(
|
|
735
|
-
processor,
|
|
736
|
-
images=None,
|
|
737
|
-
audio=None,
|
|
738
|
-
prompts=None,
|
|
739
|
-
image_token_index=None,
|
|
740
|
-
resize_shape=None,
|
|
741
|
-
add_special_tokens=False,
|
|
742
|
-
):
|
|
743
|
-
|
|
744
|
-
if not images and not audio:
|
|
745
|
-
tokenizer = (
|
|
746
|
-
processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
747
|
-
)
|
|
748
|
-
inputs = tokenizer(prompts, add_special_tokens=add_special_tokens)
|
|
749
|
-
input_ids = mx.array([inputs.input_ids])
|
|
750
|
-
mask = mx.array([inputs.attention_mask])
|
|
751
|
-
return {
|
|
752
|
-
"input_ids": input_ids,
|
|
753
|
-
"attention_mask": mask,
|
|
754
|
-
}
|
|
755
|
-
|
|
756
|
-
# Process images
|
|
757
|
-
if images is not None:
|
|
758
|
-
if not isinstance(images, list):
|
|
759
|
-
images = [images]
|
|
760
|
-
|
|
761
|
-
image_processor = (
|
|
762
|
-
processor.image_processor if hasattr(processor, "image_processor") else None
|
|
763
|
-
)
|
|
764
|
-
images = [process_image(img, resize_shape, image_processor) for img in images]
|
|
765
|
-
|
|
766
|
-
# Process audio
|
|
767
|
-
if audio:
|
|
768
|
-
if not isinstance(audio, list):
|
|
769
|
-
audio = [audio]
|
|
770
|
-
|
|
771
|
-
if len(audio) > 1:
|
|
772
|
-
print(
|
|
773
|
-
"\033[33mWarning\033[0m: Single prompt with multiple audio files is not supported yet. Using the first audio file.\n"
|
|
774
|
-
)
|
|
775
|
-
audio = audio[:1]
|
|
776
|
-
|
|
777
|
-
audio = [
|
|
778
|
-
load_audio(audio_file, sr=processor.feature_extractor.sampling_rate)
|
|
779
|
-
for audio_file in audio
|
|
780
|
-
]
|
|
781
|
-
else:
|
|
782
|
-
audio = None
|
|
783
|
-
|
|
784
|
-
model_inputs = {}
|
|
785
|
-
|
|
786
|
-
if hasattr(processor, "image_processor") and isinstance(
|
|
787
|
-
processor.image_processor, BaseImageProcessor
|
|
788
|
-
):
|
|
789
|
-
if not isinstance(prompts, list):
|
|
790
|
-
prompts = [prompts]
|
|
791
|
-
|
|
792
|
-
processor.pad_token = processor.eos_token
|
|
793
|
-
text_chunks = [
|
|
794
|
-
[processor(chunk).input_ids for chunk in prompt.split("<image>")]
|
|
795
|
-
for prompt in prompts
|
|
796
|
-
]
|
|
797
|
-
|
|
798
|
-
# Find the maximum length for padding
|
|
799
|
-
max_length = max(
|
|
800
|
-
sum(len(chunk) for chunk in chunks) + 1 for chunks in text_chunks
|
|
801
|
-
)
|
|
802
|
-
|
|
803
|
-
# Pad and create input_ids
|
|
804
|
-
input_ids = []
|
|
805
|
-
for chunks in text_chunks:
|
|
806
|
-
ids = chunks[0] + [image_token_index] + chunks[1]
|
|
807
|
-
padding = [processor.pad_token_id] * (max_length - len(ids))
|
|
808
|
-
input_ids.append(mx.array(ids + padding))
|
|
809
|
-
|
|
810
|
-
model_inputs["input_ids"] = mx.array(input_ids)
|
|
811
|
-
pixel_values = processor.image_processor.preprocess(images=images)
|
|
812
|
-
model_inputs["pixel_values"] = mx.array(np.stack(pixel_values))
|
|
813
|
-
model_inputs["attention_mask"] = mx.array(
|
|
814
|
-
[(ids != processor.pad_token_id) for ids in input_ids]
|
|
815
|
-
).astype(mx.int32)
|
|
816
|
-
|
|
817
|
-
else:
|
|
818
|
-
if hasattr(processor, "tokenizer"):
|
|
819
|
-
processor.tokenizer.pad_token = processor.tokenizer.eos_token
|
|
820
|
-
|
|
821
|
-
inputs = process_inputs_with_fallback(
|
|
822
|
-
processor,
|
|
823
|
-
images=images,
|
|
824
|
-
audio=audio,
|
|
825
|
-
prompts=prompts,
|
|
826
|
-
add_special_tokens=add_special_tokens,
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
if "images" in inputs:
|
|
830
|
-
inputs["pixel_values"] = inputs["images"]
|
|
831
|
-
inputs.pop("images")
|
|
832
|
-
|
|
833
|
-
model_inputs["attention_mask"] = (
|
|
834
|
-
mx.array(inputs["attention_mask"]) if "attention_mask" in inputs else None
|
|
835
|
-
)
|
|
836
|
-
# Convert inputs to model_inputs with mx.array if present
|
|
837
|
-
for key, value in inputs.items():
|
|
838
|
-
if key not in model_inputs and not isinstance(value, (str, list)):
|
|
839
|
-
model_inputs[key] = mx.array(value)
|
|
840
|
-
|
|
841
|
-
return model_inputs
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
class StoppingCriteria:
|
|
845
|
-
def __init__(self, eos_token_ids: List[int], tokenizer=None):
|
|
846
|
-
|
|
847
|
-
if isinstance(eos_token_ids, int):
|
|
848
|
-
self.eos_token_ids = [eos_token_ids]
|
|
849
|
-
else:
|
|
850
|
-
self.eos_token_ids = eos_token_ids
|
|
851
|
-
|
|
852
|
-
self.tokenizer = tokenizer
|
|
853
|
-
|
|
854
|
-
def add_eos_token_ids(self, new_eos_token_ids: Union[int, List[int]] = None):
|
|
855
|
-
"""
|
|
856
|
-
Add new token IDs to the list of EOS token IDs.
|
|
857
|
-
|
|
858
|
-
Args:
|
|
859
|
-
new_eos_token_ids: Integer, string, or list of integers/strings representing token IDs to add.
|
|
860
|
-
If strings are provided, they will be converted to integers if possible.
|
|
861
|
-
"""
|
|
862
|
-
if new_eos_token_ids is None:
|
|
863
|
-
return
|
|
864
|
-
|
|
865
|
-
if self.tokenizer is None:
|
|
866
|
-
raise ValueError("Processor is not provided")
|
|
867
|
-
|
|
868
|
-
if new_eos_token_ids is not None:
|
|
869
|
-
if isinstance(new_eos_token_ids, str):
|
|
870
|
-
new_eos_token_ids = [new_eos_token_ids]
|
|
871
|
-
new_eos_token_ids = [
|
|
872
|
-
self.tokenizer.encode(" " + token, add_special_tokens=False)[-1]
|
|
873
|
-
for token in new_eos_token_ids
|
|
874
|
-
]
|
|
875
|
-
self.eos_token_ids.extend(new_eos_token_ids)
|
|
876
|
-
|
|
877
|
-
def reset(self, eos_token_ids: List[int] = None):
|
|
878
|
-
eos_token_ids = (
|
|
879
|
-
eos_token_ids if eos_token_ids is not None else self.tokenizer.eos_token_ids
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
if isinstance(eos_token_ids, int):
|
|
883
|
-
eos_token_ids = [eos_token_ids]
|
|
884
|
-
|
|
885
|
-
if self.eos_token_ids != eos_token_ids:
|
|
886
|
-
self.eos_token_ids = eos_token_ids
|
|
887
|
-
|
|
888
|
-
def __call__(self, input_ids: mx.array) -> bool:
|
|
889
|
-
return input_ids in self.eos_token_ids
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
def print_array_report(t: mx.array, label: Optional[str]) -> dict:
|
|
893
|
-
"""
|
|
894
|
-
Return a dictionary report of an MLX array similar to PyTorch's tensor representation.
|
|
895
|
-
Args:
|
|
896
|
-
arr: MLX array to analyze
|
|
897
|
-
Returns:
|
|
898
|
-
Dictionary containing shape, dtype, value representation, and statistics
|
|
899
|
-
"""
|
|
900
|
-
|
|
901
|
-
from pprint import pprint
|
|
902
|
-
|
|
903
|
-
# Get basic statistics
|
|
904
|
-
mean_val = mx.mean(t)
|
|
905
|
-
std_val = mx.std(t)
|
|
906
|
-
min_val = mx.min(t)
|
|
907
|
-
max_val = mx.max(t)
|
|
908
|
-
|
|
909
|
-
report = {
|
|
910
|
-
"shape": f"{tuple(t.shape)}",
|
|
911
|
-
"dtype": str(t.dtype),
|
|
912
|
-
"value": repr(t),
|
|
913
|
-
"mean": f"array({mean_val}, dtype={t.dtype})",
|
|
914
|
-
"std": f"array({std_val}, dtype={t.dtype})",
|
|
915
|
-
"min": f"array({min_val}, dtype={t.dtype})",
|
|
916
|
-
"max": f"array({max_val}, dtype={t.dtype})",
|
|
917
|
-
"label": label if label else "array",
|
|
918
|
-
}
|
|
919
|
-
|
|
920
|
-
# Print each field, handling 'value' specially
|
|
921
|
-
print("{")
|
|
922
|
-
for key, value in report.items():
|
|
923
|
-
if key == "value":
|
|
924
|
-
print(f" '{key}': {value},") # No quotes around value
|
|
925
|
-
else:
|
|
926
|
-
print(f" '{key}': {repr(value)},")
|
|
927
|
-
print("}")
|
|
928
|
-
return report
|