fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/generate.py ADDED
@@ -0,0 +1,1457 @@
1
+ import argparse
2
+ import codecs
3
+ import contextlib
4
+ import functools
5
+ import json
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
9
+
10
+ import mlx.core as mx
11
+ import mlx.nn as nn
12
+ from mlx.utils import tree_reduce
13
+ from mlx_lm.generate import maybe_quantize_kv_cache
14
+ from tqdm import tqdm
15
+ from transformers import PreTrainedTokenizer
16
+
17
+ from .models import cache
18
+ from .prompt_utils import apply_chat_template
19
+ from .sample_utils import top_p_sampling
20
+ from .utils import (
21
+ StoppingCriteria,
22
+ apply_repetition_penalty,
23
+ group_images_by_shape,
24
+ load,
25
+ prepare_inputs,
26
+ )
27
+
28
+ DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
29
+ DEFAULT_IMAGE = None
30
+ DEFAULT_AUDIO = None
31
+ DEFAULT_PROMPT = "What are these?"
32
+ DEFAULT_MAX_TOKENS = 256
33
+ DEFAULT_TEMPERATURE = 0.5
34
+ DEFAULT_TOP_P = 1.0
35
+ DEFAULT_SEED = 0
36
+ DEFAULT_QUANTIZED_KV_START = 5000
37
+
38
+
39
+ def parse_arguments():
40
+ parser = argparse.ArgumentParser(
41
+ description="Generate text from an image using a model."
42
+ )
43
+ parser.add_argument(
44
+ "--model",
45
+ type=str,
46
+ default=DEFAULT_MODEL_PATH,
47
+ help="The path to the local model directory or Hugging Face repo.",
48
+ )
49
+ parser.add_argument(
50
+ "--adapter-path",
51
+ type=str,
52
+ default=None,
53
+ help="The path to the adapter weights.",
54
+ )
55
+ parser.add_argument(
56
+ "--image",
57
+ type=str,
58
+ nargs="+",
59
+ default=DEFAULT_IMAGE,
60
+ help="URL or path of the image to process.",
61
+ )
62
+ parser.add_argument(
63
+ "--audio",
64
+ type=str,
65
+ nargs="+",
66
+ default=DEFAULT_AUDIO,
67
+ help="URL or path of the audio to process.",
68
+ )
69
+ parser.add_argument(
70
+ "--resize-shape",
71
+ type=int,
72
+ nargs="+",
73
+ default=None,
74
+ help="Resize shape for the image.",
75
+ )
76
+ parser.add_argument(
77
+ "--prompt",
78
+ type=str,
79
+ nargs="+",
80
+ default=DEFAULT_PROMPT,
81
+ help="Message to be processed by the model.",
82
+ )
83
+ parser.add_argument(
84
+ "--system",
85
+ type=str,
86
+ default=None,
87
+ help="System message for the model.",
88
+ )
89
+ parser.add_argument(
90
+ "--max-tokens",
91
+ type=int,
92
+ default=DEFAULT_MAX_TOKENS,
93
+ help="Maximum number of tokens to generate.",
94
+ )
95
+ parser.add_argument(
96
+ "--temperature",
97
+ type=float,
98
+ default=DEFAULT_TEMPERATURE,
99
+ help="Temperature for sampling.",
100
+ )
101
+ parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.")
102
+ parser.add_argument("--verbose", action="store_false", help="Detailed output.")
103
+ parser.add_argument(
104
+ "--eos-tokens",
105
+ type=str,
106
+ nargs="+",
107
+ default=None,
108
+ help="EOS tokens to add to the tokenizer.",
109
+ )
110
+ parser.add_argument(
111
+ "--max-kv-size",
112
+ type=int,
113
+ default=None,
114
+ help="Maximum KV size for the prompt cache.",
115
+ )
116
+ parser.add_argument(
117
+ "--kv-bits",
118
+ type=int,
119
+ default=None,
120
+ help="Number of bits to quantize the KV cache to.",
121
+ )
122
+ parser.add_argument(
123
+ "--kv-group-size",
124
+ type=int,
125
+ default=64,
126
+ help="Group size for the KV cache.",
127
+ )
128
+ parser.add_argument(
129
+ "--quantized-kv-start",
130
+ type=int,
131
+ default=DEFAULT_QUANTIZED_KV_START,
132
+ help="Start index for the quantized KV cache.",
133
+ )
134
+ parser.add_argument(
135
+ "--skip-special-tokens",
136
+ action="store_true",
137
+ help="Skip special tokens in the detokenizer.",
138
+ )
139
+ parser.add_argument(
140
+ "--force-download",
141
+ action="store_true",
142
+ help="Force download the model from Hugging Face.",
143
+ )
144
+ parser.add_argument(
145
+ "--revision",
146
+ type=str,
147
+ default="main",
148
+ help="The specific model version to use (branch, tag, commit).",
149
+ )
150
+ parser.add_argument(
151
+ "--trust-remote-code",
152
+ action="store_true",
153
+ help="Trust remote code when loading the model.",
154
+ )
155
+ parser.add_argument(
156
+ "--processor-kwargs",
157
+ type=json.loads,
158
+ default={},
159
+ help="Extra processor kwargs as JSON. "
160
+ 'Example: --processor-kwargs \'{"cropping": false, "max_patches": 3}\'',
161
+ )
162
+ parser.add_argument(
163
+ "--prefill-step-size",
164
+ type=int,
165
+ default=None,
166
+ help="Number of tokens to process per prefill step. "
167
+ "Lower values reduce peak memory usage but may be slower. "
168
+ "Try 512 or 256 if you hit GPU memory errors during prefill.",
169
+ )
170
+
171
+ return parser.parse_args()
172
+
173
+
174
+ # A stream on the default device just for generation
175
+ generation_stream = mx.new_stream(mx.default_device())
176
+
177
+
178
+ @contextlib.contextmanager
179
+ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
180
+ """
181
+ A context manager to temporarily change the wired limit.
182
+
183
+ Note, the wired limit should not be changed during an async eval. If an
184
+ async eval could be running pass in the streams to synchronize with prior
185
+ to exiting the context manager.
186
+ """
187
+ if not mx.metal.is_available():
188
+ yield
189
+ return
190
+
191
+ model_bytes = tree_reduce(
192
+ lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
193
+ )
194
+ max_rec_size = mx.device_info()["max_recommended_working_set_size"]
195
+ if model_bytes > 0.9 * max_rec_size:
196
+ model_mb = model_bytes // 2**20
197
+ max_rec_mb = max_rec_size // 2**20
198
+ print(
199
+ f"[WARNING] Generating with a model that requires {model_mb} MB "
200
+ f"which is close to the maximum recommended size of {max_rec_mb} "
201
+ "MB. This can be slow. See the documentation for possible work-arounds: "
202
+ "https://github.com/ml-explore/mlx-lm/tree/main#large-models"
203
+ )
204
+ old_limit = mx.set_wired_limit(max_rec_size)
205
+ try:
206
+ yield
207
+ finally:
208
+ if streams is not None:
209
+ for s in streams:
210
+ mx.synchronize(s)
211
+ else:
212
+ mx.synchronize()
213
+ mx.set_wired_limit(old_limit)
214
+
215
+
216
+ @dataclass
217
+ class GenerationResult:
218
+ text: str = ""
219
+ token: Optional[int] = None
220
+ logprobs: Optional[List[float]] = None
221
+ prompt_tokens: int = 0
222
+ generation_tokens: int = 0
223
+ total_tokens: int = 0
224
+ prompt_tps: float = 0.0
225
+ generation_tps: float = 0.0
226
+ peak_memory: float = 0.0
227
+
228
+
229
+ def generate_step(
230
+ input_ids: mx.array,
231
+ model: nn.Module,
232
+ pixel_values,
233
+ mask,
234
+ *,
235
+ max_tokens: int = 256,
236
+ temperature: float = 0.0,
237
+ repetition_penalty: Optional[float] = None,
238
+ repetition_context_size: Optional[int] = 20,
239
+ top_p: float = 1.0,
240
+ logit_bias: Optional[Dict[int, float]] = None,
241
+ prompt_cache: Optional[List[Any]] = None,
242
+ max_kv_size: Optional[int] = None,
243
+ kv_bits: Optional[int] = None,
244
+ kv_group_size: int = 64,
245
+ quantized_kv_start: int = 0,
246
+ logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
247
+ prefill_step_size: Optional[int] = 2048,
248
+ **kwargs,
249
+ ) -> Generator[Tuple[mx.array, mx.array], None, None]:
250
+ """
251
+ A generator producing token ids based on the given prompt from the model.
252
+
253
+ Args:
254
+ input_ids (mx.array): The input prompt token ids.
255
+ model (nn.Module): The model to use for generation.
256
+ pixel_values: The pixel values for vision models (optional).
257
+ mask: The attention mask (optional).
258
+ max_tokens (int): Maximum number of tokens to generate. Default: ``256``.
259
+ temperature (float): The temperature for sampling, if 0 the argmax is used.
260
+ Default: ``0``.
261
+ repetition_penalty (float, optional): The penalty factor for repeating
262
+ tokens.
263
+ repetition_context_size (int, optional): The number of tokens to
264
+ consider for repetition penalty. Default: ``20``.
265
+ top_p (float, optional): Nucleus sampling, higher means model considers
266
+ more less likely words.
267
+ logit_bias (dictionary, optional): Additive logit bias.
268
+ prompt_cache (list, optional): Pre-existing KV cache for the prompt.
269
+ max_kv_size (int, optional): Maximum KV cache size.
270
+ kv_bits (int, optional): Number of bits for KV cache quantization.
271
+ kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
272
+ quantized_kv_start (int): Start index for quantized KV cache. Default: ``0``.
273
+ logits_processors (list, optional): List of logits processor functions.
274
+ prefill_step_size (int): Number of tokens to process per prefill step.
275
+ Chunked prefill processes prompts in smaller chunks to reduce peak
276
+ memory usage. Default: ``2048``.
277
+
278
+ Yields:
279
+ Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
280
+ one token and a vector of log probabilities.
281
+ """
282
+
283
+ quantize_cache_fn = functools.partial(
284
+ maybe_quantize_kv_cache,
285
+ quantized_kv_start=quantized_kv_start,
286
+ kv_group_size=kv_group_size,
287
+ kv_bits=kv_bits,
288
+ )
289
+
290
+ def sample(logits: mx.array) -> Tuple[mx.array, float]:
291
+ if logit_bias:
292
+ indices = mx.array(list(logit_bias.keys()))
293
+ values = mx.array(list(logit_bias.values()))
294
+ logits[:, indices] += values
295
+ logprobs = logits - mx.logsumexp(logits)
296
+
297
+ if temperature == 0:
298
+ token = mx.argmax(logits, axis=-1)
299
+ else:
300
+ if top_p > 0 and top_p < 1.0:
301
+ token = top_p_sampling(logits, top_p, temperature)
302
+ else:
303
+ token = mx.random.categorical(logits * (1 / temperature))
304
+
305
+ return token, logprobs
306
+
307
+ if repetition_penalty and (
308
+ repetition_penalty < 0 or not isinstance(repetition_penalty, float)
309
+ ):
310
+ raise ValueError(
311
+ f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
312
+ )
313
+
314
+ y = input_ids
315
+ tokens = None # Track tokens for logits processors
316
+
317
+ # Create the KV cache for generation
318
+ if prompt_cache is None:
319
+ prompt_cache = cache.make_prompt_cache(
320
+ model.language_model,
321
+ max_kv_size=max_kv_size,
322
+ )
323
+
324
+ repetition_context = input_ids.reshape(-1).tolist()
325
+
326
+ if repetition_context_size:
327
+ repetition_context = repetition_context[-repetition_context_size:]
328
+
329
+ def _step(y, inputs_embeds=None):
330
+ nonlocal tokens, repetition_context, kwargs
331
+ with mx.stream(generation_stream):
332
+ if "decoder_input_ids" in kwargs:
333
+ outputs = model.language_model(
334
+ cache=prompt_cache,
335
+ **kwargs,
336
+ )
337
+ else:
338
+ outputs = model.language_model(
339
+ y,
340
+ inputs_embeds=inputs_embeds,
341
+ cache=prompt_cache,
342
+ **kwargs,
343
+ )
344
+
345
+ logits = outputs.logits[:, -1, :]
346
+
347
+ # Apply logits processors before repetition penalty
348
+ if logits_processors:
349
+ # Efficiently update tokens by concatenating only the new token
350
+ tokens = mx.concat([tokens, y])
351
+ for processor in logits_processors:
352
+ logits = processor(tokens, logits)
353
+
354
+ if repetition_penalty:
355
+ logits = apply_repetition_penalty(
356
+ logits, repetition_context, repetition_penalty
357
+ )
358
+ y, logprobs = sample(logits)
359
+ repetition_context.append(y.item())
360
+ else:
361
+ y, logprobs = sample(logits)
362
+
363
+ if repetition_context_size:
364
+ if len(repetition_context) > repetition_context_size:
365
+ repetition_context = repetition_context[-repetition_context_size:]
366
+
367
+ quantize_cache_fn(prompt_cache)
368
+
369
+ if outputs.cross_attention_states is not None:
370
+ kwargs = {"cross_attention_states": outputs.cross_attention_states}
371
+ elif outputs.encoder_outputs is not None:
372
+ kwargs = {
373
+ "decoder_input_ids": y[None],
374
+ "encoder_outputs": outputs.encoder_outputs,
375
+ }
376
+ else:
377
+ kwargs = {}
378
+
379
+ return y, logprobs.squeeze(0)
380
+
381
+ with mx.stream(generation_stream):
382
+
383
+ # Get input embeddings (handles both multimodal and text-only)
384
+ embedding_output = model.get_input_embeddings(
385
+ input_ids, pixel_values, mask=mask, **kwargs
386
+ )
387
+
388
+ inputs_embeds = embedding_output.inputs_embeds
389
+
390
+ kwargs.update(
391
+ {
392
+ k: v
393
+ for k, v in embedding_output.to_dict().items()
394
+ if k != "inputs_embeds" and v is not None
395
+ }
396
+ )
397
+ if prefill_step_size is not None and inputs_embeds.shape[1] > prefill_step_size:
398
+ # Chunked prefill with embeddings
399
+ total_tokens = inputs_embeds.shape[1]
400
+ with tqdm(total=total_tokens, desc="Prefill", unit="tok") as pbar:
401
+ while inputs_embeds.shape[1] > 1:
402
+ n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1)
403
+ model.language_model(
404
+ inputs=input_ids[:, :n_to_process],
405
+ inputs_embeds=inputs_embeds[:, :n_to_process],
406
+ cache=prompt_cache,
407
+ **kwargs,
408
+ )
409
+ quantize_cache_fn(prompt_cache)
410
+ mx.eval([c.state for c in prompt_cache])
411
+ inputs_embeds = inputs_embeds[:, n_to_process:]
412
+ input_ids = input_ids[:, n_to_process:]
413
+ mx.clear_cache()
414
+ pbar.update(n_to_process)
415
+
416
+ input_ids = input_ids[:, -1:]
417
+
418
+ y, logprobs = _step(input_ids, inputs_embeds=inputs_embeds)
419
+
420
+ mx.async_eval(y)
421
+
422
+ n = 0
423
+ while True:
424
+ if n != max_tokens:
425
+ next_y, next_logprobs = _step(y[None])
426
+ mx.async_eval(next_y)
427
+ if n == 0:
428
+ mx.eval(y)
429
+ if n == max_tokens:
430
+ break
431
+
432
+ yield y.item(), logprobs
433
+ if n % 256 == 0:
434
+ mx.clear_cache()
435
+ y, logprobs = next_y, next_logprobs
436
+ n += 1
437
+
438
+
439
+ def stream_generate(
440
+ model: nn.Module,
441
+ processor: PreTrainedTokenizer,
442
+ prompt: str,
443
+ image: Union[str, List[str]] = None,
444
+ audio: Union[str, List[str]] = None,
445
+ **kwargs,
446
+ ) -> Union[str, Generator[str, None, None]]:
447
+ """
448
+ A generator producing text based on the given prompt from the model.
449
+
450
+ Args:
451
+ model (nn.Module): The model to use for generation.
452
+ processor (PreTrainedTokenizer): The tokenizer/processor.
453
+ prompt (str): The input prompt text.
454
+ image (Union[str, List[str]], optional): Image path(s) or URL(s).
455
+ audio (Union[str, List[str]], optional): Audio file path(s).
456
+ prefill_step_size (int, optional): Number of tokens to process per prefill
457
+ step. When set, enables chunked prefill which processes long prompts in
458
+ smaller chunks to reduce peak memory usage.
459
+ kwargs: Additional options passed to :func:`generate_step`.
460
+ See :func:`generate_step` for more details.
461
+
462
+ Yields:
463
+ Generator[GenerationResult]: A generator producing GenerationResult objects
464
+ containing the generated text, tokens, and statistics.
465
+ """
466
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
467
+
468
+ # Skip special tokens
469
+ skip_special_tokens = kwargs.pop("skip_special_tokens", False)
470
+ skip_special_token_ids = (
471
+ set(tokenizer.all_special_ids)
472
+ if skip_special_tokens and hasattr(tokenizer, "all_special_ids")
473
+ else []
474
+ )
475
+
476
+ add_special_tokens = (
477
+ not hasattr(processor, "chat_template")
478
+ if model.config.model_type in ["gemma3", "gemma3n"]
479
+ else True
480
+ )
481
+
482
+ resize_shape = kwargs.pop("resize_shape", None)
483
+ image_token_index = getattr(model.config, "image_token_index", None)
484
+
485
+ if kwargs.get("input_ids", None) is not None:
486
+ input_ids = kwargs.pop("input_ids")
487
+ pixel_values = kwargs.pop("pixel_values", None)
488
+ mask = kwargs.pop("mask", None)
489
+ else:
490
+ inputs = prepare_inputs(
491
+ processor,
492
+ images=image,
493
+ audio=audio,
494
+ prompts=prompt,
495
+ image_token_index=image_token_index,
496
+ resize_shape=resize_shape,
497
+ add_special_tokens=add_special_tokens,
498
+ **kwargs,
499
+ )
500
+ input_ids = inputs.get("input_ids", None)
501
+ pixel_values = inputs.get("pixel_values", None)
502
+ mask = inputs.get("attention_mask", None)
503
+ data_kwargs = {
504
+ k: v
505
+ for k, v in inputs.items()
506
+ if k not in ["input_ids", "pixel_values", "attention_mask"]
507
+ }
508
+ kwargs.update(data_kwargs)
509
+
510
+ with wired_limit(model, [generation_stream]):
511
+ detokenizer = processor.detokenizer
512
+ detokenizer.reset()
513
+ tic = time.perf_counter()
514
+
515
+ # #region agent log
516
+ import json
517
+ log_file = "/Users/zekieldee/Desktop/code/mlx-vlm/.cursor/debug.log"
518
+ def log_debug(location, message, data, hypothesis_id):
519
+ try:
520
+ with open(log_file, "a") as f:
521
+ f.write(json.dumps({"sessionId": "debug-session", "runId": "generation", "hypothesisId": hypothesis_id, "location": location, "message": message, "data": data, "timestamp": __import__("time").time_ns() // 1000000}) + "\n")
522
+ except: pass
523
+
524
+ log_debug("generate.py:stream_generate_start", "Tokenizer and model info", {
525
+ "model_type": model.config.model_type if hasattr(model.config, "model_type") else "unknown",
526
+ "tokenizer_class": tokenizer.__class__.__name__,
527
+ "vocab_size": tokenizer.vocab_size if hasattr(tokenizer, "vocab_size") else "unknown",
528
+ "eos_token_id": tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else "unknown",
529
+ "bos_token_id": tokenizer.bos_token_id if hasattr(tokenizer, "bos_token_id") else "unknown",
530
+ "pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, "pad_token_id") else "unknown",
531
+ }, "H2,H3,H4")
532
+ # #endregion
533
+
534
+ try:
535
+ for n, (token, logprobs) in enumerate(
536
+ generate_step(input_ids, model, pixel_values, mask, **kwargs)
537
+ ):
538
+ if n == 0:
539
+ prompt_time = time.perf_counter() - tic
540
+ prompt_tps = input_ids.size / prompt_time
541
+ tic = time.perf_counter()
542
+
543
+ # #region agent log
544
+ top5_indices = mx.argsort(logprobs)[-5:].tolist()
545
+ top5_values = mx.sort(logprobs)[-5:].tolist()
546
+ log_debug("generate.py:first_token", "First token generated", {
547
+ "token_id": int(token),
548
+ "token_str": tokenizer.decode([token]) if hasattr(tokenizer, "decode") else "N/A",
549
+ "logprobs_shape": str(logprobs.shape),
550
+ "logprobs_top5_indices": top5_indices,
551
+ "logprobs_top5_values": top5_values,
552
+ }, "H2,H4")
553
+ # #endregion
554
+
555
+ # Stop generation if the token is in the eos_token_ids
556
+ if tokenizer.stopping_criteria(token):
557
+ # #region agent log
558
+ log_debug("generate.py:eos_detected", "EOS token detected", {"token_id": int(token), "iteration": n}, "H4")
559
+ # #endregion
560
+ break
561
+
562
+ # #region agent log
563
+ if n < 5: # Log first 5 tokens
564
+ decoded_token = tokenizer.decode([token]) if hasattr(tokenizer, "decode") else "N/A"
565
+ log_debug("generate.py:token_decode", f"Token {n} decode", {
566
+ "iteration": n,
567
+ "token_id": int(token),
568
+ "decoded_single": decoded_token,
569
+ "detokenizer_segment": detokenizer.text if hasattr(detokenizer, "text") else "N/A",
570
+ }, "H2,H4")
571
+ # #endregion
572
+
573
+ detokenizer.add_token(
574
+ token, skip_special_token_ids=skip_special_token_ids
575
+ )
576
+
577
+ # Yield the last segment if streaming
578
+ yield GenerationResult(
579
+ text=detokenizer.last_segment,
580
+ token=token,
581
+ logprobs=logprobs,
582
+ prompt_tokens=input_ids.size,
583
+ generation_tokens=n + 1,
584
+ total_tokens=input_ids.size + n + 1,
585
+ prompt_tps=prompt_tps,
586
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
587
+ peak_memory=mx.get_peak_memory() / 1e9,
588
+ )
589
+
590
+ detokenizer.finalize()
591
+
592
+ yield GenerationResult(
593
+ text=detokenizer.last_segment,
594
+ token=token,
595
+ logprobs=logprobs,
596
+ prompt_tokens=input_ids.size,
597
+ generation_tokens=n + 1,
598
+ total_tokens=input_ids.size + n + 1,
599
+ prompt_tps=prompt_tps,
600
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
601
+ peak_memory=mx.get_peak_memory() / 1e9,
602
+ )
603
+ except Exception as e:
604
+ raise
605
+
606
+ # Cleanup after generation
607
+ mx.clear_cache()
608
+
609
+
610
+ def generate(
611
+ model: nn.Module,
612
+ processor: PreTrainedTokenizer,
613
+ prompt: str,
614
+ image: Union[str, List[str]] = None,
615
+ audio: Union[str, List[str]] = None,
616
+ verbose: bool = False,
617
+ **kwargs,
618
+ ) -> GenerationResult:
619
+ """
620
+ Generate text from the model.
621
+
622
+ Args:
623
+ model (nn.Module): The language model.
624
+ tokenizer (PreTrainedTokenizer): The tokenizer.
625
+ prompt (str): The string prompt.
626
+ temperature (float): The temperature for sampling (default 0).
627
+ max_tokens (int): The maximum number of tokens (default 100).
628
+ verbose (bool): If ``True``, print tokens and timing information
629
+ (default ``False``).
630
+ formatter (Optional[Callable]): A function which takes a token and a
631
+ probability and displays it.
632
+ repetition_penalty (float, optional): The penalty factor for repeating tokens.
633
+ repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
634
+ """
635
+
636
+ if verbose:
637
+ print("=" * 10)
638
+ files = []
639
+ if image is not None:
640
+ files.extend(image)
641
+ if audio is not None:
642
+ files.extend(audio)
643
+ if kwargs.get("video") is not None:
644
+ files.extend(kwargs.get("video"))
645
+
646
+ print(f"Files: {files}", "\n")
647
+
648
+ print("Prompt:", prompt)
649
+
650
+ text = ""
651
+ last_response = None
652
+
653
+ eos_tokens = kwargs.get("eos_tokens", None)
654
+ stopping_criteria = kwargs.get("stopping_criteria", None)
655
+
656
+ # Get the tokenizer
657
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
658
+
659
+ # Add custom EOS tokens to the stopping criteria
660
+ if eos_tokens is not None:
661
+ tokenizer.stopping_criteria.add_eos_token_ids(eos_tokens)
662
+
663
+ # Use custom stopping criteria
664
+ elif stopping_criteria is not None:
665
+ if isinstance(stopping_criteria, StoppingCriteria) or callable(
666
+ stopping_criteria
667
+ ):
668
+ tokenizer.stopping_criteria = stopping_criteria
669
+ else:
670
+ raise ValueError(
671
+ "stopping_criteria must be an instance of StoppingCriteria or a callable"
672
+ )
673
+ else:
674
+ tokenizer.stopping_criteria.reset(model.config.eos_token_id)
675
+
676
+ for response in stream_generate(model, processor, prompt, image, audio, **kwargs):
677
+ if verbose:
678
+ print(response.text, end="", flush=True)
679
+ text += response.text
680
+ last_response = response
681
+
682
+ if verbose:
683
+ print("\n" + "=" * 10)
684
+ if len(text) == 0:
685
+ print("No text generated for this prompt")
686
+ return GenerationResult(
687
+ text=text,
688
+ token=None,
689
+ logprobs=None,
690
+ prompt_tokens=0,
691
+ generation_tokens=0,
692
+ total_tokens=0,
693
+ prompt_tps=0.0,
694
+ generation_tps=0.0,
695
+ peak_memory=mx.get_peak_memory() / 1e9,
696
+ )
697
+ print(
698
+ f"Prompt: {last_response.prompt_tokens} tokens, "
699
+ f"{last_response.prompt_tps:.3f} tokens-per-sec"
700
+ )
701
+ print(
702
+ f"Generation: {last_response.generation_tokens} tokens, "
703
+ f"{last_response.generation_tps:.3f} tokens-per-sec"
704
+ )
705
+ print(f"Peak memory: {last_response.peak_memory:.3f} GB")
706
+
707
+ return GenerationResult(
708
+ text=text,
709
+ token=last_response.token,
710
+ logprobs=last_response.logprobs,
711
+ prompt_tokens=last_response.prompt_tokens,
712
+ generation_tokens=last_response.generation_tokens,
713
+ total_tokens=last_response.total_tokens,
714
+ prompt_tps=last_response.prompt_tps,
715
+ generation_tps=last_response.generation_tps,
716
+ peak_memory=last_response.peak_memory,
717
+ )
718
+
719
+
720
+ @dataclass
721
+ class BatchGenerationResult:
722
+ """
723
+ Result of batch generation with optional image size tracking.
724
+
725
+ Attributes:
726
+ texts: Generated text for each sample
727
+ tokens: Last generated token for each sample
728
+ logprobs: Log probabilities for each sample
729
+ prompt_tokens: Number of prompt tokens per sample
730
+ generation_tokens: Number of generated tokens per sample
731
+ total_tokens: Total tokens (prompt + generation) per sample
732
+ prompt_tps: Prompt tokens per second per sample
733
+ generation_tps: Generation tokens per second per sample
734
+ peak_memory: Peak memory usage in GB
735
+ image_sizes: Original (height, width) for each image (for tracking)
736
+ """
737
+
738
+ texts: List[str]
739
+ tokens: List[Optional[int]]
740
+ logprobs: List[Optional[List[float]]]
741
+ prompt_tokens: List[int]
742
+ generation_tokens: List[int]
743
+ total_tokens: List[int]
744
+ prompt_tps: List[float]
745
+ generation_tps: List[float]
746
+ peak_memory: float = 0.0
747
+ image_sizes: Optional[List[Tuple[int, int]]] = None
748
+
749
+
750
+ def _left_pad_prompts(prompts, max_length=None):
751
+ if max_length is None:
752
+ max_length = max(len(p) for p in prompts)
753
+
754
+ return mx.array([[0] * (max_length - len(p)) + p for p in prompts])
755
+
756
+
757
+ def _make_cache(model, left_padding):
758
+ """
759
+ Convert a list of regular caches into their corresponding
760
+ batch-aware caches.
761
+ """
762
+
763
+ def to_batch_cache(c):
764
+ if isinstance(c, cache.KVCache):
765
+ return cache.BatchKVCache(left_padding)
766
+ elif isinstance(c, cache.ArraysCache):
767
+ c.left_padding = mx.array(left_padding)
768
+ return c
769
+ elif isinstance(c, cache.RotatingKVCache):
770
+ if c.keep > 0:
771
+ raise ValueError("RotatingKVCache with keep tokens is not supported.")
772
+ return cache.BatchRotatingKVCache(c.max_size, left_padding)
773
+ elif isinstance(c, cache.CacheList):
774
+ return cache.BatchCacheList(*(to_batch_cache(sub_c) for sub_c in c.caches))
775
+ else:
776
+ raise ValueError(f"{type(c)} does not yet support batching")
777
+
778
+ if hasattr(model, "make_cache"):
779
+ model_cache = model.make_cache()
780
+ return [to_batch_cache(c) for c in model_cache]
781
+ else:
782
+ return [cache.BatchKVCache(left_padding) for _ in model.layers]
783
+
784
+
785
+ @dataclass
786
+ class BatchStats:
787
+ """
788
+ An data object to hold generation stats.
789
+
790
+ Args:
791
+ prompt_tokens (int): The number of prompt tokens processed.
792
+ prompt_tps (float): The prompt processing tokens-per-second.
793
+ prompt_time (float): The time in seconds spent in prompt processing.
794
+ generation_tokens (int): The number of generated tokens.
795
+ generation_tps (float): The tokens-per-second for generation.
796
+ generation_time (float): The time in seconds spent in generation .
797
+ peak_memory (float): The peak memory used so far in GB.
798
+ """
799
+
800
+ prompt_tokens: int = 0
801
+ prompt_tps: float = 0
802
+ prompt_time: float = 0
803
+ generation_tokens: int = 0
804
+ generation_tps: float = 0
805
+ generation_time: float = 0
806
+ peak_memory: float = 0
807
+
808
+
809
+ @dataclass
810
+ class BatchResponse:
811
+ """
812
+ An data object to hold a batch generation response.
813
+
814
+ Args:
815
+ texts: (List[str]): The generated text for each prompt.
816
+ stats (BatchStats): Statistics about the generation.
817
+ image_sizes: (Optional[List[Tuple[int, int]]]): Original (height, width)
818
+ for each image. Useful for tracking which images produced which responses
819
+ and for debugging padding/batching behavior.
820
+ """
821
+
822
+ texts: List[str]
823
+ stats: BatchStats
824
+ image_sizes: Optional[List[Tuple[int, int]]] = None
825
+
826
+
827
+ @dataclass
828
+ class Batch:
829
+ uids: List[int]
830
+ y: mx.array
831
+ logprobs: mx.array
832
+ max_tokens: List[int]
833
+ num_tokens: List[int]
834
+ cache: List[Any]
835
+
836
+ def __len__(self):
837
+ return len(self.uids)
838
+
839
+ def filter(self, keep_idx: List[int]):
840
+ self.uids = [self.uids[k] for k in keep_idx]
841
+ self.max_tokens = [self.max_tokens[k] for k in keep_idx]
842
+ self.num_tokens = [self.num_tokens[k] for k in keep_idx]
843
+ keep_idx = mx.array(keep_idx, mx.int32)
844
+ self.y = self.y[keep_idx]
845
+ self.logprobs = self.logprobs[keep_idx]
846
+ for c in self.cache:
847
+ c.filter(keep_idx)
848
+
849
+ def extend(self, other):
850
+ self.uids.extend(other.uids)
851
+ self.y = mx.concatenate([self.y, other.y])
852
+ self.logprobs = mx.concatenate([self.logprobs, other.logprobs])
853
+ self.num_tokens.extend(other.num_tokens)
854
+ self.max_tokens.extend(other.max_tokens)
855
+ for c, o in zip(self.cache, other.cache):
856
+ c.extend(o)
857
+
858
+
859
+ class BatchGenerator:
860
+
861
+ @dataclass
862
+ class Response:
863
+ uid: int
864
+ token: int
865
+ logprobs: mx.array
866
+ finish_reason: Optional[str]
867
+
868
+ def __init__(
869
+ self,
870
+ model,
871
+ processor,
872
+ max_tokens: int = 128,
873
+ stop_tokens: Optional[set] = None,
874
+ sampler: Optional[Callable[[mx.array], mx.array]] = None,
875
+ completion_batch_size: int = 32,
876
+ prefill_batch_size: int = 8,
877
+ prefill_step_size: int = 2048,
878
+ prompt_cache=None,
879
+ ):
880
+ self.model = model
881
+ self.unprocessed_prompts = []
882
+ self.max_tokens = max_tokens
883
+ self.processor = processor
884
+ self.tokenizer = (
885
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
886
+ )
887
+ self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
888
+ self.uid_count = 0
889
+ self.prefill_step_size = prefill_step_size
890
+ self.prefill_batch_size = prefill_batch_size
891
+ self.completion_batch_size = completion_batch_size
892
+ self.prompt_cache = prompt_cache
893
+ self._stats = BatchStats()
894
+
895
+ self.tokenizer.stopping_criteria.add_eos_token_ids(stop_tokens)
896
+
897
+ self.active_batch = None
898
+
899
+ def insert(self, prompts, max_tokens: Union[List[int], int, None] = None):
900
+ uids = []
901
+
902
+ if max_tokens is None or isinstance(max_tokens, int):
903
+ max_tokens = [max_tokens or self.max_tokens] * len(prompts)
904
+
905
+ for p, m in zip(prompts, max_tokens):
906
+ self.unprocessed_prompts.append((self.uid_count, p, m))
907
+ uids.append(self.uid_count)
908
+ self.uid_count += 1
909
+ # Sort in ascending order of length
910
+ self.unprocessed_prompts = sorted(
911
+ self.unprocessed_prompts, key=lambda x: len(x[1])
912
+ )
913
+ return uids
914
+
915
+ def _process_prompts(self, prompts, **kwargs) -> Batch:
916
+ uids, inputs, max_tokens = zip(*prompts)
917
+ lengths = [len(p) for p in inputs]
918
+ max_length = max(lengths)
919
+
920
+ self._stats.prompt_tokens += sum(lengths)
921
+ left_padding = [max_length - l for l in lengths]
922
+ inputs = _left_pad_prompts(inputs, max_length=max_length)
923
+
924
+ prompt_cache = (
925
+ _make_cache(self.model, left_padding)
926
+ if self.prompt_cache is None
927
+ else self.prompt_cache
928
+ )
929
+
930
+ # Slice batch data in kwargs to match current batch size
931
+ batch_size = len(uids)
932
+ for key, value in kwargs.items():
933
+ if isinstance(value, mx.array) and value.ndim > 0:
934
+ kwargs[key] = value[:batch_size]
935
+
936
+ inputs_embeds = kwargs.pop("inputs_embeds", None)
937
+
938
+ if inputs_embeds is not None:
939
+ # Multimodal prefill
940
+ while inputs_embeds.shape[1] > 1:
941
+ n_to_process = min(self.prefill_step_size, inputs_embeds.shape[1] - 1)
942
+ self.model(
943
+ inputs[:, :n_to_process],
944
+ cache=prompt_cache,
945
+ inputs_embeds=inputs_embeds[:, :n_to_process],
946
+ n_to_process=n_to_process,
947
+ **kwargs,
948
+ )
949
+ mx.eval([c.state for c in prompt_cache])
950
+ inputs_embeds = inputs_embeds[:, n_to_process:]
951
+ inputs = inputs[:, n_to_process:]
952
+ mx.clear_cache()
953
+
954
+ kwargs = {"inputs_embeds": inputs_embeds}
955
+
956
+ else:
957
+ # Text-only prefill
958
+ while inputs.shape[1] > 1 and inputs_embeds is None:
959
+ n_to_process = min(self.prefill_step_size, inputs.shape[1] - 1)
960
+ self.model(inputs[:, :n_to_process], cache=prompt_cache)
961
+ mx.eval([c.state for c in prompt_cache])
962
+ inputs = inputs[:, n_to_process:]
963
+ mx.clear_cache()
964
+
965
+ y, logprobs = self._step(inputs, prompt_cache, **kwargs)
966
+ mx.async_eval(y, logprobs)
967
+ mx.clear_cache()
968
+ return Batch(
969
+ list(uids), y, logprobs, list(max_tokens), [0] * len(uids), prompt_cache
970
+ )
971
+
972
+ def _step(self, input_tokens: mx.array, prompt_cache: List[Any], **kwargs):
973
+ output = self.model(input_tokens, cache=prompt_cache, **kwargs)
974
+ logits = output.logits[:, -1, :]
975
+ logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
976
+ sampled = self.sampler(logprobs)
977
+
978
+ # TODO: Add KV cache quantization if specified
979
+ return sampled, logprobs
980
+
981
+ def stats(self):
982
+ self._stats.prompt_tps = self._stats.prompt_tokens / self._stats.prompt_time
983
+ self._stats.generation_tps = (
984
+ self._stats.generation_tokens / self._stats.generation_time
985
+ )
986
+ self._stats.peak_memory = mx.get_peak_memory() / 1e9
987
+ return self._stats
988
+
989
+ def _next(self, **kwargs):
990
+ tic = time.perf_counter()
991
+
992
+ prompt_processing = False
993
+ batch = self.active_batch
994
+ num_active = len(batch) if batch else 0
995
+ num_to_add = self.completion_batch_size - num_active
996
+ while num_to_add >= self.prefill_batch_size:
997
+ prompts = self.unprocessed_prompts[: self.prefill_batch_size]
998
+ # Finish processing the last examples of the last batch
999
+ if len(prompts) == 0 and num_active > 0:
1000
+ break
1001
+ # No more prompts and no more completions, all done
1002
+ elif len(prompts) == 0:
1003
+ self.active_batch = None
1004
+ return []
1005
+ # Process prompts
1006
+ if batch is not None and not prompt_processing:
1007
+ # Finish any active completion tokens
1008
+ mx.eval(batch.y, batch.logprobs)
1009
+ self._stats.generation_time += time.perf_counter() - tic
1010
+ tic = time.perf_counter()
1011
+
1012
+ batch = self._process_prompts(prompts, **kwargs)
1013
+ self.unprocessed_prompts = self.unprocessed_prompts[
1014
+ self.prefill_batch_size :
1015
+ ]
1016
+ prompt_processing = True
1017
+ # If there was no active batch, set it
1018
+ if self.active_batch is None:
1019
+ self.active_batch = batch
1020
+ else:
1021
+ self.active_batch.extend(batch)
1022
+
1023
+ num_active = len(self.active_batch)
1024
+ num_to_add -= len(batch)
1025
+
1026
+ batch = self.active_batch
1027
+ y, logprobs = batch.y, batch.logprobs
1028
+ batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
1029
+ mx.async_eval(batch.y, batch.logprobs)
1030
+
1031
+ y = y.tolist()
1032
+ toc = time.perf_counter()
1033
+ if prompt_processing:
1034
+ self._stats.prompt_time += toc - tic
1035
+ else:
1036
+ self._stats.generation_time += toc - tic
1037
+ keep_idx = []
1038
+ end_idx = []
1039
+ responses = []
1040
+
1041
+ for e, (t, uid, num_tok, max_tok) in enumerate(
1042
+ zip(y, batch.uids, batch.num_tokens, batch.max_tokens)
1043
+ ):
1044
+ num_tok += 1
1045
+ batch.num_tokens[e] = num_tok
1046
+ if self.tokenizer.stopping_criteria(t):
1047
+ finish_reason = "stop"
1048
+ end_idx.append(e)
1049
+ elif num_tok >= max_tok:
1050
+ finish_reason = "length"
1051
+ end_idx.append(e)
1052
+ else:
1053
+ finish_reason = None
1054
+ keep_idx.append(e)
1055
+ responses.append(self.Response(uid, t, logprobs[e], finish_reason))
1056
+
1057
+ # Remove any finished completions
1058
+ if len(end_idx):
1059
+ if len(keep_idx) > 0:
1060
+ batch.filter(keep_idx)
1061
+ else:
1062
+ self.active_batch = None
1063
+
1064
+ self._stats.generation_tokens += len(responses)
1065
+
1066
+ if len(responses) > 0 and self._stats.generation_tokens % 100 == 0:
1067
+ mx.clear_cache()
1068
+
1069
+ return responses
1070
+
1071
+ def next(self, **kwargs):
1072
+ with mx.stream(generation_stream):
1073
+ return self._next(**kwargs)
1074
+
1075
+
1076
+ def batch_generate(
1077
+ model,
1078
+ processor,
1079
+ images: Union[str, List[str]] = None,
1080
+ audios: Union[str, List[str]] = None,
1081
+ prompts: List[str] = None,
1082
+ max_tokens: Union[int, List[int]] = 128,
1083
+ verbose: bool = False,
1084
+ group_by_shape: bool = True,
1085
+ track_image_sizes: bool = True,
1086
+ **kwargs,
1087
+ ):
1088
+ """
1089
+ Generate responses for the given batch of prompts with variable-sized images.
1090
+
1091
+ This function implements the transformers-style approach to batching:
1092
+ 1. Group images with the same shape for efficient batch processing
1093
+ 2. Process each group as a batch (no padding waste within groups)
1094
+ 3. Track original image sizes for proper attention masking
1095
+ 4. Restore results to original batch order
1096
+
1097
+ Key insight: Instead of padding all images to the same spatial dimensions
1098
+ (which wastes computation and may hurt accuracy), we group same-sized
1099
+ images together so there's zero padding within each group.
1100
+
1101
+ Args:
1102
+ model (nn.Module): The language model.
1103
+ processor (PreTrainedTokenizer): The tokenizer/processor.
1104
+ images (Union[str, List[str]]): Images (paths, URLs, or PIL images).
1105
+ audios (Union[str, List[str]]): Audio files (not yet supported for batching).
1106
+ prompts (List[str]): The input prompts.
1107
+ max_tokens (Union[int, List[int]]): Maximum number of output tokens. This
1108
+ can be per prompt if a list is provided.
1109
+ verbose (bool): If ``True``, print tokens and timing information.
1110
+ Default: ``False``.
1111
+ group_by_shape (bool): If ``True``, group same-shaped images for efficient
1112
+ batch processing. Default: ``True``.
1113
+ track_image_sizes (bool): If ``True``, track and return original image sizes.
1114
+ Default: ``True``.
1115
+ kwargs: The remaining options get passed to :obj:`BatchGenerator`.
1116
+ See :obj:`BatchGenerator` for more details.
1117
+
1118
+ Returns:
1119
+ BatchResponse with generated texts, statistics, and optionally image_sizes.
1120
+ """
1121
+ from PIL import Image
1122
+
1123
+ from .utils import process_image
1124
+
1125
+ processor.detokenizer.reset()
1126
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
1127
+
1128
+ # Handle single image case
1129
+ if isinstance(images, str):
1130
+ images = [images]
1131
+
1132
+ # Handle no images case
1133
+ if images is None:
1134
+ texts, stats = _generate_batch(
1135
+ model, processor, prompts, None, max_tokens, verbose, **kwargs
1136
+ )
1137
+ return BatchResponse(texts, stats)
1138
+
1139
+ # Load and preprocess images
1140
+ image_processor = (
1141
+ processor.image_processor if hasattr(processor, "image_processor") else None
1142
+ )
1143
+
1144
+ processed_images = []
1145
+ image_sizes_original = []
1146
+ for img in images:
1147
+ if isinstance(img, str):
1148
+ pil_img = process_image(img, None, image_processor)
1149
+ elif isinstance(img, Image.Image):
1150
+ pil_img = img
1151
+ else:
1152
+ pil_img = img
1153
+ processed_images.append(pil_img)
1154
+ # Track original size
1155
+ if hasattr(pil_img, "height"):
1156
+ image_sizes_original.append((pil_img.height, pil_img.width))
1157
+ else:
1158
+ image_sizes_original.append((0, 0))
1159
+
1160
+ # Group images by shape for efficient processing (no padding within groups)
1161
+ if group_by_shape and len(processed_images) > 1:
1162
+ grouped_images, grouped_indices = group_images_by_shape(processed_images)
1163
+
1164
+ if verbose:
1165
+ print(f"[batch_generate] Found {len(grouped_images)} unique image shapes")
1166
+ else:
1167
+ # Single image or grouping disabled - treat as one group
1168
+ shape = (
1169
+ (processed_images[0].height, processed_images[0].width)
1170
+ if processed_images
1171
+ else (0, 0)
1172
+ )
1173
+ grouped_images = {shape: processed_images}
1174
+ grouped_indices = {shape: list(range(len(processed_images)))}
1175
+
1176
+ # Process each shape group
1177
+ all_texts = [None] * len(prompts)
1178
+ all_image_sizes = [None] * len(prompts)
1179
+ total_stats = BatchStats()
1180
+
1181
+ for shape, indices in grouped_indices.items():
1182
+ # Get images and prompts for this shape group
1183
+ group_images = [processed_images[i] for i in indices]
1184
+ group_prompts = [prompts[i] for i in indices]
1185
+ group_sizes = [image_sizes_original[i] for i in indices]
1186
+
1187
+ # Handle per-sample max_tokens
1188
+ if isinstance(max_tokens, list):
1189
+ group_max_tokens = [max_tokens[i] for i in indices]
1190
+ else:
1191
+ group_max_tokens = max_tokens
1192
+
1193
+ # Process the entire group at once (same shape = no padding needed)
1194
+ chunk_texts, chunk_stats = _generate_batch(
1195
+ model,
1196
+ processor,
1197
+ group_prompts,
1198
+ group_images,
1199
+ group_max_tokens,
1200
+ **kwargs,
1201
+ )
1202
+
1203
+ # Store results in original order
1204
+ for j, orig_idx in enumerate(indices):
1205
+ all_texts[orig_idx] = chunk_texts[j]
1206
+ all_image_sizes[orig_idx] = group_sizes[j]
1207
+
1208
+ # Accumulate stats
1209
+ total_stats.prompt_tokens += chunk_stats.prompt_tokens
1210
+ total_stats.prompt_time += chunk_stats.prompt_time
1211
+ total_stats.generation_tokens += chunk_stats.generation_tokens
1212
+ total_stats.generation_time += chunk_stats.generation_time
1213
+
1214
+ mx.clear_cache()
1215
+
1216
+ # Compute final stats
1217
+ if total_stats.prompt_time > 0:
1218
+ total_stats.prompt_tps = total_stats.prompt_tokens / total_stats.prompt_time
1219
+ if total_stats.generation_time > 0:
1220
+ total_stats.generation_tps = (
1221
+ total_stats.generation_tokens / total_stats.generation_time
1222
+ )
1223
+ total_stats.peak_memory = mx.get_peak_memory() / 1e9
1224
+
1225
+ if verbose:
1226
+ print(f"[batch_generate] Finished processing {len(prompts)} samples")
1227
+ print(
1228
+ f"[batch_generate] Prompt: {total_stats.prompt_tokens} tokens, {total_stats.prompt_tps:.3f} tokens-per-sec"
1229
+ )
1230
+ print(
1231
+ f"[batch_generate] Generation: {total_stats.generation_tokens} tokens, "
1232
+ f"{total_stats.generation_tps:.3f} tokens-per-sec"
1233
+ )
1234
+ print(f"[batch_generate] Peak memory: {total_stats.peak_memory:.3f} GB")
1235
+
1236
+ response = BatchResponse(all_texts, total_stats)
1237
+ if track_image_sizes:
1238
+ response.image_sizes = all_image_sizes
1239
+ return response
1240
+
1241
+
1242
+ def _generate_batch(
1243
+ model,
1244
+ processor,
1245
+ prompts: List[str],
1246
+ images: List = None,
1247
+ max_tokens: Union[int, List[int]] = 100,
1248
+ verbose: bool = False,
1249
+ **kwargs,
1250
+ ) -> Tuple[List[str], BatchStats]:
1251
+
1252
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
1253
+ batch_size = len(prompts)
1254
+
1255
+ num_images_list = [
1256
+ 1 if i < (len(images) if images is not None else 0) else 0
1257
+ for i in range(len(prompts))
1258
+ ]
1259
+ formatted_prompts = [
1260
+ apply_chat_template(
1261
+ processor,
1262
+ model.config,
1263
+ p,
1264
+ num_images=num_images_list[i],
1265
+ )
1266
+ for i, p in enumerate(prompts)
1267
+ ]
1268
+
1269
+ add_special_tokens = (
1270
+ not hasattr(processor, "chat_template")
1271
+ if model.config.model_type in ["gemma3", "gemma3n"]
1272
+ else True
1273
+ )
1274
+
1275
+ resize_shape = kwargs.pop("resize_shape", None)
1276
+ image_token_index = getattr(model.config, "image_token_index", None)
1277
+
1278
+ inputs = prepare_inputs(
1279
+ processor,
1280
+ images=images,
1281
+ audio=None,
1282
+ prompts=formatted_prompts,
1283
+ image_token_index=image_token_index,
1284
+ resize_shape=resize_shape,
1285
+ add_special_tokens=add_special_tokens,
1286
+ pad_to_uniform_size=False, # Since images are pre-grouped by shape, they're already uniform size
1287
+ )
1288
+ input_ids = inputs.get("input_ids", None)
1289
+ pixel_values = inputs.get("pixel_values", None)
1290
+
1291
+ data_kwargs = {
1292
+ k: v
1293
+ for k, v in inputs.items()
1294
+ if k not in ["input_ids", "pixel_values", "attention_mask"]
1295
+ }
1296
+
1297
+ # Use batch_size for prefill and completion to ensure consistent processing
1298
+ gen = BatchGenerator(
1299
+ model.language_model,
1300
+ processor,
1301
+ prefill_batch_size=batch_size,
1302
+ completion_batch_size=batch_size,
1303
+ **kwargs,
1304
+ )
1305
+
1306
+ with wired_limit(model, [generation_stream]):
1307
+ if pixel_values is not None:
1308
+ embedding_output = model.get_input_embeddings(
1309
+ input_ids, pixel_values, **data_kwargs
1310
+ )
1311
+
1312
+ # Normalize embedding output to a kwargs dict expected by BatchGenerator
1313
+ if isinstance(embedding_output, dict):
1314
+ embed_kwargs = embedding_output
1315
+ elif hasattr(embedding_output, "to_dict"):
1316
+ # Convert to dict and keep non-None fields
1317
+ embed_kwargs = {
1318
+ k: v for k, v in embedding_output.to_dict().items() if v is not None
1319
+ }
1320
+ else:
1321
+ # Assume it's directly an inputs_embeds array
1322
+ embed_kwargs = {"inputs_embeds": embedding_output}
1323
+
1324
+ gen_kwargs = {
1325
+ "pixel_values": pixel_values,
1326
+ **data_kwargs,
1327
+ **embed_kwargs,
1328
+ }
1329
+ else:
1330
+ input_ids = mx.squeeze(input_ids, axis=0)
1331
+ gen_kwargs = {}
1332
+
1333
+ uids = gen.insert(input_ids.tolist(), max_tokens)
1334
+ results = {uid: [] for uid in uids}
1335
+ while responses := gen.next(**gen_kwargs):
1336
+ for r in responses:
1337
+ if r.finish_reason != "stop":
1338
+ results[r.uid].append(r.token)
1339
+
1340
+ texts = [tokenizer.decode(results[uid]) for uid in uids]
1341
+ return texts, gen.stats()
1342
+
1343
+
1344
+ def main():
1345
+ args = parse_arguments()
1346
+ if isinstance(args.image, str):
1347
+ args.image = [args.image]
1348
+
1349
+ model, processor = load(
1350
+ args.model,
1351
+ args.adapter_path,
1352
+ revision=args.revision,
1353
+ trust_remote_code=args.trust_remote_code,
1354
+ )
1355
+ config = model.config
1356
+
1357
+ prompt = args.prompt
1358
+
1359
+ num_images = len(args.image) if args.image is not None else 0
1360
+ num_audios = (
1361
+ 1 if args.audio is not None else 0
1362
+ ) # TODO: Support multiple audio files
1363
+ prompt = apply_chat_template(
1364
+ processor, config, prompt, num_images=num_images, num_audios=num_audios
1365
+ )
1366
+
1367
+ kwargs = {}
1368
+
1369
+ if args.resize_shape is not None:
1370
+ if len(args.resize_shape) not in [1, 2]:
1371
+ raise ValueError("Resize shape must be 1 or 2 integers")
1372
+ kwargs["resize_shape"] = (
1373
+ (args.resize_shape[0],) * 2
1374
+ if len(args.resize_shape) == 1
1375
+ else tuple(args.resize_shape)
1376
+ )
1377
+
1378
+ if args.eos_tokens is not None:
1379
+ eos_tokens = []
1380
+ for token in args.eos_tokens:
1381
+ try:
1382
+ decoded_token = codecs.decode(token, "unicode_escape")
1383
+ eos_tokens.append(decoded_token)
1384
+ except (UnicodeDecodeError, UnicodeError):
1385
+ eos_tokens.append(token)
1386
+ kwargs["eos_tokens"] = eos_tokens
1387
+
1388
+ if args.skip_special_tokens:
1389
+ kwargs["skip_special_tokens"] = args.skip_special_tokens
1390
+
1391
+ # Add processor kwargs from JSON
1392
+ if args.processor_kwargs:
1393
+ kwargs.update(args.processor_kwargs)
1394
+
1395
+ if args.chat:
1396
+ chat = []
1397
+ if args.system:
1398
+ chat.append({"role": "system", "content": args.system})
1399
+ while user := input("User:"):
1400
+ chat.append({"role": "user", "content": user})
1401
+ prompt = apply_chat_template(processor, config, chat, num_images=num_images)
1402
+ response = ""
1403
+ print("Assistant:", end="")
1404
+ stream_kwargs = {
1405
+ "max_tokens": args.max_tokens,
1406
+ "temperature": args.temperature,
1407
+ **kwargs,
1408
+ }
1409
+ if args.prefill_step_size is not None:
1410
+ stream_kwargs["prefill_step_size"] = args.prefill_step_size
1411
+
1412
+ for chunk in stream_generate(
1413
+ model,
1414
+ processor,
1415
+ prompt,
1416
+ args.image,
1417
+ args.audio,
1418
+ **stream_kwargs,
1419
+ ):
1420
+ response += chunk.text
1421
+ print(chunk.text, end="")
1422
+
1423
+ chat.append({"role": "assistant", "content": response})
1424
+ print()
1425
+
1426
+ else:
1427
+ gen_kwargs = {
1428
+ "image": args.image,
1429
+ "audio": args.audio,
1430
+ "temperature": args.temperature,
1431
+ "max_tokens": args.max_tokens,
1432
+ "verbose": args.verbose,
1433
+ "max_kv_size": args.max_kv_size,
1434
+ "kv_bits": args.kv_bits,
1435
+ "kv_group_size": args.kv_group_size,
1436
+ "quantized_kv_start": args.quantized_kv_start,
1437
+ **kwargs,
1438
+ }
1439
+ if args.prefill_step_size is not None:
1440
+ gen_kwargs["prefill_step_size"] = args.prefill_step_size
1441
+
1442
+ result = generate(
1443
+ model,
1444
+ processor,
1445
+ prompt,
1446
+ **gen_kwargs,
1447
+ )
1448
+ if not args.verbose:
1449
+ print(result.text)
1450
+
1451
+
1452
+ if __name__ == "__main__":
1453
+ print(
1454
+ "Calling `python -m mlx_vlm.generate ...` directly is deprecated."
1455
+ " Use `mlx_vlm generate` or `python -m mlx_vlm generate` instead."
1456
+ )
1457
+ main()