nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (200) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
  5. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
  6. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
  7. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
  8. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
  9. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
  10. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  11. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  12. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  13. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  14. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  16. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  17. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  18. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  19. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  21. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  22. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  23. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  37. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  39. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  40. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  41. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  43. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  44. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  45. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  47. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  48. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  49. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  58. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  59. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  61. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  62. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  64. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  65. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  199. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
  200. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
@@ -1,572 +0,0 @@
1
- import argparse
2
- import codecs
3
- import contextlib
4
- import functools
5
- import os
6
- import re
7
- import time
8
- from dataclasses import dataclass
9
- from typing import Any, Dict, Generator, List, Optional, Tuple, Union
10
-
11
- import mlx.core as mx
12
- import mlx.nn as nn
13
- from mlx_lm.generate import maybe_quantize_kv_cache
14
- from transformers import PreTrainedTokenizer
15
-
16
- from .modeling.models import cache
17
- from .modeling.prompt_utils import apply_chat_template
18
- from .modeling.sample_utils import top_p_sampling
19
- from .modeling.utils import (
20
- StoppingCriteria,
21
- apply_repetition_penalty,
22
- load,
23
- prepare_inputs,
24
- tree_reduce,
25
- )
26
-
27
- DEFAULT_MODEL_PATH = "mlx-community/gemma-3-4b-it-8bit"
28
-
29
-
30
- def parse_media_from_input(user_input):
31
- """Parse quoted media files from user input and return prompt and media paths"""
32
- # Find all quoted strings (both single and double quotes)
33
- quoted_pattern = r'["\']([^"\']*)["\']'
34
- quoted_matches = re.findall(quoted_pattern, user_input)
35
-
36
- # Remove quoted strings from the input to get the actual prompt
37
- prompt = re.sub(quoted_pattern, '', user_input).strip()
38
-
39
- # Separate image and audio files based on extensions
40
- image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
41
- audio_extensions = {'.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a'}
42
-
43
- image_paths = []
44
- audio_paths = []
45
-
46
- for quoted_file in quoted_matches:
47
- if quoted_file: # Skip empty quotes
48
- # Expand user path if it starts with ~
49
- if quoted_file.startswith('~'):
50
- quoted_file = os.path.expanduser(quoted_file)
51
-
52
- # Check if file exists
53
- if not os.path.exists(quoted_file):
54
- print(f"Warning: File '{quoted_file}' not found")
55
- continue
56
-
57
- file_ext = os.path.splitext(quoted_file.lower())[1]
58
- if file_ext in image_extensions:
59
- image_paths.append(quoted_file)
60
- elif file_ext in audio_extensions:
61
- audio_paths.append(quoted_file)
62
-
63
- return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
64
-
65
-
66
- def parse_arguments():
67
- parser = argparse.ArgumentParser(
68
- description="Generate text from an image using a model."
69
- )
70
- parser.add_argument(
71
- "--model",
72
- type=str,
73
- default=DEFAULT_MODEL_PATH,
74
- help="The path to the local model directory or Hugging Face repo.",
75
- )
76
- return parser.parse_args()
77
-
78
-
79
- # A stream on the default device just for generation
80
- generation_stream = mx.new_stream(mx.default_device())
81
-
82
-
83
- @contextlib.contextmanager
84
- def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
85
- """
86
- A context manager to temporarily change the wired limit.
87
-
88
- Note, the wired limit should not be changed during an async eval. If an
89
- async eval could be running pass in the streams to synchronize with prior
90
- to exiting the context manager.
91
- """
92
- model_bytes = tree_reduce(
93
- lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
94
- )
95
- max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
96
- if model_bytes > 0.9 * max_rec_size:
97
- model_mb = model_bytes // 2**20
98
- max_rec_mb = max_rec_size // 2**20
99
- print(
100
- f"[WARNING] Generating with a model that requires {model_mb} MB "
101
- f"which is close to the maximum recommended size of {max_rec_mb} "
102
- "MB. This can be slow. See the documentation for possible work-arounds: "
103
- "https://github.com/ml-explore/mlx-lm/tree/main#large-models"
104
- )
105
- old_limit = mx.set_wired_limit(max_rec_size)
106
- try:
107
- yield None
108
- finally:
109
- if streams is not None:
110
- for s in streams:
111
- mx.synchronize(s)
112
- else:
113
- mx.synchronize()
114
- mx.set_wired_limit(old_limit)
115
-
116
-
117
- @dataclass
118
- class GenerationResult:
119
- text: str
120
- token: Optional[int]
121
- logprobs: Optional[List[float]]
122
- prompt_tokens: int
123
- generation_tokens: int
124
- prompt_tps: float
125
- generation_tps: float
126
- peak_memory: float
127
-
128
-
129
- def generate_step(
130
- input_ids: mx.array,
131
- model: nn.Module,
132
- pixel_values,
133
- mask,
134
- *,
135
- max_tokens: int = 256,
136
- temperature: float = 0.0,
137
- repetition_penalty: Optional[float] = None,
138
- repetition_context_size: Optional[int] = 20,
139
- top_p: float = 1.0,
140
- logit_bias: Optional[Dict[int, float]] = None,
141
- prompt_cache: Optional[List[Any]] = None,
142
- max_kv_size: Optional[int] = None,
143
- kv_bits: Optional[int] = None,
144
- kv_group_size: int = 64,
145
- quantized_kv_start: int = 0,
146
- **kwargs,
147
- ) -> Generator[Tuple[mx.array, mx.array], None, None]:
148
- """
149
- A generator producing token ids based on the given prompt from the model.
150
-
151
- Args:
152
- prompt (mx.array): The input prompt.
153
- model (nn.Module): The model to use for generation.
154
- temperature (float): The temperature for sampling, if 0 the argmax is used.
155
- Default: ``0``.
156
- repetition_penalty (float, optional): The penalty factor for repeating
157
- tokens.
158
- repetition_context_size (int, optional): The number of tokens to
159
- consider for repetition penalty. Default: ``20``.
160
- top_p (float, optional): Nulceus sampling, higher means model considers
161
- more less likely words.
162
- logit_bias (dictionary, optional): Additive logit bias.
163
-
164
- Yields:
165
- Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
166
- one token and a vector of log probabilities.
167
- """
168
-
169
- quantize_cache_fn = functools.partial(
170
- maybe_quantize_kv_cache,
171
- quantized_kv_start=quantized_kv_start,
172
- kv_group_size=kv_group_size,
173
- kv_bits=kv_bits,
174
- )
175
-
176
- def sample(logits: mx.array) -> Tuple[mx.array, float]:
177
- if logit_bias:
178
- indices = mx.array(list(logit_bias.keys()))
179
- values = mx.array(list(logit_bias.values()))
180
- logits[:, indices] += values
181
- logprobs = logits - mx.logsumexp(logits)
182
-
183
- if temperature == 0:
184
- token = mx.argmax(logits, axis=-1)
185
- else:
186
- if top_p > 0 and top_p < 1.0:
187
- token = top_p_sampling(logits, top_p, temperature)
188
- else:
189
- token = mx.random.categorical(logits * (1 / temperature))
190
-
191
- return token, logprobs
192
-
193
- if repetition_penalty and (
194
- repetition_penalty < 0 or not isinstance(repetition_penalty, float)
195
- ):
196
- raise ValueError(
197
- f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
198
- )
199
-
200
- y = input_ids
201
- # Create the KV cache for generation
202
- if prompt_cache is None:
203
- prompt_cache = cache.make_prompt_cache(
204
- model.language_model,
205
- max_kv_size=max_kv_size,
206
- )
207
-
208
- repetition_context = input_ids.reshape(-1).tolist()
209
-
210
- if repetition_context_size:
211
- repetition_context = repetition_context[-repetition_context_size:]
212
-
213
- def _step(y, **kwargs):
214
- with mx.stream(generation_stream):
215
- nonlocal repetition_context
216
- if "decoder_input_ids" in kwargs:
217
- outputs = model.language_model(
218
- cache=prompt_cache,
219
- **kwargs,
220
- )
221
- else:
222
- outputs = model.language_model(
223
- y[None],
224
- cache=prompt_cache,
225
- **kwargs,
226
- )
227
-
228
- logits = outputs.logits[:, -1, :]
229
-
230
- if repetition_penalty:
231
- logits = apply_repetition_penalty(
232
- logits, repetition_context, repetition_penalty
233
- )
234
- y, logprobs = sample(logits)
235
- repetition_context.append(y.item())
236
- else:
237
- y, logprobs = sample(logits)
238
-
239
- if repetition_context_size:
240
- if len(repetition_context) > repetition_context_size:
241
- repetition_context = repetition_context[-repetition_context_size:]
242
-
243
- quantize_cache_fn(prompt_cache)
244
- return y, logprobs.squeeze(0)
245
-
246
- outputs = model(input_ids, pixel_values, cache=prompt_cache, mask=mask, **kwargs)
247
-
248
- logits = outputs.logits[:, -1, :]
249
- quantize_cache_fn(prompt_cache)
250
- y, logprobs = sample(logits)
251
- mx.async_eval(y)
252
-
253
- if outputs.cross_attention_states is not None:
254
- kwargs = {
255
- k: v
256
- for k, v in zip(
257
- ["cross_attention_states"], [outputs.cross_attention_states]
258
- )
259
- }
260
- elif outputs.encoder_outputs is not None:
261
- kwargs = {
262
- "decoder_input_ids": y[None],
263
- "encoder_outputs": outputs.encoder_outputs,
264
- }
265
- else:
266
- kwargs = {}
267
-
268
- n = 0
269
- while True:
270
- if n != max_tokens:
271
- next_y, next_logprobs = _step(y, **kwargs)
272
- mx.async_eval(next_y)
273
- if "decoder_input_ids" in kwargs:
274
- kwargs["decoder_input_ids"] = next_y[None]
275
- yield y.item(), logprobs
276
- y, logprobs = next_y, next_logprobs
277
- if n == max_tokens:
278
- break
279
-
280
- n += 1
281
-
282
- # Periodically clear cache to prevent memory accumulation
283
- if n % 256 == 0: # Clear cache every 256 tokens
284
- mx.clear_cache()
285
-
286
-
287
- def stream_generate(
288
- model: nn.Module,
289
- processor: PreTrainedTokenizer,
290
- prompt: str,
291
- image: Union[str, List[str]] = None,
292
- audio: Union[str, List[str]] = None,
293
- **kwargs,
294
- ) -> Union[str, Generator[str, None, None]]:
295
- """
296
- A generator producing text based on the given prompt from the model.
297
-
298
- Args:
299
- prompt (mx.array): The input prompt.
300
- model (nn.Module): The model to use for generation.
301
- max_tokens (int): The ma
302
- kwargs: The remaining options get passed to :func:`generate_step`.
303
- See :func:`generate_step` for more details.
304
-
305
- Yields:
306
- Generator[Tuple[mx.array, mx.array]]: A generator producing text.
307
- """
308
- tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
309
-
310
- # Skip special tokens
311
- skip_special_tokens = kwargs.pop("skip_special_tokens", False)
312
- skip_special_token_ids = (
313
- set(tokenizer.all_special_ids)
314
- if skip_special_tokens and hasattr(tokenizer, "all_special_ids")
315
- else []
316
- )
317
-
318
- add_special_tokens = (
319
- not hasattr(processor, "chat_template")
320
- if model.config.model_type in ["gemma3", "gemma3n"]
321
- else True
322
- )
323
-
324
- resize_shape = kwargs.pop("resize_shape", None)
325
- image_token_index = getattr(model.config, "image_token_index", None)
326
-
327
- if kwargs.get("input_ids", None) is not None:
328
- input_ids = kwargs.pop("input_ids")
329
- pixel_values = kwargs.pop("pixel_values", None)
330
- mask = kwargs.pop("mask", None)
331
- else:
332
- inputs = prepare_inputs(
333
- processor,
334
- images=image,
335
- audio=audio,
336
- prompts=prompt,
337
- image_token_index=image_token_index,
338
- resize_shape=resize_shape,
339
- add_special_tokens=add_special_tokens,
340
- )
341
- input_ids = inputs.get("input_ids", None)
342
- pixel_values = inputs.get("pixel_values", None)
343
- mask = inputs.get("attention_mask", None)
344
- data_kwargs = {
345
- k: v
346
- for k, v in inputs.items()
347
- if k not in ["input_ids", "pixel_values", "attention_mask"]
348
- }
349
- kwargs.update(data_kwargs)
350
-
351
- with wired_limit(model, [generation_stream]):
352
- detokenizer = processor.detokenizer
353
- detokenizer.reset()
354
- tic = time.perf_counter()
355
- for n, (token, logprobs) in enumerate(
356
- generate_step(input_ids, model, pixel_values, mask, **kwargs)
357
- ):
358
- if n == 0:
359
- prompt_time = time.perf_counter() - tic
360
- prompt_tps = input_ids.size / prompt_time
361
- tic = time.perf_counter()
362
-
363
- # Stop generation if the token is in the eos_token_ids
364
- if tokenizer.stopping_criteria(token):
365
- break
366
-
367
- detokenizer.add_token(token, skip_special_token_ids=skip_special_token_ids)
368
-
369
- # Yield the last segment if streaming
370
- yield GenerationResult(
371
- text=detokenizer.last_segment,
372
- token=token,
373
- logprobs=logprobs,
374
- prompt_tokens=input_ids.size,
375
- generation_tokens=n + 1,
376
- prompt_tps=prompt_tps,
377
- generation_tps=(n + 1) / (time.perf_counter() - tic),
378
- peak_memory=mx.get_peak_memory() / 1e9,
379
- )
380
-
381
- detokenizer.finalize()
382
- yield GenerationResult(
383
- text=detokenizer.last_segment,
384
- token=token,
385
- logprobs=logprobs,
386
- prompt_tokens=input_ids.size,
387
- generation_tokens=n + 1,
388
- prompt_tps=prompt_tps,
389
- generation_tps=(n + 1) / (time.perf_counter() - tic),
390
- peak_memory=mx.get_peak_memory() / 1e9,
391
- )
392
-
393
- # Cleanup after generation
394
- mx.clear_cache()
395
-
396
-
397
- def generate(
398
- model: nn.Module,
399
- processor: PreTrainedTokenizer,
400
- prompt: str,
401
- image: Union[str, List[str]] = None,
402
- audio: Union[str, List[str]] = None,
403
- verbose: bool = False,
404
- **kwargs,
405
- ) -> str:
406
- """
407
- Generate text from the model.
408
-
409
- Args:
410
- model (nn.Module): The language model.
411
- tokenizer (PreTrainedTokenizer): The tokenizer.
412
- prompt (str): The string prompt.
413
- temperature (float): The temperature for sampling (default 0).
414
- max_tokens (int): The maximum number of tokens (default 100).
415
- verbose (bool): If ``True``, print tokens and timing information
416
- (default ``False``).
417
- formatter (Optional[Callable]): A function which takes a token and a
418
- probability and displays it.
419
- repetition_penalty (float, optional): The penalty factor for repeating tokens.
420
- repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
421
- """
422
-
423
- if verbose:
424
- print("=" * 10)
425
- files = []
426
- if image is not None:
427
- files.extend(image)
428
- if audio is not None:
429
- files.extend(audio)
430
- if kwargs.get("video") is not None:
431
- files.extend(kwargs.get("video"))
432
-
433
- print(f"Files: {files}", "\n")
434
-
435
- print("Prompt:", prompt)
436
-
437
- text = ""
438
- last_response = None
439
-
440
- eos_tokens = kwargs.get("eos_tokens", None)
441
- stopping_criteria = kwargs.get("stopping_criteria", None)
442
-
443
- # Get the tokenizer
444
- tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
445
-
446
- # Add custom EOS tokens to the stopping criteria
447
- if eos_tokens is not None:
448
- tokenizer.stopping_criteria.add_eos_token_ids(eos_tokens)
449
-
450
- # Use custom stopping criteria
451
- elif stopping_criteria is not None:
452
- if isinstance(stopping_criteria, StoppingCriteria) or callable(
453
- stopping_criteria
454
- ):
455
- tokenizer.stopping_criteria = stopping_criteria
456
- else:
457
- raise ValueError(
458
- "stopping_criteria must be an instance of StoppingCriteria or a callable"
459
- )
460
- else:
461
- tokenizer.stopping_criteria.reset(model.config.eos_token_id)
462
-
463
- for response in stream_generate(model, processor, prompt, image, audio, **kwargs):
464
- if verbose:
465
- print(response.text, end="", flush=True)
466
- text += response.text
467
- last_response = response
468
-
469
- if verbose:
470
- print("\n" + "=" * 10)
471
- if len(text) == 0:
472
- print("No text generated for this prompt")
473
- return
474
- print(
475
- f"Prompt: {last_response.prompt_tokens} tokens, "
476
- f"{last_response.prompt_tps:.3f} tokens-per-sec"
477
- )
478
- print(
479
- f"Generation: {last_response.generation_tokens} tokens, "
480
- f"{last_response.generation_tps:.3f} tokens-per-sec"
481
- )
482
- print(f"Peak memory: {last_response.peak_memory:.3f} GB")
483
-
484
- usage_stats = {
485
- "input_tokens": last_response.prompt_tokens,
486
- "output_tokens": last_response.generation_tokens,
487
- "total_tokens": last_response.prompt_tokens + last_response.generation_tokens,
488
- "prompt_tps": last_response.prompt_tps,
489
- "generation_tps": last_response.generation_tps,
490
- "peak_memory": last_response.peak_memory,
491
- }
492
-
493
- return text, usage_stats
494
-
495
-
496
- def main():
497
- args = parse_arguments()
498
-
499
- # Load model and processor
500
- model, processor = load(args.model, None)
501
- config = model.config
502
-
503
- # Initialize chat history
504
- chat = []
505
-
506
- print("Multi-round conversation started. Type 'exit' or 'quit' to stop.")
507
- print("You can include image/audio files in quotes, e.g.: 'what does this image mean \"/path/to/image.jpg\"'")
508
- print("=" * 50)
509
-
510
- # Main chat loop
511
- while True:
512
- try:
513
- user_input = input("User: ").strip()
514
-
515
- # Exit conditions
516
- if user_input.lower() in ['exit', 'quit', '']:
517
- break
518
-
519
- # Parse media files from user input
520
- prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
521
-
522
- # If no text prompt after parsing, use the original input
523
- if not prompt_text.strip():
524
- prompt_text = user_input
525
- image_paths = None
526
- audio_paths = None
527
-
528
- # Add user message to chat history
529
- chat.append({"role": "user", "content": prompt_text})
530
-
531
- # Calculate number of images for chat template
532
- num_images = len(image_paths) if image_paths else 0
533
- num_audios = len(audio_paths) if audio_paths else 0
534
-
535
- # Apply chat template
536
- formatted_prompt = apply_chat_template(
537
- processor, config, chat, num_images=num_images, num_audios=num_audios
538
- )
539
-
540
- # Generate response
541
- response = ""
542
- print("Assistant: ", end="", flush=True)
543
-
544
- for chunk in stream_generate(
545
- model,
546
- processor,
547
- formatted_prompt,
548
- image_paths,
549
- audio_paths,
550
- max_tokens=100,
551
- temperature=0.7,
552
- top_p=0.9,
553
- verbose=True,
554
- ):
555
- response += chunk.text
556
- print(chunk.text, end="", flush=True)
557
-
558
- print() # New line after response
559
-
560
- # Add assistant response to chat history
561
- chat.append({"role": "assistant", "content": response})
562
-
563
- except KeyboardInterrupt:
564
- print("\nConversation interrupted by user.")
565
- break
566
- except Exception as e:
567
- print(f"Error: {e}")
568
- continue
569
-
570
-
571
- if __name__ == "__main__":
572
- main()