fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/evals/mmmu.py ADDED
@@ -0,0 +1,528 @@
1
+ import argparse
2
+ import csv
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ from json import dump
8
+
9
+ from datasets import load_dataset
10
+ from tqdm import tqdm
11
+
12
+ from mlx_vlm import load
13
+ from mlx_vlm.evals.utils import inference
14
+
15
+ # All 30 MMMU subjects (confirmed from dataset)
16
+ MMMU_SUBJECTS = [
17
+ "Accounting",
18
+ "Agriculture",
19
+ "Architecture_and_Engineering",
20
+ "Art",
21
+ "Art_Theory",
22
+ "Basic_Medical_Science",
23
+ "Biology",
24
+ "Chemistry",
25
+ "Clinical_Medicine",
26
+ "Computer_Science",
27
+ "Design",
28
+ "Diagnostics_and_Laboratory_Medicine",
29
+ "Economics",
30
+ "Electronics",
31
+ "Energy_and_Power",
32
+ "Finance",
33
+ "Geography",
34
+ "History",
35
+ "Literature",
36
+ "Manage",
37
+ "Marketing",
38
+ "Materials",
39
+ "Math",
40
+ "Mechanical_Engineering",
41
+ "Music",
42
+ "Pharmacy",
43
+ "Physics",
44
+ "Psychology",
45
+ "Public_Health",
46
+ "Sociology",
47
+ ]
48
+
49
+ MMMU_PRO_SUBJECTS = [
50
+ "vision",
51
+ "standard (10 options)",
52
+ "standard (4 options)",
53
+ ]
54
+
55
+
56
+ def normalize_number(s):
57
+ """Normalize numeric strings for comparison."""
58
+ try:
59
+ return float(str(s).strip().replace(",", ""))
60
+ except:
61
+ return str(s).strip()
62
+
63
+
64
+ def MMMU_eval(data: list, eval_file: str):
65
+ """
66
+ Evaluate MMMU results by subject.
67
+ Handles both multiple choice (A-F) and open-ended questions.
68
+ """
69
+
70
+ # Track by subject
71
+ subject_scores = {}
72
+ subject_counters = {}
73
+
74
+ total_correct = 0
75
+ total_questions = 0
76
+
77
+ for line in data:
78
+ predict = str(line["prediction"])
79
+ answer = str(line["answer"])
80
+ subject = str(line.get("subject", "Unknown"))
81
+
82
+ # Initialize subject tracking if needed
83
+ if subject not in subject_scores:
84
+ subject_scores[subject] = 0
85
+ subject_counters[subject] = 0
86
+
87
+ # Count this question
88
+ subject_counters[subject] += 1
89
+ total_questions += 1
90
+
91
+ # Normalize for comparison
92
+ predict_lower = predict.lower().strip()
93
+ answer_lower = answer.lower().strip()
94
+
95
+ is_correct = False
96
+
97
+ # Check if this is a multiple choice question (answer is A-F or I)
98
+ if answer in ["A", "B", "C", "D", "E", "F", "I"]:
99
+ # Multiple choice extraction with prioritized patterns
100
+ patterns = [
101
+ (r"option\s+([a-f])\b", 10), # High priority
102
+ (r"answer\s+is:?\s+([a-f])\b", 10),
103
+ (r"choice\s+is:?\s+([a-f])\b", 10),
104
+ (r"correct\s+answer\s+is:?\s+([a-f])\b", 10),
105
+ (r"correct\s+option\s+is:?\s+\(?([a-f])\)?", 10),
106
+ (r"\(([a-f])\)", 8), # Medium priority
107
+ (r"^([a-f])[.:\)]\s", 8),
108
+ (r"\b([a-f])\b", 5), # Low priority - isolated letters
109
+ ]
110
+
111
+ best_match = None
112
+ best_priority = -1
113
+
114
+ # Try each pattern, keeping the highest priority match
115
+ for pattern, priority in patterns:
116
+ matches = re.findall(pattern, predict_lower, re.IGNORECASE)
117
+ if matches and priority > best_priority:
118
+ best_match = matches[0].lower()
119
+ best_priority = priority
120
+ # Stop early if we found a high-confidence pattern
121
+ if priority >= 10:
122
+ break
123
+
124
+ # Check if match is correct
125
+ if best_match and best_match == answer_lower:
126
+ is_correct = True
127
+ # Fallback: check first character
128
+ elif (
129
+ not best_match
130
+ and len(predict_lower) > 0
131
+ and predict_lower[0] in "abcdef"
132
+ ):
133
+ if predict_lower[0] == answer_lower:
134
+ is_correct = True
135
+
136
+ else:
137
+ # Open-ended question - check if answer appears in prediction
138
+ # Exact substring match (case-insensitive)
139
+ if answer_lower in predict_lower:
140
+ is_correct = True
141
+ # For numeric answers, try numeric comparison
142
+ elif answer.replace(".", "").replace("-", "").replace(",", "").isdigit():
143
+ numbers = re.findall(r"-?\d+\.?\d*", predict)
144
+ answer_num = normalize_number(answer)
145
+ for num_str in numbers:
146
+ try:
147
+ if abs(normalize_number(num_str) - answer_num) < 0.01:
148
+ is_correct = True
149
+ break
150
+ except:
151
+ pass
152
+ # Word-level match for text answers
153
+ else:
154
+ answer_words = set(answer_lower.split())
155
+ predict_words = set(predict_lower.split())
156
+ if answer_words and answer_words.issubset(predict_words):
157
+ is_correct = True
158
+
159
+ if is_correct:
160
+ total_correct += 1
161
+ subject_scores[subject] += 1
162
+ line["score"] = 1
163
+ else:
164
+ line["score"] = 0
165
+
166
+ # Calculate final scores
167
+ results = {}
168
+ results["overall_accuracy"] = (
169
+ float(total_correct) / float(total_questions) if total_questions > 0 else 0.0
170
+ )
171
+ results["total_correct"] = total_correct
172
+ results["total_questions"] = total_questions
173
+
174
+ # Calculate subject scores
175
+ for subject in sorted(subject_scores.keys()):
176
+ if subject_counters[subject] > 0:
177
+ results[f"subject_{subject}_accuracy"] = float(
178
+ subject_scores[subject]
179
+ ) / float(subject_counters[subject])
180
+ results[f"subject_{subject}_correct"] = subject_scores[subject]
181
+ results[f"subject_{subject}_total"] = subject_counters[subject]
182
+
183
+ # Print scores
184
+ print("\nMMMU Evaluation Results:")
185
+ print("=" * 80)
186
+ print(f"Model: {eval_file.split('/')[-1].split('_MMMU_')[0]}")
187
+ print(f"Total Questions: {total_questions}")
188
+ print(f"Total Correct: {total_correct}")
189
+ print(
190
+ f"Overall Accuracy: {results['overall_accuracy']:.4f} ({total_correct}/{total_questions})"
191
+ )
192
+ print("=" * 80)
193
+ print("Subject Breakdown:")
194
+ for subject in sorted(subject_scores.keys()):
195
+ acc = results.get(f"subject_{subject}_accuracy", 0.0)
196
+ correct = results.get(f"subject_{subject}_correct", 0)
197
+ total = results.get(f"subject_{subject}_total", 0)
198
+ print(f" {subject}: {acc:.4f} ({correct}/{total})")
199
+ print("=" * 80)
200
+
201
+ # Save results
202
+ score_pth = eval_file.replace(".csv", "_score.json")
203
+ with open(score_pth, "w") as f:
204
+ dump(results, f, indent=2)
205
+
206
+ with open(eval_file, "w", newline="", encoding="utf-8") as f:
207
+ if data:
208
+ writer = csv.DictWriter(f, fieldnames=data[0].keys())
209
+ writer.writeheader()
210
+ writer.writerows(data)
211
+
212
+ logging.info(
213
+ f"MMMU_eval successfully finished evaluating {eval_file}, results saved in {score_pth}"
214
+ )
215
+
216
+
217
+ def process_question(example):
218
+ """
219
+ Process MMMU question to format it properly.
220
+ MMMU questions may have options and images.
221
+ """
222
+ question = example.get("question", "")
223
+
224
+ # Add options if they exist
225
+ options = example.get("options", None)
226
+ options = re.sub(r'[\[\]"\']', "", options).split(", ") if options else None
227
+
228
+ if options and isinstance(options, list):
229
+ question += "\n\nOptions:"
230
+ for i, option in enumerate(options):
231
+ letter = chr(65 + i) # A, B, C, D, ...
232
+ question += f"\n{letter}. {option}"
233
+
234
+ # Remove <image n> tags from the question
235
+ question = re.sub(r"<image \d+>", "", question).strip()
236
+
237
+ return question
238
+
239
+
240
+ def get_images(example):
241
+ """
242
+ Extract images from MMMU example.
243
+ MMMU can have multiple images per question.
244
+ """
245
+ images = []
246
+
247
+ # MMMU dataset may have image_1, image_2, etc.
248
+ if "image" in example and example["image"] is not None:
249
+ try:
250
+ img = example["image"].convert("RGB")
251
+ images.append(img)
252
+ except Exception as e:
253
+ print(f"Warning: Could not process image - {e}")
254
+ else:
255
+ for i in range(0, 8): # Check up to 7 images
256
+ img_key = f"image_{i}"
257
+ if img_key in example and example[img_key] is not None:
258
+ try:
259
+ img = example[img_key].convert("RGB")
260
+ images.append(img)
261
+ except Exception as e:
262
+ print(f"Warning: Could not process image for key {img_key} - {e}")
263
+ continue
264
+ return images
265
+
266
+
267
+ def list_subjects():
268
+ """Print all available MMMU subjects."""
269
+ print("\n" + "=" * 80)
270
+ print("MMMU Pro Subjects (3 total)")
271
+ print("=" * 80)
272
+ for i, subject in enumerate(MMMU_PRO_SUBJECTS, 1):
273
+ print(f"{i:2d}. {subject}")
274
+ print("\n" + "=" * 80)
275
+ print("MMMU Available Subjects (30 total)")
276
+ print("=" * 80)
277
+ for i, subject in enumerate(MMMU_SUBJECTS, 1):
278
+ print(f"{i:2d}. {subject}")
279
+ print("=" * 80 + "\n")
280
+
281
+
282
+ def parse_arguments():
283
+ parser = argparse.ArgumentParser(
284
+ description="MMMU Evaluation - Massive Multi-discipline Multimodal Understanding",
285
+ epilog="Use --subset to evaluate a specific subject, or omit to evaluate all 30 subjects.",
286
+ )
287
+ parser.add_argument(
288
+ "--model",
289
+ type=str,
290
+ default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
291
+ help="Model path",
292
+ )
293
+ parser.add_argument("--adapter_path", type=str, default=None, help="Adapter path")
294
+ parser.add_argument("--dataset", type=str, default="MMMU/MMMU", help="Dataset path")
295
+ parser.add_argument(
296
+ "--split", type=str, default="validation", help="Split to use for evaluation"
297
+ )
298
+ parser.add_argument(
299
+ "--subset",
300
+ type=str,
301
+ default=None,
302
+ help=f"Subset to use - one of 30 subjects: {', '.join(MMMU_SUBJECTS[:5])}... (see SUBJECTS.md for full list)",
303
+ )
304
+ parser.add_argument(
305
+ "--streaming", action="store_true", help="Use streaming dataset loading"
306
+ )
307
+ parser.add_argument(
308
+ "--max-tokens",
309
+ type=int,
310
+ default=3000,
311
+ help="Maximum number of tokens to generate",
312
+ )
313
+ parser.add_argument(
314
+ "--temperature",
315
+ type=float,
316
+ default=0.0,
317
+ help="Temperature for sampling (0.0 for greedy)",
318
+ )
319
+ parser.add_argument(
320
+ "--top-p",
321
+ type=float,
322
+ default=0.9,
323
+ help="Top-p sampling parameter",
324
+ )
325
+ parser.add_argument(
326
+ "--repetition-penalty",
327
+ type=float,
328
+ default=1.0,
329
+ help="Repetition penalty parameter",
330
+ )
331
+ parser.add_argument(
332
+ "--resize-shape",
333
+ type=int,
334
+ nargs=2,
335
+ default=None,
336
+ help="Resize shape for the image",
337
+ )
338
+ parser.add_argument("--verbose", action="store_true", help="Verbose output")
339
+ parser.add_argument(
340
+ "--max-samples",
341
+ type=int,
342
+ default=None,
343
+ help="Maximum number of samples to evaluate (for testing)",
344
+ )
345
+ parser.add_argument(
346
+ "--list-subjects",
347
+ action="store_true",
348
+ help="List all 30 available subjects and exit",
349
+ )
350
+ parser.add_argument(
351
+ "--prediction-file",
352
+ type=str,
353
+ default=None,
354
+ help="Path to the prediction file",
355
+ )
356
+ parser.add_argument(
357
+ "--output-dir",
358
+ type=str,
359
+ default="results/mmmu",
360
+ help="Directory to save evaluation results",
361
+ )
362
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
363
+ return parser.parse_args()
364
+
365
+
366
+ def main():
367
+ args = parse_arguments()
368
+
369
+ random.seed(args.seed)
370
+
371
+ # Setup logging
372
+ logging.basicConfig(
373
+ level=logging.INFO if args.verbose else logging.WARNING,
374
+ format="%(asctime)s - %(levelname)s - %(message)s",
375
+ )
376
+
377
+ if "pro" in args.dataset.lower():
378
+ subjects = MMMU_PRO_SUBJECTS
379
+ else:
380
+ subjects = MMMU_SUBJECTS
381
+
382
+ if args.prediction_file:
383
+ # Load predictions from file
384
+ logging.info(f"\033[32mLoading predictions from {args.prediction_file}\033[0m")
385
+ results = []
386
+ with open(args.prediction_file, "r", encoding="utf-8") as f:
387
+ reader = csv.DictReader(f)
388
+ for row in reader:
389
+ results.append(row)
390
+
391
+ # Evaluate loaded predictions
392
+ MMMU_eval(results, args.prediction_file)
393
+ logging.info(f"\033[32mEvaluation complete\033[0m")
394
+ return
395
+
396
+ # Handle --list-subjects flag
397
+ if args.list_subjects:
398
+ list_subjects()
399
+ return
400
+
401
+ logging.info("\033[32mStarting MMMU Evaluation\033[0m")
402
+
403
+ # Validate subset if provided
404
+ if args.subset and args.subset not in subjects:
405
+ logging.error(f"\033[31mError: Invalid subset '{args.subset}'\033[0m")
406
+ logging.error(f"\033[31mValid subjects are: {', '.join(subjects)}\033[0m")
407
+ logging.error(f"\033[31mSee SUBJECTS.md for more details\033[0m")
408
+ return
409
+
410
+ logging.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
411
+
412
+ # Load dataset
413
+
414
+ if args.subset:
415
+ logging.info(f"\033[32mUsing subset: {args.subset}\033[0m")
416
+ datasets = {
417
+ args.subset: load_dataset(
418
+ args.dataset, args.subset, split=args.split, streaming=args.streaming
419
+ )
420
+ }
421
+ subset_name = args.subset
422
+ else:
423
+ logging.info(f"\033[32mEvaluating all 30 subjects\033[0m")
424
+ datasets = {}
425
+
426
+ for subject in subjects:
427
+ try:
428
+ datasets[subject] = load_dataset(
429
+ args.dataset,
430
+ name=subject,
431
+ split=args.split,
432
+ streaming=args.streaming,
433
+ )
434
+ except Exception as e:
435
+ logging.error(
436
+ f"\033[31mError loading dataset for {subject}: {e}\033[0m"
437
+ )
438
+ continue
439
+
440
+ subset_name = "all"
441
+
442
+ # Limit samples if specified
443
+ if args.max_samples:
444
+ datasets = {
445
+ k: v.select(range(min(args.max_samples, len(v))))
446
+ for k, v in datasets.items()
447
+ }
448
+ logging.info(f"\033[33mLimited to {len(datasets)} samples for testing\033[0m")
449
+
450
+ logging.info(f"\033[32mDataset subset size: {len(datasets.keys())}\033[0m")
451
+ logging.info(f"\033[32mLoading model from {args.model}\033[0m")
452
+
453
+ model, processor = load(
454
+ args.model, adapter_path=args.adapter_path, trust_remote_code=True
455
+ )
456
+ config = model.config
457
+ logging.info(f"\033[32mConfig: {config}\033[0m")
458
+
459
+ # Create results directory
460
+ model_name = args.model.split("/")[-1]
461
+ result_file = f"{args.output_dir}/{model_name}_MMMU_{subset_name}_{args.split}_predictions.csv"
462
+ os.makedirs(args.output_dir, exist_ok=True)
463
+
464
+ results = []
465
+ for subject, dataset in tqdm(datasets.items(), desc="Processing subjects"):
466
+ for idx, example in enumerate(tqdm(dataset, desc=f"Processing {subject}")):
467
+ question = process_question(example)
468
+
469
+ images = get_images(example)
470
+ try:
471
+ # Get prediction
472
+ prediction = inference(
473
+ model,
474
+ processor,
475
+ question,
476
+ images,
477
+ args.max_tokens,
478
+ args.temperature,
479
+ args.resize_shape,
480
+ args.verbose,
481
+ )
482
+ except Exception as e:
483
+ print(f"Error during inference:", question, images, "error message:", e)
484
+ prediction = ""
485
+
486
+ # Store result
487
+ result = {
488
+ "id": example.get("id", idx),
489
+ "question": question,
490
+ "answer": example.get("answer", ""),
491
+ "subfield": example.get("subfield", "Unknown"),
492
+ "topic_difficulty": example.get("topic_difficulty", "Unknown"),
493
+ "question_type": example.get("question_type", "Unknown"),
494
+ "prediction": prediction,
495
+ "subject": example.get("subject", None) or subject,
496
+ }
497
+ results.append(result)
498
+
499
+ # Show progress
500
+ if (idx + 1) % 10 == 0 or idx < 5:
501
+ logging.info(
502
+ f"Sample {idx + 1}: Answer={result['answer']}, Prediction={prediction[:50]}..."
503
+ )
504
+
505
+ # Print first few results
506
+ print("\nFirst 5 results:")
507
+ for i, result in enumerate(results[:5]):
508
+ print(
509
+ f"{i+1}. Question: {result['question'][:50]}... | Answer: {result['answer']} | Prediction: {result['prediction'][:50]}..."
510
+ )
511
+
512
+ # Save results to CSV
513
+ with open(result_file, "w", newline="", encoding="utf-8") as f:
514
+ if results:
515
+ writer = csv.DictWriter(f, fieldnames=results[0].keys())
516
+ writer.writeheader()
517
+ writer.writerows(results)
518
+
519
+ logging.info(f"\033[32mSaved results to {result_file}\033[0m")
520
+
521
+ # Evaluate results
522
+ MMMU_eval(results, result_file)
523
+
524
+ logging.info(f"\033[32mEvaluation complete\033[0m")
525
+
526
+
527
+ if __name__ == "__main__":
528
+ main()