nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,386 +0,0 @@
|
|
|
1
|
-
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import os
|
|
5
|
-
from typing import Optional
|
|
6
|
-
|
|
7
|
-
import mlx.core as mx
|
|
8
|
-
from huggingface_hub import hf_hub_download
|
|
9
|
-
from mlx.utils import tree_unflatten
|
|
10
|
-
|
|
11
|
-
from .clip import CLIPTextModel
|
|
12
|
-
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
|
13
|
-
from .tokenizer import Tokenizer
|
|
14
|
-
from .unet import UNetModel
|
|
15
|
-
from .vae import Autoencoder
|
|
16
|
-
|
|
17
|
-
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
|
18
|
-
_MODELS = {
|
|
19
|
-
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
|
|
20
|
-
"stabilityai/sdxl-turbo": {
|
|
21
|
-
"unet_config": "unet/config.json",
|
|
22
|
-
"unet": "unet/diffusion_pytorch_model.safetensors",
|
|
23
|
-
"text_encoder_config": "text_encoder/config.json",
|
|
24
|
-
"text_encoder": "text_encoder/model.safetensors",
|
|
25
|
-
"text_encoder_2_config": "text_encoder_2/config.json",
|
|
26
|
-
"text_encoder_2": "text_encoder_2/model.safetensors",
|
|
27
|
-
"vae_config": "vae/config.json",
|
|
28
|
-
"vae": "vae/diffusion_pytorch_model.safetensors",
|
|
29
|
-
"diffusion_config": "scheduler/scheduler_config.json",
|
|
30
|
-
"tokenizer_vocab": "tokenizer/vocab.json",
|
|
31
|
-
"tokenizer_merges": "tokenizer/merges.txt",
|
|
32
|
-
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
|
33
|
-
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
|
34
|
-
},
|
|
35
|
-
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
|
36
|
-
"stabilityai/stable-diffusion-2-1-base": {
|
|
37
|
-
"unet_config": "unet/config.json",
|
|
38
|
-
"unet": "unet/diffusion_pytorch_model.safetensors",
|
|
39
|
-
"text_encoder_config": "text_encoder/config.json",
|
|
40
|
-
"text_encoder": "text_encoder/model.safetensors",
|
|
41
|
-
"vae_config": "vae/config.json",
|
|
42
|
-
"vae": "vae/diffusion_pytorch_model.safetensors",
|
|
43
|
-
"diffusion_config": "scheduler/scheduler_config.json",
|
|
44
|
-
"tokenizer_vocab": "tokenizer/vocab.json",
|
|
45
|
-
"tokenizer_merges": "tokenizer/merges.txt",
|
|
46
|
-
},
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def map_unet_weights(key, value):
|
|
51
|
-
# Map up/downsampling
|
|
52
|
-
if "downsamplers" in key:
|
|
53
|
-
key = key.replace("downsamplers.0.conv", "downsample")
|
|
54
|
-
if "upsamplers" in key:
|
|
55
|
-
key = key.replace("upsamplers.0.conv", "upsample")
|
|
56
|
-
|
|
57
|
-
# Map the mid block
|
|
58
|
-
if "mid_block.resnets.0" in key:
|
|
59
|
-
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
60
|
-
if "mid_block.attentions.0" in key:
|
|
61
|
-
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
62
|
-
if "mid_block.resnets.1" in key:
|
|
63
|
-
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
64
|
-
|
|
65
|
-
# Map attention layers
|
|
66
|
-
if "to_k" in key:
|
|
67
|
-
key = key.replace("to_k", "key_proj")
|
|
68
|
-
if "to_out.0" in key:
|
|
69
|
-
key = key.replace("to_out.0", "out_proj")
|
|
70
|
-
if "to_q" in key:
|
|
71
|
-
key = key.replace("to_q", "query_proj")
|
|
72
|
-
if "to_v" in key:
|
|
73
|
-
key = key.replace("to_v", "value_proj")
|
|
74
|
-
|
|
75
|
-
# Map transformer ffn
|
|
76
|
-
if "ff.net.2" in key:
|
|
77
|
-
key = key.replace("ff.net.2", "linear3")
|
|
78
|
-
if "ff.net.0" in key:
|
|
79
|
-
k1 = key.replace("ff.net.0.proj", "linear1")
|
|
80
|
-
k2 = key.replace("ff.net.0.proj", "linear2")
|
|
81
|
-
v1, v2 = mx.split(value, 2)
|
|
82
|
-
|
|
83
|
-
return [(k1, v1), (k2, v2)]
|
|
84
|
-
|
|
85
|
-
if "conv_shortcut.weight" in key:
|
|
86
|
-
value = value.squeeze()
|
|
87
|
-
|
|
88
|
-
# Transform the weights from 1x1 convs to linear
|
|
89
|
-
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
|
90
|
-
value = value.squeeze()
|
|
91
|
-
|
|
92
|
-
if len(value.shape) == 4:
|
|
93
|
-
value = value.transpose(0, 2, 3, 1)
|
|
94
|
-
value = value.reshape(-1).reshape(value.shape)
|
|
95
|
-
|
|
96
|
-
return [(key, value)]
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def map_clip_text_encoder_weights(key, value):
|
|
100
|
-
# Remove prefixes
|
|
101
|
-
if key.startswith("text_model."):
|
|
102
|
-
key = key[11:]
|
|
103
|
-
if key.startswith("embeddings."):
|
|
104
|
-
key = key[11:]
|
|
105
|
-
if key.startswith("encoder."):
|
|
106
|
-
key = key[8:]
|
|
107
|
-
|
|
108
|
-
# Map attention layers
|
|
109
|
-
if "self_attn." in key:
|
|
110
|
-
key = key.replace("self_attn.", "attention.")
|
|
111
|
-
if "q_proj." in key:
|
|
112
|
-
key = key.replace("q_proj.", "query_proj.")
|
|
113
|
-
if "k_proj." in key:
|
|
114
|
-
key = key.replace("k_proj.", "key_proj.")
|
|
115
|
-
if "v_proj." in key:
|
|
116
|
-
key = key.replace("v_proj.", "value_proj.")
|
|
117
|
-
|
|
118
|
-
# Map ffn layers
|
|
119
|
-
if "mlp.fc1" in key:
|
|
120
|
-
key = key.replace("mlp.fc1", "linear1")
|
|
121
|
-
if "mlp.fc2" in key:
|
|
122
|
-
key = key.replace("mlp.fc2", "linear2")
|
|
123
|
-
|
|
124
|
-
return [(key, value)]
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def map_vae_weights(key, value):
|
|
128
|
-
# Map up/downsampling
|
|
129
|
-
if "downsamplers" in key:
|
|
130
|
-
key = key.replace("downsamplers.0.conv", "downsample")
|
|
131
|
-
if "upsamplers" in key:
|
|
132
|
-
key = key.replace("upsamplers.0.conv", "upsample")
|
|
133
|
-
|
|
134
|
-
# Map attention layers
|
|
135
|
-
if "to_k" in key:
|
|
136
|
-
key = key.replace("to_k", "key_proj")
|
|
137
|
-
if "to_out.0" in key:
|
|
138
|
-
key = key.replace("to_out.0", "out_proj")
|
|
139
|
-
if "to_q" in key:
|
|
140
|
-
key = key.replace("to_q", "query_proj")
|
|
141
|
-
if "to_v" in key:
|
|
142
|
-
key = key.replace("to_v", "value_proj")
|
|
143
|
-
|
|
144
|
-
# Map the mid block
|
|
145
|
-
if "mid_block.resnets.0" in key:
|
|
146
|
-
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
|
147
|
-
if "mid_block.attentions.0" in key:
|
|
148
|
-
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
|
149
|
-
if "mid_block.resnets.1" in key:
|
|
150
|
-
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
|
151
|
-
|
|
152
|
-
# Map the quant/post_quant layers
|
|
153
|
-
if "quant_conv" in key:
|
|
154
|
-
key = key.replace("quant_conv", "quant_proj")
|
|
155
|
-
value = value.squeeze()
|
|
156
|
-
|
|
157
|
-
# Map the conv_shortcut to linear
|
|
158
|
-
if "conv_shortcut.weight" in key:
|
|
159
|
-
value = value.squeeze()
|
|
160
|
-
|
|
161
|
-
if len(value.shape) == 4:
|
|
162
|
-
value = value.transpose(0, 2, 3, 1)
|
|
163
|
-
value = value.reshape(-1).reshape(value.shape)
|
|
164
|
-
|
|
165
|
-
return [(key, value)]
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def _flatten(params):
|
|
169
|
-
return [(k, v) for p in params for (k, v) in p]
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
|
173
|
-
dtype = mx.float16 if float16 else mx.float32
|
|
174
|
-
weights = mx.load(weight_file)
|
|
175
|
-
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
|
|
176
|
-
model.update(tree_unflatten(weights))
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def _check_key(key: str, part: str):
|
|
180
|
-
# Check if it's a local path
|
|
181
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
182
|
-
# For local paths, we'll use a default model structure
|
|
183
|
-
return
|
|
184
|
-
if key not in _MODELS:
|
|
185
|
-
raise ValueError(
|
|
186
|
-
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
def _get_model_path(key: str, file_path: str):
|
|
190
|
-
"""Get the full path for a model file, supporting both local and HuggingFace paths"""
|
|
191
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
192
|
-
# Local path
|
|
193
|
-
return os.path.join(key, file_path)
|
|
194
|
-
else:
|
|
195
|
-
# HuggingFace path
|
|
196
|
-
return hf_hub_download(key, file_path)
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
200
|
-
"""Load the stable diffusion UNet from Hugging Face Hub."""
|
|
201
|
-
_check_key(key, "load_unet")
|
|
202
|
-
|
|
203
|
-
# Get the config path
|
|
204
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
205
|
-
# Local path - use SDXL Turbo structure
|
|
206
|
-
unet_config = "unet/config.json"
|
|
207
|
-
else:
|
|
208
|
-
unet_config = _MODELS[key]["unet_config"]
|
|
209
|
-
|
|
210
|
-
with open(_get_model_path(key, unet_config)) as f:
|
|
211
|
-
config = json.load(f)
|
|
212
|
-
|
|
213
|
-
n_blocks = len(config["block_out_channels"])
|
|
214
|
-
model = UNetModel(
|
|
215
|
-
UNetConfig(
|
|
216
|
-
in_channels=config["in_channels"],
|
|
217
|
-
out_channels=config["out_channels"],
|
|
218
|
-
block_out_channels=config["block_out_channels"],
|
|
219
|
-
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
|
220
|
-
transformer_layers_per_block=config.get(
|
|
221
|
-
"transformer_layers_per_block", (1,) * 4
|
|
222
|
-
),
|
|
223
|
-
num_attention_heads=(
|
|
224
|
-
[config["attention_head_dim"]] * n_blocks
|
|
225
|
-
if isinstance(config["attention_head_dim"], int)
|
|
226
|
-
else config["attention_head_dim"]
|
|
227
|
-
),
|
|
228
|
-
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
|
229
|
-
norm_num_groups=config["norm_num_groups"],
|
|
230
|
-
down_block_types=config["down_block_types"],
|
|
231
|
-
up_block_types=config["up_block_types"][::-1],
|
|
232
|
-
addition_embed_type=config.get("addition_embed_type", None),
|
|
233
|
-
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
|
234
|
-
projection_class_embeddings_input_dim=config.get(
|
|
235
|
-
"projection_class_embeddings_input_dim", None
|
|
236
|
-
),
|
|
237
|
-
)
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
# Download the weights and map them into the model
|
|
241
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
242
|
-
# Local path - use SDXL Turbo structure
|
|
243
|
-
unet_weights = "unet/diffusion_pytorch_model.safetensors"
|
|
244
|
-
else:
|
|
245
|
-
unet_weights = _MODELS[key]["unet"]
|
|
246
|
-
|
|
247
|
-
weight_file = _get_model_path(key, unet_weights)
|
|
248
|
-
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
|
249
|
-
|
|
250
|
-
return model
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def load_text_encoder(
|
|
254
|
-
key: str = _DEFAULT_MODEL,
|
|
255
|
-
float16: bool = False,
|
|
256
|
-
model_key: str = "text_encoder",
|
|
257
|
-
config_key: Optional[str] = None,
|
|
258
|
-
):
|
|
259
|
-
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
|
260
|
-
_check_key(key, "load_text_encoder")
|
|
261
|
-
|
|
262
|
-
config_key = config_key or (model_key + "_config")
|
|
263
|
-
|
|
264
|
-
# Download the config and create the model
|
|
265
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
266
|
-
# Local path - use SDXL Turbo structure
|
|
267
|
-
text_encoder_config = f"{model_key}/config.json"
|
|
268
|
-
else:
|
|
269
|
-
text_encoder_config = _MODELS[key][config_key]
|
|
270
|
-
|
|
271
|
-
with open(_get_model_path(key, text_encoder_config)) as f:
|
|
272
|
-
config = json.load(f)
|
|
273
|
-
|
|
274
|
-
with_projection = "WithProjection" in config["architectures"][0]
|
|
275
|
-
|
|
276
|
-
model = CLIPTextModel(
|
|
277
|
-
CLIPTextModelConfig(
|
|
278
|
-
num_layers=config["num_hidden_layers"],
|
|
279
|
-
model_dims=config["hidden_size"],
|
|
280
|
-
num_heads=config["num_attention_heads"],
|
|
281
|
-
max_length=config["max_position_embeddings"],
|
|
282
|
-
vocab_size=config["vocab_size"],
|
|
283
|
-
projection_dim=config["projection_dim"] if with_projection else None,
|
|
284
|
-
hidden_act=config.get("hidden_act", "quick_gelu"),
|
|
285
|
-
)
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
# Download the weights and map them into the model
|
|
289
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
290
|
-
# Local path - use SDXL Turbo structure
|
|
291
|
-
text_encoder_weights = f"{model_key}/model.safetensors"
|
|
292
|
-
else:
|
|
293
|
-
text_encoder_weights = _MODELS[key][model_key]
|
|
294
|
-
|
|
295
|
-
weight_file = _get_model_path(key, text_encoder_weights)
|
|
296
|
-
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
|
297
|
-
|
|
298
|
-
return model
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
302
|
-
"""Load the stable diffusion autoencoder from Hugging Face Hub."""
|
|
303
|
-
_check_key(key, "load_autoencoder")
|
|
304
|
-
|
|
305
|
-
# Download the config and create the model
|
|
306
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
307
|
-
# Local path - use SDXL Turbo structure
|
|
308
|
-
vae_config = "vae/config.json"
|
|
309
|
-
else:
|
|
310
|
-
vae_config = _MODELS[key]["vae_config"]
|
|
311
|
-
|
|
312
|
-
with open(_get_model_path(key, vae_config)) as f:
|
|
313
|
-
config = json.load(f)
|
|
314
|
-
|
|
315
|
-
model = Autoencoder(
|
|
316
|
-
AutoencoderConfig(
|
|
317
|
-
in_channels=config["in_channels"],
|
|
318
|
-
out_channels=config["out_channels"],
|
|
319
|
-
latent_channels_out=2 * config["latent_channels"],
|
|
320
|
-
latent_channels_in=config["latent_channels"],
|
|
321
|
-
block_out_channels=config["block_out_channels"],
|
|
322
|
-
layers_per_block=config["layers_per_block"],
|
|
323
|
-
norm_num_groups=config["norm_num_groups"],
|
|
324
|
-
scaling_factor=config.get("scaling_factor", 0.18215),
|
|
325
|
-
)
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
# Download the weights and map them into the model
|
|
329
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
330
|
-
# Local path - use SDXL Turbo structure
|
|
331
|
-
vae_weights = "vae/diffusion_pytorch_model.safetensors"
|
|
332
|
-
else:
|
|
333
|
-
vae_weights = _MODELS[key]["vae"]
|
|
334
|
-
|
|
335
|
-
weight_file = _get_model_path(key, vae_weights)
|
|
336
|
-
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
|
337
|
-
|
|
338
|
-
return model
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
|
342
|
-
"""Load the stable diffusion config from Hugging Face Hub."""
|
|
343
|
-
_check_key(key, "load_diffusion_config")
|
|
344
|
-
|
|
345
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
346
|
-
# Local path - use SDXL Turbo structure
|
|
347
|
-
diffusion_config = "scheduler/scheduler_config.json"
|
|
348
|
-
else:
|
|
349
|
-
diffusion_config = _MODELS[key]["diffusion_config"]
|
|
350
|
-
|
|
351
|
-
with open(_get_model_path(key, diffusion_config)) as f:
|
|
352
|
-
config = json.load(f)
|
|
353
|
-
|
|
354
|
-
return DiffusionConfig(
|
|
355
|
-
beta_start=config["beta_start"],
|
|
356
|
-
beta_end=config["beta_end"],
|
|
357
|
-
beta_schedule=config["beta_schedule"],
|
|
358
|
-
num_train_steps=config["num_train_timesteps"],
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
def load_tokenizer(
|
|
363
|
-
key: str = _DEFAULT_MODEL,
|
|
364
|
-
vocab_key: str = "tokenizer_vocab",
|
|
365
|
-
merges_key: str = "tokenizer_merges",
|
|
366
|
-
):
|
|
367
|
-
_check_key(key, "load_tokenizer")
|
|
368
|
-
|
|
369
|
-
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
370
|
-
# Local path - use SDXL Turbo structure
|
|
371
|
-
# For SDXL Turbo, we always use the main tokenizer files
|
|
372
|
-
vocab_file = _get_model_path(key, "tokenizer/vocab.json")
|
|
373
|
-
merges_file = _get_model_path(key, "tokenizer/merges.txt")
|
|
374
|
-
else:
|
|
375
|
-
vocab_file = _get_model_path(key, _MODELS[key][vocab_key])
|
|
376
|
-
merges_file = _get_model_path(key, _MODELS[key][merges_key])
|
|
377
|
-
|
|
378
|
-
with open(vocab_file, encoding="utf-8") as f:
|
|
379
|
-
vocab = json.load(f)
|
|
380
|
-
|
|
381
|
-
with open(merges_file, encoding="utf-8") as f:
|
|
382
|
-
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
383
|
-
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
|
384
|
-
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
|
385
|
-
|
|
386
|
-
return Tokenizer(bpe_ranks, vocab)
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
# Copyright © 2023 Apple Inc.
|
|
2
|
-
|
|
3
|
-
import mlx.core as mx
|
|
4
|
-
|
|
5
|
-
from .config import DiffusionConfig
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def _linspace(a, b, num):
|
|
9
|
-
x = mx.arange(0, num) / (num - 1)
|
|
10
|
-
return (b - a) * x + a
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def _interp(y, x_new):
|
|
14
|
-
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
|
15
|
-
x_low = x_new.astype(mx.int32)
|
|
16
|
-
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
|
17
|
-
|
|
18
|
-
y_low = y[x_low]
|
|
19
|
-
y_high = y[x_high]
|
|
20
|
-
delta_x = x_new - x_low
|
|
21
|
-
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
|
22
|
-
|
|
23
|
-
return y_new
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class SimpleEulerSampler:
|
|
27
|
-
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
|
28
|
-
|
|
29
|
-
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(self, config: DiffusionConfig):
|
|
33
|
-
# Compute the noise schedule
|
|
34
|
-
if config.beta_schedule == "linear":
|
|
35
|
-
betas = _linspace(
|
|
36
|
-
config.beta_start, config.beta_end, config.num_train_steps
|
|
37
|
-
)
|
|
38
|
-
elif config.beta_schedule == "scaled_linear":
|
|
39
|
-
betas = _linspace(
|
|
40
|
-
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
|
41
|
-
).square()
|
|
42
|
-
else:
|
|
43
|
-
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
|
44
|
-
|
|
45
|
-
alphas = 1 - betas
|
|
46
|
-
alphas_cumprod = mx.cumprod(alphas)
|
|
47
|
-
|
|
48
|
-
self._sigmas = mx.concatenate(
|
|
49
|
-
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
@property
|
|
53
|
-
def max_time(self):
|
|
54
|
-
return len(self._sigmas) - 1
|
|
55
|
-
|
|
56
|
-
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
|
57
|
-
noise = mx.random.normal(shape, key=key)
|
|
58
|
-
return (
|
|
59
|
-
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
|
60
|
-
).astype(dtype)
|
|
61
|
-
|
|
62
|
-
def add_noise(self, x, t, key=None):
|
|
63
|
-
noise = mx.random.normal(x.shape, key=key)
|
|
64
|
-
s = self.sigmas(t)
|
|
65
|
-
return (x + noise * s) * (s.square() + 1).rsqrt()
|
|
66
|
-
|
|
67
|
-
def sigmas(self, t):
|
|
68
|
-
return _interp(self._sigmas, t)
|
|
69
|
-
|
|
70
|
-
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
|
71
|
-
start_time = start_time or (len(self._sigmas) - 1)
|
|
72
|
-
assert 0 < start_time <= (len(self._sigmas) - 1)
|
|
73
|
-
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
|
74
|
-
return list(zip(steps, steps[1:]))
|
|
75
|
-
|
|
76
|
-
def step(self, eps_pred, x_t, t, t_prev):
|
|
77
|
-
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
|
78
|
-
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
|
79
|
-
|
|
80
|
-
dt = sigma_prev - sigma
|
|
81
|
-
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
|
82
|
-
|
|
83
|
-
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
|
84
|
-
|
|
85
|
-
return x_t_prev
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class SimpleEulerAncestralSampler(SimpleEulerSampler):
|
|
89
|
-
def step(self, eps_pred, x_t, t, t_prev):
|
|
90
|
-
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
|
91
|
-
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
|
92
|
-
|
|
93
|
-
sigma2 = sigma.square()
|
|
94
|
-
sigma_prev2 = sigma_prev.square()
|
|
95
|
-
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
|
|
96
|
-
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
|
|
97
|
-
|
|
98
|
-
dt = sigma_down - sigma
|
|
99
|
-
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
|
|
100
|
-
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
|
|
101
|
-
x_t_prev = x_t_prev + noise * sigma_up
|
|
102
|
-
|
|
103
|
-
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
|
|
104
|
-
|
|
105
|
-
return x_t_prev
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
# Copyright © 2023 Apple Inc.
|
|
2
|
-
|
|
3
|
-
import regex
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class Tokenizer:
|
|
7
|
-
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
|
8
|
-
|
|
9
|
-
def __init__(self, bpe_ranks, vocab):
|
|
10
|
-
self.bpe_ranks = bpe_ranks
|
|
11
|
-
self.vocab = vocab
|
|
12
|
-
self.pat = regex.compile(
|
|
13
|
-
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
|
14
|
-
regex.IGNORECASE,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
|
18
|
-
|
|
19
|
-
@property
|
|
20
|
-
def bos(self):
|
|
21
|
-
return "<|startoftext|>"
|
|
22
|
-
|
|
23
|
-
@property
|
|
24
|
-
def bos_token(self):
|
|
25
|
-
return self.vocab[self.bos]
|
|
26
|
-
|
|
27
|
-
@property
|
|
28
|
-
def eos(self):
|
|
29
|
-
return "<|endoftext|>"
|
|
30
|
-
|
|
31
|
-
@property
|
|
32
|
-
def eos_token(self):
|
|
33
|
-
return self.vocab[self.eos]
|
|
34
|
-
|
|
35
|
-
def bpe(self, text):
|
|
36
|
-
if text in self._cache:
|
|
37
|
-
return self._cache[text]
|
|
38
|
-
|
|
39
|
-
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
|
40
|
-
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
41
|
-
|
|
42
|
-
if not unique_bigrams:
|
|
43
|
-
return unigrams
|
|
44
|
-
|
|
45
|
-
# In every iteration try to merge the two most likely bigrams. If none
|
|
46
|
-
# was merged we are done.
|
|
47
|
-
#
|
|
48
|
-
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
|
49
|
-
while unique_bigrams:
|
|
50
|
-
bigram = min(
|
|
51
|
-
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
|
52
|
-
)
|
|
53
|
-
if bigram not in self.bpe_ranks:
|
|
54
|
-
break
|
|
55
|
-
|
|
56
|
-
new_unigrams = []
|
|
57
|
-
skip = False
|
|
58
|
-
for a, b in zip(unigrams, unigrams[1:]):
|
|
59
|
-
if skip:
|
|
60
|
-
skip = False
|
|
61
|
-
continue
|
|
62
|
-
|
|
63
|
-
if (a, b) == bigram:
|
|
64
|
-
new_unigrams.append(a + b)
|
|
65
|
-
skip = True
|
|
66
|
-
|
|
67
|
-
else:
|
|
68
|
-
new_unigrams.append(a)
|
|
69
|
-
|
|
70
|
-
if not skip:
|
|
71
|
-
new_unigrams.append(b)
|
|
72
|
-
|
|
73
|
-
unigrams = new_unigrams
|
|
74
|
-
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
|
75
|
-
|
|
76
|
-
self._cache[text] = unigrams
|
|
77
|
-
|
|
78
|
-
return unigrams
|
|
79
|
-
|
|
80
|
-
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
|
81
|
-
if isinstance(text, list):
|
|
82
|
-
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
|
83
|
-
|
|
84
|
-
# Lower case cleanup and split according to self.pat. Hugging Face does
|
|
85
|
-
# a much more thorough job here but this should suffice for 95% of
|
|
86
|
-
# cases.
|
|
87
|
-
clean_text = regex.sub(r"\s+", " ", text.lower())
|
|
88
|
-
tokens = regex.findall(self.pat, clean_text)
|
|
89
|
-
|
|
90
|
-
# Split the tokens according to the byte-pair merge file
|
|
91
|
-
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
|
92
|
-
|
|
93
|
-
# Map to token ids and return
|
|
94
|
-
tokens = [self.vocab[t] for t in bpe_tokens]
|
|
95
|
-
if prepend_bos:
|
|
96
|
-
tokens = [self.bos_token] + tokens
|
|
97
|
-
if append_eos:
|
|
98
|
-
tokens.append(self.eos_token)
|
|
99
|
-
|
|
100
|
-
return tokens
|