fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,453 @@
1
+ import argparse
2
+ import csv
3
+ import json
4
+ import logging
5
+ import random
6
+ import traceback
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ import mlx.core as mx
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+ from mlx_vlm import load
15
+ from mlx_vlm.evals.utils import inference
16
+ from mlx_vlm.generate import batch_generate
17
+ from mlx_vlm.sample_utils import top_p_sampling
18
+
19
+
20
+ def process_question(sample: dict) -> str:
21
+ """Format the question."""
22
+ return sample["question"]
23
+
24
+
25
+ def normalize_answer(response: str, problem: dict) -> Optional[str]:
26
+ """Normalize the model's response to extract the answer."""
27
+ if not response:
28
+ return None
29
+ return response.strip()
30
+
31
+
32
+ def evaluate_answer(prediction: Optional[str], ground_truth: list) -> bool:
33
+ """Check if any ground truth answer is contained in the prediction."""
34
+ if prediction is None:
35
+ return False
36
+ pred = prediction.strip().lower()
37
+ return any(str(a).strip().lower() in pred for a in ground_truth)
38
+
39
+
40
+ def OCRBench_val(results_list, args, model_name, dataset="OCRBench"):
41
+ correct = 0
42
+ total = len(results_list)
43
+ category_scores = {}
44
+ for row in results_list:
45
+ ground_truth = row["ground_truth"]
46
+ if isinstance(ground_truth, str):
47
+
48
+ ground_truth = [a.strip() for a in ground_truth.split(";")]
49
+ prediction = row["prediction"]
50
+
51
+ is_correct = evaluate_answer(prediction, ground_truth)
52
+ row["correct"] = is_correct
53
+ if is_correct:
54
+ correct += 1
55
+ category = row["type"]
56
+ if category not in category_scores:
57
+ category_scores[category] = {"correct": 0, "total": 0}
58
+ category_scores[category]["total"] += 1
59
+ if is_correct:
60
+ category_scores[category]["correct"] += 1
61
+
62
+ accuracy = correct / total if total > 0 else 0
63
+
64
+ output_dir = Path(args.output_dir)
65
+ output_dir.mkdir(parents=True, exist_ok=True)
66
+
67
+ results_file = output_dir / f"{model_name}_{dataset}_{args.split}.csv"
68
+
69
+ fieldnames = [
70
+ "id",
71
+ "question",
72
+ "dataset",
73
+ "type",
74
+ "ground_truth",
75
+ "response",
76
+ "prediction",
77
+ "correct",
78
+ ]
79
+
80
+ with open(results_file, "w", newline="", encoding="utf-8") as f:
81
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
82
+ writer.writeheader()
83
+ for row in results_list:
84
+ out_row = row.copy()
85
+ if isinstance(out_row["ground_truth"], list):
86
+ out_row["ground_truth"] = "; ".join(map(str, out_row["ground_truth"]))
87
+ writer.writerow(out_row)
88
+
89
+ summary = {
90
+ "model": model_name,
91
+ "dataset": args.dataset,
92
+ "split": args.split,
93
+ "total_samples": total,
94
+ "correct": correct,
95
+ "accuracy": accuracy,
96
+ "category_scores": category_scores,
97
+ }
98
+
99
+ summary_file = output_dir / f"{model_name}_{dataset}_{args.split}.json"
100
+ with open(summary_file, "w") as f:
101
+ json.dump(summary, f, indent=2)
102
+
103
+ print(f"\n{'='*80}")
104
+ print(f"{dataset} Evaluation Results")
105
+ print(f"{'='*80}")
106
+ print(f"Model: {summary['model']}")
107
+ print(f"Split: {args.split}")
108
+ print(f"Total Samples: {total}")
109
+ print(f"Correct: {correct}")
110
+ print(f"Accuracy: {accuracy*100:.2f}%")
111
+
112
+ if len(category_scores.items()) > 1:
113
+ print("\n" + "-" * 80)
114
+ print(f"Subcategory Scores:")
115
+ print(f"{'-'*80}")
116
+ for category, scores in category_scores.items():
117
+ cat_total = scores["total"]
118
+ cat_correct = scores["correct"]
119
+ cat_accuracy = cat_correct / cat_total if cat_total > 0 else 0
120
+ print(f" {category}: {cat_correct}/{cat_total} ({cat_accuracy*100:.2f}%)")
121
+ print(f"{'='*80}")
122
+ print(f"\nResults saved to {results_file} and {summary_file}")
123
+
124
+
125
+ def parse_args():
126
+ parser = argparse.ArgumentParser(
127
+ description="Evaluate models on OCRBench benchmark"
128
+ )
129
+ parser.add_argument(
130
+ "--model",
131
+ type=str,
132
+ required=True,
133
+ help="The path to the MLX VLM model",
134
+ )
135
+ parser.add_argument(
136
+ "--adapter-path",
137
+ type=str,
138
+ help="Optional path for the trained adapter weights and config",
139
+ )
140
+ parser.add_argument(
141
+ "--dataset",
142
+ type=str,
143
+ default="echo840/OCRBench",
144
+ help="Hugging Face dataset name",
145
+ )
146
+ parser.add_argument(
147
+ "--split",
148
+ type=str,
149
+ default="test",
150
+ choices=["test"],
151
+ help="Dataset split to evaluate on",
152
+ )
153
+ parser.add_argument(
154
+ "--streaming",
155
+ action="store_true",
156
+ help="Use streaming dataset loading",
157
+ )
158
+ parser.add_argument(
159
+ "--max-samples",
160
+ type=int,
161
+ default=None,
162
+ help="Maximum number of samples to evaluate (for debugging)",
163
+ )
164
+ parser.add_argument(
165
+ "--predictions-file",
166
+ type=str,
167
+ default=None,
168
+ help="File with predictions",
169
+ )
170
+ parser.add_argument(
171
+ "--output-dir",
172
+ type=str,
173
+ default="results/ocrbench",
174
+ help="Directory to save results",
175
+ )
176
+ parser.add_argument(
177
+ "--max-tokens",
178
+ type=int,
179
+ default=512,
180
+ help="Maximum number of tokens to generate",
181
+ )
182
+ parser.add_argument(
183
+ "--temperature",
184
+ type=float,
185
+ default=0.0,
186
+ help="Temperature for generation",
187
+ )
188
+ parser.add_argument(
189
+ "--verbose",
190
+ action="store_true",
191
+ help="Print detailed output for debugging",
192
+ )
193
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
194
+ parser.add_argument(
195
+ "--batch-size",
196
+ type=int,
197
+ default=1,
198
+ help="Batch size for generation (1 = sequential, >1 = batch generation)",
199
+ )
200
+ return parser.parse_args()
201
+
202
+
203
+ def create_sampler(temperature: float, top_p: float = 1.0):
204
+ """Create a sampler function for batch generation.
205
+
206
+ For accuracy consistency across batch sizes, we use deterministic sampling
207
+ (temperature=0) by default. This ensures the same outputs regardless of batch size.
208
+ """
209
+
210
+ def sampler(logits: mx.array) -> mx.array:
211
+ if temperature == 0:
212
+ return mx.argmax(logits, axis=-1)
213
+ else:
214
+ if top_p > 0 and top_p < 1.0:
215
+ return top_p_sampling(logits, top_p, temperature)
216
+ else:
217
+ return mx.random.categorical(logits * (1 / temperature))
218
+
219
+ return sampler
220
+
221
+
222
+ def process_batch(
223
+ model,
224
+ processor,
225
+ batch_samples,
226
+ args,
227
+ ):
228
+ """Process a batch of samples using batch_generate.
229
+
230
+ batch_generate now handles image size sorting internally to minimize
231
+ padding effects and maintain accuracy.
232
+ """
233
+ prompts = []
234
+ images = []
235
+ sample_metadata = []
236
+
237
+ for sample in batch_samples:
238
+ pid = sample.get("id", str(sample.get("_idx", 0)))
239
+
240
+ # Load and process image
241
+ if "image" in sample and sample["image"]:
242
+ image = sample["image"].convert("RGB")
243
+ else:
244
+ logging.warning(f"No image for sample {pid}, skipping")
245
+ continue
246
+
247
+ images.append(image)
248
+
249
+ # Create prompt
250
+ prompt = process_question(sample)
251
+ prompts.append(prompt)
252
+
253
+ # Store metadata for results
254
+ sample_metadata.append(
255
+ {
256
+ "id": pid,
257
+ "question": sample["question"],
258
+ "dataset": sample.get("dataset", ""),
259
+ "type": sample.get("type", ""),
260
+ "ground_truth": (
261
+ sample.get("answers", [])
262
+ if hasattr(sample, "answers")
263
+ else sample.get("answer", [])
264
+ ),
265
+ }
266
+ )
267
+
268
+ if not prompts:
269
+ return []
270
+
271
+ # Create sampler for deterministic output (temperature=0 by default)
272
+ sampler = create_sampler(args.temperature)
273
+
274
+ # Use batch_generate for processing
275
+ # batch_generate now handles image size sorting internally to avoid padding issues
276
+ batch_response = batch_generate(
277
+ model,
278
+ processor,
279
+ images=images,
280
+ prompts=prompts,
281
+ max_tokens=args.max_tokens,
282
+ sampler=sampler,
283
+ verbose=args.verbose,
284
+ )
285
+
286
+ # Process results
287
+ results = []
288
+ for text, metadata in zip(batch_response.texts, sample_metadata):
289
+ response = text.strip()
290
+ prediction = normalize_answer(response, {"question": metadata["question"]})
291
+
292
+ result = {
293
+ **metadata,
294
+ "response": response,
295
+ "prediction": prediction,
296
+ "correct": False,
297
+ }
298
+ results.append(result)
299
+
300
+ if args.verbose:
301
+ logging.info(f"\nSample {metadata['id']}:")
302
+ logging.info(f"Question: {metadata['question']}")
303
+ logging.info(f"Response: {response}")
304
+ logging.info(f"Prediction: {prediction}")
305
+ logging.info(f"Ground Truth: {metadata['ground_truth']}")
306
+
307
+ return results
308
+
309
+
310
+ def main():
311
+ args = parse_args()
312
+
313
+ random.seed(args.seed)
314
+
315
+ # Setup logging
316
+ logging.basicConfig(
317
+ level=logging.INFO if args.verbose else logging.WARNING,
318
+ format="%(asctime)s - %(levelname)s - %(message)s",
319
+ )
320
+
321
+ if args.predictions_file:
322
+ logging.info(
323
+ f"\033[32mLoading predictions from {args.predictions_file} for evaluation\033[0m"
324
+ )
325
+ with open(args.predictions_file, "r", encoding="utf-8") as f:
326
+ reader = csv.DictReader(f)
327
+ loaded_results = list(reader)
328
+ model_name = Path(args.predictions_file).stem.split("_OCRBench")[0]
329
+ dataset = (
330
+ "OCRBench-v2" if "OCRBench-v2" in args.predictions_file else "OCRBench"
331
+ )
332
+ OCRBench_val(loaded_results, args, model_name, dataset)
333
+ logging.info(f"\033[32mEvaluation complete\033[0m")
334
+ return
335
+
336
+ logging.info(f"Loading model from {args.model}")
337
+ model, processor = load(
338
+ args.model, adapter_path=args.adapter_path, trust_remote_code=True
339
+ )
340
+
341
+ # Load dataset
342
+ logging.info(f"Loading dataset {args.dataset}, split {args.split}")
343
+ dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
344
+
345
+ if args.max_samples:
346
+ dataset = dataset.take(args.max_samples)
347
+
348
+ # Convert to list for batching if streaming
349
+ if args.streaming:
350
+ dataset = list(dataset)
351
+
352
+ results = {}
353
+ batch_size = args.batch_size
354
+
355
+ if batch_size > 1:
356
+ # Batch generation mode
357
+ logging.info(f"Using batch generation with batch_size={batch_size}")
358
+
359
+ # Collect samples into batches
360
+ batch = []
361
+ all_samples = list(dataset) if hasattr(dataset, "__iter__") else dataset
362
+
363
+ # Add index to samples for tracking
364
+ for idx, sample in enumerate(all_samples):
365
+ sample["_idx"] = idx
366
+
367
+ for idx, sample in enumerate(
368
+ tqdm(all_samples, desc=f"Evaluating (batch_size={batch_size})")
369
+ ):
370
+ batch.append(sample)
371
+
372
+ # Process batch when full or at the end
373
+ if len(batch) >= batch_size or idx == len(all_samples) - 1:
374
+ try:
375
+ batch_results = process_batch(model, processor, batch, args)
376
+ for result in batch_results:
377
+ results[result["id"]] = result
378
+ except Exception as e:
379
+ logging.error(f"Error processing batch: {e}")
380
+ traceback.print_exc()
381
+
382
+ batch = []
383
+
384
+ # Clear memory after each batch
385
+ mx.clear_cache()
386
+
387
+ else:
388
+ # Sequential generation mode (original behavior)
389
+ for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")):
390
+ pid = sample.get("id", str(idx))
391
+
392
+ try:
393
+ # Load and process image
394
+ if "image" in sample and sample["image"]:
395
+ image = sample["image"].convert("RGB")
396
+ else:
397
+ logging.warning(f"No image for sample {pid}, skipping")
398
+ continue
399
+
400
+ # Create prompt
401
+ prompt = process_question(sample)
402
+
403
+ # Generate response
404
+ output = inference(
405
+ model,
406
+ processor,
407
+ prompt,
408
+ image=image,
409
+ max_tokens=args.max_tokens,
410
+ temperature=args.temperature,
411
+ )
412
+
413
+ response = output.strip()
414
+
415
+ # Normalize answer
416
+ prediction = normalize_answer(response, sample)
417
+
418
+ # Store results (evaluation happens later)
419
+ results[pid] = {
420
+ "id": pid,
421
+ "question": sample["question"],
422
+ "dataset": sample.get("dataset", ""),
423
+ "type": sample.get("type", ""),
424
+ "ground_truth": (
425
+ sample.get("answers", [])
426
+ if hasattr(sample, "answers")
427
+ else sample.get("answer", [])
428
+ ),
429
+ "response": response,
430
+ "prediction": prediction,
431
+ "correct": False,
432
+ }
433
+
434
+ if args.verbose:
435
+ logging.info(f"\nSample {pid}:")
436
+ logging.info(f"Question: {sample['question']}")
437
+ logging.info(f"Response: {response}")
438
+ logging.info(f"Prediction: {prediction}")
439
+ logging.info(f"Ground Truth: {sample.get('answers', [])}")
440
+
441
+ except Exception as e:
442
+ traceback.print_exc()
443
+ logging.error(f"Error processing sample {pid}: {e}")
444
+ continue
445
+
446
+ results_list = list(results.values())
447
+ model_name = args.model.split("/")[-1]
448
+ dataset = args.dataset.split("/")[-1]
449
+ OCRBench_val(results_list, args, model_name, dataset)
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()
mlx_vlm/evals/utils.py ADDED
@@ -0,0 +1,37 @@
1
+ from mlx_vlm import generate
2
+ from mlx_vlm.prompt_utils import apply_chat_template
3
+
4
+
5
+ def inference(
6
+ model,
7
+ processor,
8
+ question,
9
+ image,
10
+ max_tokens=3000,
11
+ temperature=0.0,
12
+ resize_shape=None,
13
+ verbose=False,
14
+ ):
15
+ """Run inference on a single question."""
16
+ if image is None:
17
+ num_images = 0
18
+ elif isinstance(image, list):
19
+ num_images = len(image)
20
+ else:
21
+ num_images = 1
22
+
23
+ prompt = apply_chat_template(
24
+ processor, model.config, question, num_images=num_images
25
+ )
26
+
27
+ response = generate(
28
+ model,
29
+ processor,
30
+ prompt,
31
+ image=image,
32
+ max_tokens=max_tokens,
33
+ temperature=temperature,
34
+ resize_shape=resize_shape,
35
+ verbose=verbose,
36
+ )
37
+ return response.text