nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,386 +0,0 @@
1
- # Copyright © 2023-2024 Apple Inc.
2
-
3
- import json
4
- import os
5
- from typing import Optional
6
-
7
- import mlx.core as mx
8
- from huggingface_hub import hf_hub_download
9
- from mlx.utils import tree_unflatten
10
-
11
- from .clip import CLIPTextModel
12
- from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
13
- from .tokenizer import Tokenizer
14
- from .unet import UNetModel
15
- from .vae import Autoencoder
16
-
17
- _DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
18
- _MODELS = {
19
- # See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
20
- "stabilityai/sdxl-turbo": {
21
- "unet_config": "unet/config.json",
22
- "unet": "unet/diffusion_pytorch_model.safetensors",
23
- "text_encoder_config": "text_encoder/config.json",
24
- "text_encoder": "text_encoder/model.safetensors",
25
- "text_encoder_2_config": "text_encoder_2/config.json",
26
- "text_encoder_2": "text_encoder_2/model.safetensors",
27
- "vae_config": "vae/config.json",
28
- "vae": "vae/diffusion_pytorch_model.safetensors",
29
- "diffusion_config": "scheduler/scheduler_config.json",
30
- "tokenizer_vocab": "tokenizer/vocab.json",
31
- "tokenizer_merges": "tokenizer/merges.txt",
32
- "tokenizer_2_vocab": "tokenizer_2/vocab.json",
33
- "tokenizer_2_merges": "tokenizer_2/merges.txt",
34
- },
35
- # See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
36
- "stabilityai/stable-diffusion-2-1-base": {
37
- "unet_config": "unet/config.json",
38
- "unet": "unet/diffusion_pytorch_model.safetensors",
39
- "text_encoder_config": "text_encoder/config.json",
40
- "text_encoder": "text_encoder/model.safetensors",
41
- "vae_config": "vae/config.json",
42
- "vae": "vae/diffusion_pytorch_model.safetensors",
43
- "diffusion_config": "scheduler/scheduler_config.json",
44
- "tokenizer_vocab": "tokenizer/vocab.json",
45
- "tokenizer_merges": "tokenizer/merges.txt",
46
- },
47
- }
48
-
49
-
50
- def map_unet_weights(key, value):
51
- # Map up/downsampling
52
- if "downsamplers" in key:
53
- key = key.replace("downsamplers.0.conv", "downsample")
54
- if "upsamplers" in key:
55
- key = key.replace("upsamplers.0.conv", "upsample")
56
-
57
- # Map the mid block
58
- if "mid_block.resnets.0" in key:
59
- key = key.replace("mid_block.resnets.0", "mid_blocks.0")
60
- if "mid_block.attentions.0" in key:
61
- key = key.replace("mid_block.attentions.0", "mid_blocks.1")
62
- if "mid_block.resnets.1" in key:
63
- key = key.replace("mid_block.resnets.1", "mid_blocks.2")
64
-
65
- # Map attention layers
66
- if "to_k" in key:
67
- key = key.replace("to_k", "key_proj")
68
- if "to_out.0" in key:
69
- key = key.replace("to_out.0", "out_proj")
70
- if "to_q" in key:
71
- key = key.replace("to_q", "query_proj")
72
- if "to_v" in key:
73
- key = key.replace("to_v", "value_proj")
74
-
75
- # Map transformer ffn
76
- if "ff.net.2" in key:
77
- key = key.replace("ff.net.2", "linear3")
78
- if "ff.net.0" in key:
79
- k1 = key.replace("ff.net.0.proj", "linear1")
80
- k2 = key.replace("ff.net.0.proj", "linear2")
81
- v1, v2 = mx.split(value, 2)
82
-
83
- return [(k1, v1), (k2, v2)]
84
-
85
- if "conv_shortcut.weight" in key:
86
- value = value.squeeze()
87
-
88
- # Transform the weights from 1x1 convs to linear
89
- if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
90
- value = value.squeeze()
91
-
92
- if len(value.shape) == 4:
93
- value = value.transpose(0, 2, 3, 1)
94
- value = value.reshape(-1).reshape(value.shape)
95
-
96
- return [(key, value)]
97
-
98
-
99
- def map_clip_text_encoder_weights(key, value):
100
- # Remove prefixes
101
- if key.startswith("text_model."):
102
- key = key[11:]
103
- if key.startswith("embeddings."):
104
- key = key[11:]
105
- if key.startswith("encoder."):
106
- key = key[8:]
107
-
108
- # Map attention layers
109
- if "self_attn." in key:
110
- key = key.replace("self_attn.", "attention.")
111
- if "q_proj." in key:
112
- key = key.replace("q_proj.", "query_proj.")
113
- if "k_proj." in key:
114
- key = key.replace("k_proj.", "key_proj.")
115
- if "v_proj." in key:
116
- key = key.replace("v_proj.", "value_proj.")
117
-
118
- # Map ffn layers
119
- if "mlp.fc1" in key:
120
- key = key.replace("mlp.fc1", "linear1")
121
- if "mlp.fc2" in key:
122
- key = key.replace("mlp.fc2", "linear2")
123
-
124
- return [(key, value)]
125
-
126
-
127
- def map_vae_weights(key, value):
128
- # Map up/downsampling
129
- if "downsamplers" in key:
130
- key = key.replace("downsamplers.0.conv", "downsample")
131
- if "upsamplers" in key:
132
- key = key.replace("upsamplers.0.conv", "upsample")
133
-
134
- # Map attention layers
135
- if "to_k" in key:
136
- key = key.replace("to_k", "key_proj")
137
- if "to_out.0" in key:
138
- key = key.replace("to_out.0", "out_proj")
139
- if "to_q" in key:
140
- key = key.replace("to_q", "query_proj")
141
- if "to_v" in key:
142
- key = key.replace("to_v", "value_proj")
143
-
144
- # Map the mid block
145
- if "mid_block.resnets.0" in key:
146
- key = key.replace("mid_block.resnets.0", "mid_blocks.0")
147
- if "mid_block.attentions.0" in key:
148
- key = key.replace("mid_block.attentions.0", "mid_blocks.1")
149
- if "mid_block.resnets.1" in key:
150
- key = key.replace("mid_block.resnets.1", "mid_blocks.2")
151
-
152
- # Map the quant/post_quant layers
153
- if "quant_conv" in key:
154
- key = key.replace("quant_conv", "quant_proj")
155
- value = value.squeeze()
156
-
157
- # Map the conv_shortcut to linear
158
- if "conv_shortcut.weight" in key:
159
- value = value.squeeze()
160
-
161
- if len(value.shape) == 4:
162
- value = value.transpose(0, 2, 3, 1)
163
- value = value.reshape(-1).reshape(value.shape)
164
-
165
- return [(key, value)]
166
-
167
-
168
- def _flatten(params):
169
- return [(k, v) for p in params for (k, v) in p]
170
-
171
-
172
- def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
173
- dtype = mx.float16 if float16 else mx.float32
174
- weights = mx.load(weight_file)
175
- weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
176
- model.update(tree_unflatten(weights))
177
-
178
-
179
- def _check_key(key: str, part: str):
180
- # Check if it's a local path
181
- if os.path.exists(key) or '/' in key or '\\' in key:
182
- # For local paths, we'll use a default model structure
183
- return
184
- if key not in _MODELS:
185
- raise ValueError(
186
- f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
187
- )
188
-
189
- def _get_model_path(key: str, file_path: str):
190
- """Get the full path for a model file, supporting both local and HuggingFace paths"""
191
- if os.path.exists(key) or '/' in key or '\\' in key:
192
- # Local path
193
- return os.path.join(key, file_path)
194
- else:
195
- # HuggingFace path
196
- return hf_hub_download(key, file_path)
197
-
198
-
199
- def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
200
- """Load the stable diffusion UNet from Hugging Face Hub."""
201
- _check_key(key, "load_unet")
202
-
203
- # Get the config path
204
- if os.path.exists(key) or '/' in key or '\\' in key:
205
- # Local path - use SDXL Turbo structure
206
- unet_config = "unet/config.json"
207
- else:
208
- unet_config = _MODELS[key]["unet_config"]
209
-
210
- with open(_get_model_path(key, unet_config)) as f:
211
- config = json.load(f)
212
-
213
- n_blocks = len(config["block_out_channels"])
214
- model = UNetModel(
215
- UNetConfig(
216
- in_channels=config["in_channels"],
217
- out_channels=config["out_channels"],
218
- block_out_channels=config["block_out_channels"],
219
- layers_per_block=[config["layers_per_block"]] * n_blocks,
220
- transformer_layers_per_block=config.get(
221
- "transformer_layers_per_block", (1,) * 4
222
- ),
223
- num_attention_heads=(
224
- [config["attention_head_dim"]] * n_blocks
225
- if isinstance(config["attention_head_dim"], int)
226
- else config["attention_head_dim"]
227
- ),
228
- cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
229
- norm_num_groups=config["norm_num_groups"],
230
- down_block_types=config["down_block_types"],
231
- up_block_types=config["up_block_types"][::-1],
232
- addition_embed_type=config.get("addition_embed_type", None),
233
- addition_time_embed_dim=config.get("addition_time_embed_dim", None),
234
- projection_class_embeddings_input_dim=config.get(
235
- "projection_class_embeddings_input_dim", None
236
- ),
237
- )
238
- )
239
-
240
- # Download the weights and map them into the model
241
- if os.path.exists(key) or '/' in key or '\\' in key:
242
- # Local path - use SDXL Turbo structure
243
- unet_weights = "unet/diffusion_pytorch_model.safetensors"
244
- else:
245
- unet_weights = _MODELS[key]["unet"]
246
-
247
- weight_file = _get_model_path(key, unet_weights)
248
- _load_safetensor_weights(map_unet_weights, model, weight_file, float16)
249
-
250
- return model
251
-
252
-
253
- def load_text_encoder(
254
- key: str = _DEFAULT_MODEL,
255
- float16: bool = False,
256
- model_key: str = "text_encoder",
257
- config_key: Optional[str] = None,
258
- ):
259
- """Load the stable diffusion text encoder from Hugging Face Hub."""
260
- _check_key(key, "load_text_encoder")
261
-
262
- config_key = config_key or (model_key + "_config")
263
-
264
- # Download the config and create the model
265
- if os.path.exists(key) or '/' in key or '\\' in key:
266
- # Local path - use SDXL Turbo structure
267
- text_encoder_config = f"{model_key}/config.json"
268
- else:
269
- text_encoder_config = _MODELS[key][config_key]
270
-
271
- with open(_get_model_path(key, text_encoder_config)) as f:
272
- config = json.load(f)
273
-
274
- with_projection = "WithProjection" in config["architectures"][0]
275
-
276
- model = CLIPTextModel(
277
- CLIPTextModelConfig(
278
- num_layers=config["num_hidden_layers"],
279
- model_dims=config["hidden_size"],
280
- num_heads=config["num_attention_heads"],
281
- max_length=config["max_position_embeddings"],
282
- vocab_size=config["vocab_size"],
283
- projection_dim=config["projection_dim"] if with_projection else None,
284
- hidden_act=config.get("hidden_act", "quick_gelu"),
285
- )
286
- )
287
-
288
- # Download the weights and map them into the model
289
- if os.path.exists(key) or '/' in key or '\\' in key:
290
- # Local path - use SDXL Turbo structure
291
- text_encoder_weights = f"{model_key}/model.safetensors"
292
- else:
293
- text_encoder_weights = _MODELS[key][model_key]
294
-
295
- weight_file = _get_model_path(key, text_encoder_weights)
296
- _load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
297
-
298
- return model
299
-
300
-
301
- def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
302
- """Load the stable diffusion autoencoder from Hugging Face Hub."""
303
- _check_key(key, "load_autoencoder")
304
-
305
- # Download the config and create the model
306
- if os.path.exists(key) or '/' in key or '\\' in key:
307
- # Local path - use SDXL Turbo structure
308
- vae_config = "vae/config.json"
309
- else:
310
- vae_config = _MODELS[key]["vae_config"]
311
-
312
- with open(_get_model_path(key, vae_config)) as f:
313
- config = json.load(f)
314
-
315
- model = Autoencoder(
316
- AutoencoderConfig(
317
- in_channels=config["in_channels"],
318
- out_channels=config["out_channels"],
319
- latent_channels_out=2 * config["latent_channels"],
320
- latent_channels_in=config["latent_channels"],
321
- block_out_channels=config["block_out_channels"],
322
- layers_per_block=config["layers_per_block"],
323
- norm_num_groups=config["norm_num_groups"],
324
- scaling_factor=config.get("scaling_factor", 0.18215),
325
- )
326
- )
327
-
328
- # Download the weights and map them into the model
329
- if os.path.exists(key) or '/' in key or '\\' in key:
330
- # Local path - use SDXL Turbo structure
331
- vae_weights = "vae/diffusion_pytorch_model.safetensors"
332
- else:
333
- vae_weights = _MODELS[key]["vae"]
334
-
335
- weight_file = _get_model_path(key, vae_weights)
336
- _load_safetensor_weights(map_vae_weights, model, weight_file, float16)
337
-
338
- return model
339
-
340
-
341
- def load_diffusion_config(key: str = _DEFAULT_MODEL):
342
- """Load the stable diffusion config from Hugging Face Hub."""
343
- _check_key(key, "load_diffusion_config")
344
-
345
- if os.path.exists(key) or '/' in key or '\\' in key:
346
- # Local path - use SDXL Turbo structure
347
- diffusion_config = "scheduler/scheduler_config.json"
348
- else:
349
- diffusion_config = _MODELS[key]["diffusion_config"]
350
-
351
- with open(_get_model_path(key, diffusion_config)) as f:
352
- config = json.load(f)
353
-
354
- return DiffusionConfig(
355
- beta_start=config["beta_start"],
356
- beta_end=config["beta_end"],
357
- beta_schedule=config["beta_schedule"],
358
- num_train_steps=config["num_train_timesteps"],
359
- )
360
-
361
-
362
- def load_tokenizer(
363
- key: str = _DEFAULT_MODEL,
364
- vocab_key: str = "tokenizer_vocab",
365
- merges_key: str = "tokenizer_merges",
366
- ):
367
- _check_key(key, "load_tokenizer")
368
-
369
- if os.path.exists(key) or '/' in key or '\\' in key:
370
- # Local path - use SDXL Turbo structure
371
- # For SDXL Turbo, we always use the main tokenizer files
372
- vocab_file = _get_model_path(key, "tokenizer/vocab.json")
373
- merges_file = _get_model_path(key, "tokenizer/merges.txt")
374
- else:
375
- vocab_file = _get_model_path(key, _MODELS[key][vocab_key])
376
- merges_file = _get_model_path(key, _MODELS[key][merges_key])
377
-
378
- with open(vocab_file, encoding="utf-8") as f:
379
- vocab = json.load(f)
380
-
381
- with open(merges_file, encoding="utf-8") as f:
382
- bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
383
- bpe_merges = [tuple(m.split()) for m in bpe_merges]
384
- bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
385
-
386
- return Tokenizer(bpe_ranks, vocab)
@@ -1,105 +0,0 @@
1
- # Copyright © 2023 Apple Inc.
2
-
3
- import mlx.core as mx
4
-
5
- from .config import DiffusionConfig
6
-
7
-
8
- def _linspace(a, b, num):
9
- x = mx.arange(0, num) / (num - 1)
10
- return (b - a) * x + a
11
-
12
-
13
- def _interp(y, x_new):
14
- """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
15
- x_low = x_new.astype(mx.int32)
16
- x_high = mx.minimum(x_low + 1, len(y) - 1)
17
-
18
- y_low = y[x_low]
19
- y_high = y[x_high]
20
- delta_x = x_new - x_low
21
- y_new = y_low * (1 - delta_x) + delta_x * y_high
22
-
23
- return y_new
24
-
25
-
26
- class SimpleEulerSampler:
27
- """A simple Euler integrator that can be used to sample from our diffusion models.
28
-
29
- The method ``step()`` performs one Euler step from x_t to x_t_prev.
30
- """
31
-
32
- def __init__(self, config: DiffusionConfig):
33
- # Compute the noise schedule
34
- if config.beta_schedule == "linear":
35
- betas = _linspace(
36
- config.beta_start, config.beta_end, config.num_train_steps
37
- )
38
- elif config.beta_schedule == "scaled_linear":
39
- betas = _linspace(
40
- config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
41
- ).square()
42
- else:
43
- raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
44
-
45
- alphas = 1 - betas
46
- alphas_cumprod = mx.cumprod(alphas)
47
-
48
- self._sigmas = mx.concatenate(
49
- [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
50
- )
51
-
52
- @property
53
- def max_time(self):
54
- return len(self._sigmas) - 1
55
-
56
- def sample_prior(self, shape, dtype=mx.float32, key=None):
57
- noise = mx.random.normal(shape, key=key)
58
- return (
59
- noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
60
- ).astype(dtype)
61
-
62
- def add_noise(self, x, t, key=None):
63
- noise = mx.random.normal(x.shape, key=key)
64
- s = self.sigmas(t)
65
- return (x + noise * s) * (s.square() + 1).rsqrt()
66
-
67
- def sigmas(self, t):
68
- return _interp(self._sigmas, t)
69
-
70
- def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
71
- start_time = start_time or (len(self._sigmas) - 1)
72
- assert 0 < start_time <= (len(self._sigmas) - 1)
73
- steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
74
- return list(zip(steps, steps[1:]))
75
-
76
- def step(self, eps_pred, x_t, t, t_prev):
77
- sigma = self.sigmas(t).astype(eps_pred.dtype)
78
- sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
79
-
80
- dt = sigma_prev - sigma
81
- x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
82
-
83
- x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
84
-
85
- return x_t_prev
86
-
87
-
88
- class SimpleEulerAncestralSampler(SimpleEulerSampler):
89
- def step(self, eps_pred, x_t, t, t_prev):
90
- sigma = self.sigmas(t).astype(eps_pred.dtype)
91
- sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
92
-
93
- sigma2 = sigma.square()
94
- sigma_prev2 = sigma_prev.square()
95
- sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
96
- sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
97
-
98
- dt = sigma_down - sigma
99
- x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
100
- noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
101
- x_t_prev = x_t_prev + noise * sigma_up
102
-
103
- x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
104
-
105
- return x_t_prev
@@ -1,100 +0,0 @@
1
- # Copyright © 2023 Apple Inc.
2
-
3
- import regex
4
-
5
-
6
- class Tokenizer:
7
- """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
8
-
9
- def __init__(self, bpe_ranks, vocab):
10
- self.bpe_ranks = bpe_ranks
11
- self.vocab = vocab
12
- self.pat = regex.compile(
13
- r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
14
- regex.IGNORECASE,
15
- )
16
-
17
- self._cache = {self.bos: self.bos, self.eos: self.eos}
18
-
19
- @property
20
- def bos(self):
21
- return "<|startoftext|>"
22
-
23
- @property
24
- def bos_token(self):
25
- return self.vocab[self.bos]
26
-
27
- @property
28
- def eos(self):
29
- return "<|endoftext|>"
30
-
31
- @property
32
- def eos_token(self):
33
- return self.vocab[self.eos]
34
-
35
- def bpe(self, text):
36
- if text in self._cache:
37
- return self._cache[text]
38
-
39
- unigrams = list(text[:-1]) + [text[-1] + "</w>"]
40
- unique_bigrams = set(zip(unigrams, unigrams[1:]))
41
-
42
- if not unique_bigrams:
43
- return unigrams
44
-
45
- # In every iteration try to merge the two most likely bigrams. If none
46
- # was merged we are done.
47
- #
48
- # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
49
- while unique_bigrams:
50
- bigram = min(
51
- unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
52
- )
53
- if bigram not in self.bpe_ranks:
54
- break
55
-
56
- new_unigrams = []
57
- skip = False
58
- for a, b in zip(unigrams, unigrams[1:]):
59
- if skip:
60
- skip = False
61
- continue
62
-
63
- if (a, b) == bigram:
64
- new_unigrams.append(a + b)
65
- skip = True
66
-
67
- else:
68
- new_unigrams.append(a)
69
-
70
- if not skip:
71
- new_unigrams.append(b)
72
-
73
- unigrams = new_unigrams
74
- unique_bigrams = set(zip(unigrams, unigrams[1:]))
75
-
76
- self._cache[text] = unigrams
77
-
78
- return unigrams
79
-
80
- def tokenize(self, text, prepend_bos=True, append_eos=True):
81
- if isinstance(text, list):
82
- return [self.tokenize(t, prepend_bos, append_eos) for t in text]
83
-
84
- # Lower case cleanup and split according to self.pat. Hugging Face does
85
- # a much more thorough job here but this should suffice for 95% of
86
- # cases.
87
- clean_text = regex.sub(r"\s+", " ", text.lower())
88
- tokens = regex.findall(self.pat, clean_text)
89
-
90
- # Split the tokens according to the byte-pair merge file
91
- bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
92
-
93
- # Map to token ids and return
94
- tokens = [self.vocab[t] for t in bpe_tokens]
95
- if prepend_bos:
96
- tokens = [self.bos_token] + tokens
97
- if append_eos:
98
- tokens.append(self.eos_token)
99
-
100
- return tokens