fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,565 @@
1
+ import argparse
2
+ import csv
3
+ import json
4
+ import logging
5
+ import os
6
+ import random
7
+ import re
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ from datasets import load_dataset
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+ from mlx_vlm import load
16
+ from mlx_vlm.evals.utils import inference
17
+
18
+
19
+ def process_question(sample: dict) -> str:
20
+ """Format the question with choices if it's multiple choice."""
21
+ question = sample["query"]
22
+
23
+ if sample["question_type"] == "multi_choice" and sample["choices"]:
24
+ choices_text = "\n".join(
25
+ [f"({chr(65+i)}) {choice}" for i, choice in enumerate(sample["choices"])]
26
+ )
27
+ question = f"{question}\n{choices_text}"
28
+ return question
29
+
30
+
31
+ def normalize_answer(response: str, problem: dict) -> Optional[str]:
32
+ """Normalize the model's response to extract the answer."""
33
+ response = response.strip()
34
+
35
+ if not response:
36
+ return None
37
+
38
+ question_type = problem["question_type"]
39
+ answer_type = problem["answer_type"]
40
+ choices = problem.get("choices", [])
41
+
42
+ # For multiple choice, try to extract the letter
43
+ if question_type == "multi_choice":
44
+ # First, try to find boxed answers
45
+ boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
46
+ if boxed_match:
47
+ boxed_content = boxed_match.group(1)
48
+ # Check if it's a choice letter
49
+ letter_match = re.match(
50
+ r"^\(?([A-Z])\)?\.?$", boxed_content.strip().upper()
51
+ )
52
+ if letter_match:
53
+ letter = letter_match.group(1)
54
+ idx = ord(letter) - ord("A")
55
+ if 0 <= idx < len(choices):
56
+ return choices[idx]
57
+ # Check if it's directly one of the choices
58
+ if boxed_content.strip() in choices:
59
+ return boxed_content.strip()
60
+
61
+ # Try to find Chinese answer pattern "故选:X" or "故选X"
62
+ chinese_match = re.search(r"故选[::]\s*([A-Z])", response.upper())
63
+ if not chinese_match:
64
+ chinese_match = re.search(r"故选\s*([A-Z])", response.upper())
65
+ if chinese_match:
66
+ letter = chinese_match.group(1)
67
+ idx = ord(letter) - ord("A")
68
+ if 0 <= idx < len(choices):
69
+ return choices[idx]
70
+
71
+ # Try to find "the answer is X" or "answer: X" patterns near the end
72
+ answer_patterns = [
73
+ r"(?:the\s+)?answer\s+is\s+\(?([A-Z])\)?",
74
+ r"answer:\s*\(?([A-Z])\)?",
75
+ r"choose\s+\(?([A-Z])\)?",
76
+ r"option\s+\(?([A-Z])\)?",
77
+ ]
78
+
79
+ # Search from the end of the response (last 500 chars)
80
+ end_section = response[-500:] if len(response) > 500 else response
81
+ for pattern in answer_patterns:
82
+ matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
83
+ if matches:
84
+ # Take the last match
85
+ letter = matches[-1].group(1).upper()
86
+ idx = ord(letter) - ord("A")
87
+ if 0 <= idx < len(choices):
88
+ return choices[idx]
89
+
90
+ # Look for patterns like "(A)", "A)", "A.", "A" - prioritize from the end
91
+ matches = list(re.finditer(r"\(?([A-Z])\)?\.?", response.upper()))
92
+ if matches:
93
+ # Try the last few matches first
94
+ for match in reversed(matches[-5:]):
95
+ letter = match.group(1)
96
+ idx = ord(letter) - ord("A")
97
+ if 0 <= idx < len(choices):
98
+ return choices[idx]
99
+
100
+ # If the response is exactly one of the choices
101
+ if response in choices:
102
+ return response
103
+
104
+ # Try to find the most similar choice using edit distance
105
+ def distance(s1, s2):
106
+ if len(s1) < len(s2):
107
+ return distance(s2, s1)
108
+ if len(s2) == 0:
109
+ return len(s1)
110
+
111
+ previous_row = range(len(s2) + 1)
112
+ for i, c1 in enumerate(s1):
113
+ current_row = [i + 1]
114
+ for j, c2 in enumerate(s2):
115
+ insertions = previous_row[j + 1] + 1
116
+ deletions = current_row[j] + 1
117
+ substitutions = previous_row[j] + (c1 != c2)
118
+ current_row.append(min(insertions, deletions, substitutions))
119
+ previous_row = current_row
120
+
121
+ return previous_row[-1]
122
+
123
+ if choices:
124
+ distances = [
125
+ distance(response.lower(), choice.lower()) for choice in choices
126
+ ]
127
+ return choices[distances.index(min(distances))]
128
+
129
+ # For integer answers
130
+ elif answer_type == "integer":
131
+ # First try to find boxed answer
132
+ boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
133
+ if boxed_match:
134
+ boxed_content = boxed_match.group(1)
135
+ # Remove commas from numbers
136
+ boxed_content = boxed_content.replace(",", "")
137
+ # Try scientific notation first
138
+ sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", boxed_content)
139
+ if sci_numbers:
140
+ try:
141
+ return str(int(float(sci_numbers[0])))
142
+ except:
143
+ pass
144
+ # Then regular numbers
145
+ numbers = re.findall(r"-?\d+", boxed_content)
146
+ if numbers:
147
+ try:
148
+ return str(int(numbers[0]))
149
+ except:
150
+ pass
151
+
152
+ # Try common answer patterns near the end
153
+ end_section = response[-500:] if len(response) > 500 else response
154
+ answer_patterns = [
155
+ r"(?:the\s+)?answer\s+is\s+(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
156
+ r"answer:\s*(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
157
+ r"(?:total|result|left|remaining)(?:\s+is|\s+are|:)\s*(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
158
+ ]
159
+
160
+ for pattern in answer_patterns:
161
+ matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
162
+ if matches:
163
+ try:
164
+ # Remove commas before converting
165
+ num_str = matches[-1].group(1).replace(",", "")
166
+ return str(int(float(num_str)))
167
+ except:
168
+ pass
169
+
170
+ # Look for scientific notation anywhere in response
171
+ sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", response)
172
+ if sci_numbers:
173
+ try:
174
+ return str(int(float(sci_numbers[-1])))
175
+ except:
176
+ pass
177
+
178
+ # Fall back to finding all numbers (including comma-formatted) and taking the last one
179
+ # Match numbers with optional commas: 7,518 or 7518
180
+ numbers = re.findall(r"-?[\d,]+", response)
181
+ if numbers:
182
+ try:
183
+ # Remove commas and try the last number first
184
+ return str(int(numbers[-1].replace(",", "")))
185
+ except:
186
+ pass
187
+
188
+ # For float answers
189
+ elif answer_type == "float":
190
+ precision = int(problem.get("precision", 2))
191
+
192
+ # First try to find boxed answer
193
+ boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
194
+ if boxed_match:
195
+ boxed_content = boxed_match.group(1)
196
+ # Try scientific notation first
197
+ sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", boxed_content)
198
+ if sci_numbers:
199
+ try:
200
+ return str(round(float(sci_numbers[0]), precision))
201
+ except:
202
+ pass
203
+ # Then regular numbers
204
+ numbers = re.findall(r"-?\d+\.?\d*", boxed_content)
205
+ if numbers:
206
+ try:
207
+ return str(round(float(numbers[0]), precision))
208
+ except:
209
+ pass
210
+
211
+ # Try common answer patterns near the end
212
+ end_section = response[-500:] if len(response) > 500 else response
213
+ answer_patterns = [
214
+ r"(?:the\s+)?answer\s+is\s+(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)",
215
+ r"answer:\s*(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)",
216
+ r"d\s*=\s*(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)", # For physics problems with d=
217
+ ]
218
+
219
+ for pattern in answer_patterns:
220
+ matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
221
+ if matches:
222
+ try:
223
+ return str(round(float(matches[-1].group(1)), precision))
224
+ except:
225
+ pass
226
+
227
+ # Look for scientific notation anywhere in response
228
+ sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", response)
229
+ if sci_numbers:
230
+ try:
231
+ return str(round(float(sci_numbers[-1]), precision))
232
+ except:
233
+ pass
234
+
235
+ # Fall back to finding all numbers and taking the last one
236
+ numbers = re.findall(r"-?\d+\.?\d*", response)
237
+ if numbers:
238
+ try:
239
+ # Try the last number first
240
+ return str(round(float(numbers[-1]), precision))
241
+ except:
242
+ pass
243
+
244
+ return response
245
+
246
+
247
+ def evaluate_answer(prediction: Optional[str], ground_truth: str) -> bool:
248
+ """Check if the prediction matches the ground truth."""
249
+ if prediction is None:
250
+ return False
251
+ try:
252
+ # First check exact match
253
+ if str(prediction).strip() == str(ground_truth).strip():
254
+ return True
255
+
256
+ # Handle numeric word representations
257
+ word_to_num = {
258
+ "zero": "0",
259
+ "one": "1",
260
+ "two": "2",
261
+ "three": "3",
262
+ "four": "4",
263
+ "five": "5",
264
+ "six": "6",
265
+ "seven": "7",
266
+ "eight": "8",
267
+ "nine": "9",
268
+ "ten": "10",
269
+ "eleven": "11",
270
+ "twelve": "12",
271
+ "thirteen": "13",
272
+ "fourteen": "14",
273
+ "fifteen": "15",
274
+ "sixteen": "16",
275
+ "seventeen": "17",
276
+ "eighteen": "18",
277
+ "nineteen": "19",
278
+ "twenty": "20",
279
+ }
280
+
281
+ pred_normalized = str(prediction).strip().lower()
282
+ gt_normalized = str(ground_truth).strip().lower()
283
+
284
+ # Convert words to numbers
285
+ if pred_normalized in word_to_num:
286
+ pred_normalized = word_to_num[pred_normalized]
287
+ if gt_normalized in word_to_num:
288
+ gt_normalized = word_to_num[gt_normalized]
289
+
290
+ return pred_normalized == gt_normalized
291
+ except:
292
+ return False
293
+
294
+
295
+ def parse_args():
296
+ parser = argparse.ArgumentParser(
297
+ description="Evaluate models on MathVista benchmark"
298
+ )
299
+ parser.add_argument(
300
+ "--model",
301
+ type=str,
302
+ required=True,
303
+ help="The path to the MLX VLM model",
304
+ )
305
+ parser.add_argument(
306
+ "--adapter-path",
307
+ type=str,
308
+ help="Optional path for the trained adapter weights and config",
309
+ )
310
+ parser.add_argument(
311
+ "--dataset",
312
+ type=str,
313
+ default="AI4Math/MathVista",
314
+ help="Hugging Face dataset name",
315
+ )
316
+ parser.add_argument(
317
+ "--split",
318
+ type=str,
319
+ default="testmini",
320
+ choices=["testmini", "test"],
321
+ help="Dataset split to evaluate on",
322
+ )
323
+ parser.add_argument(
324
+ "--streaming",
325
+ action="store_true",
326
+ help="Use streaming dataset loading",
327
+ )
328
+ parser.add_argument(
329
+ "--max-samples",
330
+ type=int,
331
+ default=None,
332
+ help="Maximum number of samples to evaluate (for debugging)",
333
+ )
334
+ parser.add_argument(
335
+ "--output-dir",
336
+ type=str,
337
+ default="results/mathvista",
338
+ help="Directory to save results",
339
+ )
340
+ parser.add_argument(
341
+ "--max-tokens",
342
+ type=int,
343
+ default=512,
344
+ help="Maximum number of tokens to generate",
345
+ )
346
+ parser.add_argument(
347
+ "--temperature",
348
+ type=float,
349
+ default=0.0,
350
+ help="Temperature for generation",
351
+ )
352
+ parser.add_argument(
353
+ "--verbose",
354
+ action="store_true",
355
+ help="Print detailed output for debugging",
356
+ )
357
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
358
+ return parser.parse_args()
359
+
360
+
361
+ def main():
362
+ args = parse_args()
363
+
364
+ random.seed(args.seed)
365
+
366
+ # Setup logging
367
+ logging.basicConfig(
368
+ level=logging.INFO if args.verbose else logging.WARNING,
369
+ format="%(asctime)s - %(levelname)s - %(message)s",
370
+ )
371
+
372
+ logging.info(f"Loading model from {args.model}")
373
+ model, processor = load(
374
+ args.model, adapter_path=args.adapter_path, trust_remote_code=True
375
+ )
376
+
377
+ # Load dataset
378
+ logging.info(f"Loading dataset {args.dataset}, split {args.split}")
379
+ dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
380
+
381
+ if args.max_samples:
382
+ dataset = dataset.select(range(min(args.max_samples, len(dataset))))
383
+
384
+ # Create output directory
385
+ output_dir = Path(args.output_dir)
386
+ output_dir.mkdir(parents=True, exist_ok=True)
387
+
388
+ results = {}
389
+ category_scores = {}
390
+ correct = 0
391
+ total = 0
392
+
393
+ # Evaluate each sample
394
+ for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")):
395
+ pid = sample["pid"]
396
+
397
+ try:
398
+ # Load and process image
399
+ if "decoded_image" in sample and sample["decoded_image"]:
400
+ if isinstance(sample["decoded_image"], str):
401
+ image_path = sample["decoded_image"]
402
+ if os.path.exists(image_path):
403
+ image = Image.open(image_path).convert("RGB")
404
+ else:
405
+ logging.warning(
406
+ f"Image not found: {image_path}, skipping sample {pid}"
407
+ )
408
+ continue
409
+ else:
410
+ # Image is already loaded
411
+ image = sample["decoded_image"].convert("RGB")
412
+ else:
413
+ logging.warning(f"No image for sample {pid}, skipping")
414
+ continue
415
+
416
+ # Create prompt
417
+ prompt = process_question(sample)
418
+
419
+ # Generate response
420
+ output = inference(
421
+ model,
422
+ processor,
423
+ prompt,
424
+ image=image,
425
+ max_tokens=args.max_tokens,
426
+ temperature=args.temperature,
427
+ )
428
+
429
+ response = output.strip()
430
+
431
+ # Normalize answer
432
+ prediction = normalize_answer(response, sample)
433
+
434
+ # Evaluate
435
+ ground_truth = sample.get("answer", "")
436
+ if args.split == "testmini" and ground_truth:
437
+ is_correct = evaluate_answer(prediction, ground_truth)
438
+ if is_correct:
439
+ correct += 1
440
+ else:
441
+ is_correct = None
442
+
443
+ total += 1
444
+
445
+ # Store results
446
+ results[pid] = {
447
+ "pid": pid,
448
+ "question": sample["question"],
449
+ "query": sample["query"],
450
+ "question_type": sample["question_type"],
451
+ "answer_type": sample["answer_type"],
452
+ "choices": sample.get("choices", []),
453
+ "unit": sample.get("unit", ""),
454
+ "precision": sample.get("precision", 0),
455
+ "ground_truth": ground_truth,
456
+ "response": response,
457
+ "prediction": prediction,
458
+ "correct": is_correct,
459
+ "metadata": sample.get("metadata", {}),
460
+ }
461
+ # Track category-wise performance
462
+ category = sample.get("metadata", {}).get("category", "unknown")
463
+ if category not in category_scores:
464
+ category_scores[category] = {"correct": 0, "total": 0}
465
+
466
+ category_scores[category]["total"] += 1
467
+ if is_correct:
468
+ category_scores[category]["correct"] += 1
469
+
470
+ if args.verbose:
471
+ logging.info(f"\nSample {pid}:")
472
+ logging.info(f"Question: {sample['question']}")
473
+ logging.info(f"Response: {response}")
474
+ logging.info(f"Prediction: {prediction}")
475
+ logging.info(f"Ground Truth: {ground_truth}")
476
+ logging.info(f"Correct: {is_correct}")
477
+
478
+ except Exception as e:
479
+ logging.error(f"Error processing sample {pid}: {e}")
480
+ continue
481
+
482
+ # Calculate accuracy if applicable
483
+ if args.split == "testmini":
484
+ accuracy = correct / total if total > 0 else 0
485
+ else:
486
+ accuracy = None
487
+ correct = None
488
+
489
+ # Save results
490
+ model_name = args.model.split("/")[-1]
491
+ results_file = output_dir / f"{model_name}_MathVista_{args.split}.csv"
492
+
493
+ # Convert results to list of dictionaries for CSV writing
494
+ fieldnames = [
495
+ "pid",
496
+ "question",
497
+ "query",
498
+ "question_type",
499
+ "answer_type",
500
+ "choices",
501
+ "unit",
502
+ "precision",
503
+ "ground_truth",
504
+ "response",
505
+ "prediction",
506
+ "correct",
507
+ "metadata",
508
+ ]
509
+
510
+ with open(results_file, "w", newline="", encoding="utf-8") as f:
511
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
512
+ writer.writeheader()
513
+
514
+ for result in results.values():
515
+ # Convert list and dict fields to strings for CSV
516
+ row = result.copy()
517
+ if isinstance(row.get("choices"), list):
518
+ row["choices"] = "; ".join(row["choices"])
519
+ if isinstance(row.get("metadata"), dict):
520
+ row["metadata"] = json.dumps(row["metadata"])
521
+ writer.writerow(row)
522
+
523
+ # Save summary
524
+ summary = {
525
+ "model": args.model,
526
+ "dataset": args.dataset,
527
+ "split": args.split,
528
+ "total_samples": total,
529
+ "category_scores": category_scores,
530
+ }
531
+
532
+ if accuracy is not None:
533
+ summary["correct"] = correct
534
+ summary["accuracy"] = accuracy
535
+
536
+ summary_file = output_dir / f"{model_name}_MathVista_{args.split}.json"
537
+ with open(summary_file, "w") as f:
538
+ json.dump(summary, f, indent=2)
539
+
540
+ print(f"\n{'='*80}")
541
+ print("MathVista Evaluation Results")
542
+ print(f"{'='*80}")
543
+ print(f"Model: {args.model}")
544
+ print(f"Split: {args.split}")
545
+ print(f"Total Samples: {total}")
546
+ if accuracy is not None:
547
+ print(f"Correct: {correct}")
548
+ print(f"Accuracy: {accuracy*100:.2f}%")
549
+ else:
550
+ print("Accuracy not computed for this split (no ground truth labels)")
551
+
552
+ print("\n" + "-" * 80)
553
+ print(f"Subcategory Scores:")
554
+ print(f"{'-'*80}")
555
+ for category, scores in category_scores.items():
556
+ cat_total = scores["total"]
557
+ cat_correct = scores["correct"]
558
+ cat_accuracy = cat_correct / cat_total if cat_total > 0 else 0
559
+ print(f" {category}: {cat_correct}/{cat_total} ({cat_accuracy*100:.2f}%)")
560
+ print(f"{'='*80}")
561
+ print(f"\nResults saved to {results_file} and {summary_file}")
562
+
563
+
564
+ if __name__ == "__main__":
565
+ main()