fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,343 @@
1
+ import argparse
2
+ import csv
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ from copy import deepcopy
8
+ from json import dump
9
+
10
+ from datasets import load_dataset
11
+ from tqdm import tqdm
12
+
13
+ from mlx_vlm import load
14
+ from mlx_vlm.evals.utils import inference
15
+
16
+
17
+ def extract_answer(predict, answer):
18
+ """
19
+ Extracts the answer from the model's predictions.
20
+ predict: Model prediction text
21
+ answer: Ground truth answer (A, B, C, or D)
22
+ Returns: bool: True if answer matches, False otherwise
23
+ """
24
+ text = predict.lower().replace("\n", " ").strip()
25
+ answer_lower = answer.lower()
26
+
27
+ general_templates = [
28
+ r"^{0}\b",
29
+ r"^\({0}",
30
+ r"^option {0}\b",
31
+ r"\b{0}\s*[:\.\)]",
32
+ r"(?:^|\.|\s)\s*{0}\.",
33
+ r"\({0}\)",
34
+ r"option\s+{0}\b",
35
+ r"choice\s+{0}\b",
36
+ ]
37
+
38
+ concluding_templates = [
39
+ r"^the answer is {0}\b",
40
+ r"answer:\s*{0}\b",
41
+ r"answer\s+is\s+{0}\b",
42
+ r"correct\s+(?:answer|option|choice)\s+is:?\s+{0}\b",
43
+ r"the\s+answer\s+is\s+{0}\b",
44
+ r"is\s+{0}\s*:",
45
+ r"(?:therefore|thus|hence)[,\s]+(?:the\s+)?(?:answer\s+is\s+)?{0}\b",
46
+ r"(?:select|choose)\s+{0}\b",
47
+ r"it\s+is\s+{0}\b",
48
+ r"would\s+be\s+{0}\b",
49
+ r"\*\*(?:revised\s+)?answer\*\*:\s*{0}\b",
50
+ r"(?:correct\s+)?category\s+(?:for\s+this\s+image\s+)?is\s+\*\*{0}[:\s]",
51
+ ]
52
+
53
+ possible_answers = ["a", "b", "c", "d", "e"]
54
+ matches = []
55
+
56
+ for ans in possible_answers:
57
+ for pri, template_list in [(2, concluding_templates), (1, general_templates)]:
58
+ for template in template_list:
59
+ pattern = template.format(ans)
60
+ for match in re.finditer(pattern, text):
61
+ matches.append((match.end(), ans, pri))
62
+
63
+ if not matches:
64
+ return False
65
+
66
+ # Sort ascending by (-priority, -end_position) to prefer higher priority first, then latest position
67
+ matches.sort(key=lambda m: (-m[2], -m[0]))
68
+ latest_ans = matches[0][1]
69
+
70
+ return latest_ans == answer_lower
71
+
72
+
73
+ def MMStar_eval(data: list, eval_file: str):
74
+ MMStar_score_l2 = {
75
+ "coarse perception": {
76
+ "image scene and topic": 0,
77
+ "image style & quality": 0,
78
+ "image emotion": 0,
79
+ },
80
+ "fine-grained perception": {
81
+ "object counting": 0,
82
+ "recognition": 0,
83
+ "localization": 0,
84
+ },
85
+ "instance reasoning": {
86
+ "single-instance reasoning": 0,
87
+ "cross-instance attribute reasoning": 0,
88
+ "cross-instance relation reasoning": 0,
89
+ },
90
+ "logical reasoning": {
91
+ "code & sequence reasoning": 0,
92
+ "diagram reasoning": 0,
93
+ "common reasoning": 0,
94
+ },
95
+ "science & technology": {
96
+ "biology & chemistry & physics": 0,
97
+ "electronics & energy & mechanical eng.": 0,
98
+ "geography & earth science & agriculture": 0,
99
+ },
100
+ "math": {
101
+ "geometry": 0,
102
+ "numeric commonsense and calculation": 0,
103
+ "statistical reasoning": 0,
104
+ },
105
+ }
106
+ MMStar_counter = deepcopy(MMStar_score_l2)
107
+
108
+ for line in tqdm(data, desc="Evaluating"):
109
+ predict = str(line["prediction"])
110
+ answers = str(line["answer"])
111
+ category = str(line["category"])
112
+ l2_category = str(line["l2_category"])
113
+
114
+ MMStar_counter[category][l2_category] += 1
115
+
116
+ # Use comprehensive extraction
117
+ if extract_answer(predict, answers):
118
+ MMStar_score_l2[category][l2_category] += 1
119
+
120
+ line["score"] = 1
121
+ else:
122
+ line["score"] = 0
123
+
124
+ # Calculate scores
125
+ MMStar_score = {}
126
+ MMStar_score["final score"] = 0
127
+ total_correct = 0
128
+
129
+ for k, v in MMStar_score_l2.items():
130
+ cat_total = sum(MMStar_counter[k].values())
131
+ cat_correct = 0
132
+ for l2_k, l2_v in v.items():
133
+ count = MMStar_counter[k][l2_k]
134
+ if count > 0:
135
+ MMStar_score[f"{k}({l2_k})"] = float(l2_v) / float(count)
136
+ else:
137
+ MMStar_score[f"{k}({l2_k})"] = 0.0
138
+ cat_correct += l2_v
139
+ total_correct += l2_v
140
+ MMStar_score[k] = float(cat_correct) / cat_total if cat_total > 0 else 0.0
141
+ MMStar_score["final score"] += cat_correct
142
+
143
+ if len(data) > 0:
144
+ MMStar_score["final score"] = float(MMStar_score["final score"]) / float(
145
+ len(data)
146
+ )
147
+
148
+ # Print results
149
+ print("\n" + "=" * 80)
150
+ print("MMStar Evaluation Results")
151
+ print("=" * 80)
152
+ print(
153
+ f"\nFinal Score: {total_correct}/{len(data)} = {MMStar_score['final score']*100:.2f}%\n"
154
+ )
155
+
156
+ print("-" * 80)
157
+ print("Category Scores:")
158
+ print("-" * 80)
159
+ for category in [
160
+ "coarse perception",
161
+ "fine-grained perception",
162
+ "instance reasoning",
163
+ "logical reasoning",
164
+ "science & technology",
165
+ "math",
166
+ ]:
167
+ if category in MMStar_score:
168
+ cat_total = sum(MMStar_counter[category].values())
169
+ cat_correct = sum(MMStar_score_l2[category].values())
170
+ print(
171
+ f"{category:30s}: {cat_correct:4d}/{cat_total:4d} = {MMStar_score[category]*100:6.2f}%"
172
+ )
173
+
174
+ print("\n" + "-" * 80)
175
+ print("Subcategory Scores:")
176
+ print("-" * 80)
177
+ for category in [
178
+ "coarse perception",
179
+ "fine-grained perception",
180
+ "instance reasoning",
181
+ "logical reasoning",
182
+ "science & technology",
183
+ "math",
184
+ ]:
185
+ print(f"\n{category.upper()}:")
186
+ for l2_cat, score in MMStar_score_l2[category].items():
187
+ count = MMStar_counter[category][l2_cat]
188
+ pct = (score / count * 100) if count > 0 else 0
189
+ print(f" {l2_cat:55s}: {score:4d}/{count:4d} = {pct:6.2f}%")
190
+
191
+ print("\n" + "=" * 80)
192
+
193
+ # Save scores
194
+ score_pth = eval_file.replace(".csv", "_score.json")
195
+ with open(score_pth, "w") as f:
196
+ dump(MMStar_score, f, indent=2)
197
+
198
+ with open(eval_file, "w", newline="", encoding="utf-8") as f:
199
+ if data:
200
+ writer = csv.DictWriter(f, fieldnames=data[0].keys())
201
+ writer.writeheader()
202
+ writer.writerows(data)
203
+
204
+
205
+ def parse_arguments():
206
+ parser = argparse.ArgumentParser(description="MMStar Evaluation")
207
+ parser.add_argument(
208
+ "--model",
209
+ type=str,
210
+ default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
211
+ help="Model path",
212
+ )
213
+ parser.add_argument("--adapter-path", type=str, default=None, help="Adapter path")
214
+ parser.add_argument(
215
+ "--dataset", type=str, default="Lin-Chen/MMStar", help="Dataset path"
216
+ )
217
+ parser.add_argument(
218
+ "--split", type=str, default="val", help="Split to use for evaluation"
219
+ )
220
+ parser.add_argument(
221
+ "--streaming", action="store_true", help="Use streaming dataset loading"
222
+ )
223
+ parser.add_argument(
224
+ "--max-samples",
225
+ type=int,
226
+ default=None,
227
+ help="Maximum number of samples to evaluate (for debugging)",
228
+ )
229
+ parser.add_argument(
230
+ "--max-tokens",
231
+ type=int,
232
+ default=3000,
233
+ help="Maximum number of tokens to generate",
234
+ )
235
+ parser.add_argument(
236
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
237
+ )
238
+ parser.add_argument(
239
+ "--resize-shape",
240
+ type=int,
241
+ nargs=2,
242
+ default=None,
243
+ help="Resize shape for the image",
244
+ )
245
+ parser.add_argument("--verbose", action="store_true", help="Verbose output")
246
+ parser.add_argument(
247
+ "--prediction-file", type=str, default=None, help="Path to the prediction file"
248
+ )
249
+ parser.add_argument(
250
+ "--output-dir",
251
+ type=str,
252
+ default="results/mmstar",
253
+ help="Directory to save evaluation results",
254
+ )
255
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
256
+
257
+ return parser.parse_args()
258
+
259
+
260
+ def main():
261
+ args = parse_arguments()
262
+
263
+ random.seed(args.seed)
264
+
265
+ # Setup logging
266
+ logging.basicConfig(
267
+ level=logging.INFO if args.verbose else logging.WARNING,
268
+ format="%(asctime)s - %(levelname)s - %(message)s",
269
+ )
270
+
271
+ logging.info("\033[32mStarting MMStar Evaluation\033[0m")
272
+ if args.prediction_file:
273
+ logging.info(
274
+ f"\033[32mLoading predictions from {args.prediction_file} for evaluation\033[0m"
275
+ )
276
+ results = []
277
+ with open(args.prediction_file, "r", encoding="utf-8") as f:
278
+ reader = csv.DictReader(f)
279
+ results = [row for row in reader]
280
+ MMStar_eval(results, args.prediction_file)
281
+ logging.info(f"\033[32mEvaluation complete\033[0m")
282
+ return
283
+ logging.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
284
+ dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
285
+ if args.max_samples:
286
+ dataset = dataset.take(args.max_samples)
287
+
288
+ logging.info(f"\033[32mLoading model from {args.model}\033[0m")
289
+ model, processor = load(
290
+ args.model, adapter_path=args.adapter_path, trust_remote_code=True
291
+ )
292
+ config = model.config
293
+ logging.info(f"\033[32mConfig: {config}\033[0m")
294
+
295
+ result_file = f'{args.output_dir}/{args.model.split("/")[-1]}_{args.dataset.split("/")[-1]}_{args.split}_predictions.csv'
296
+ os.makedirs(args.output_dir, exist_ok=True)
297
+
298
+ results = []
299
+ for example in tqdm(dataset, desc="Running inference"):
300
+ question = example["question"]
301
+ image = example["image"].convert("RGB")
302
+ prediction = inference(
303
+ model,
304
+ processor,
305
+ question,
306
+ image,
307
+ args.max_tokens,
308
+ args.temperature,
309
+ args.resize_shape,
310
+ args.verbose,
311
+ )
312
+
313
+ results.append(
314
+ {
315
+ "question": question,
316
+ "answer": example["answer"],
317
+ "category": example["category"],
318
+ "l2_category": example["l2_category"],
319
+ "meta_info": example["meta_info"],
320
+ "prediction": prediction,
321
+ }
322
+ )
323
+
324
+ print("\nFirst 5 results:")
325
+ for i, result in enumerate(results[:5]):
326
+ print(
327
+ f"{i+1}. Question: {result['question'][:50]}... | Answer: {result['answer']} | Prediction: {result['prediction'][:50]}..."
328
+ )
329
+
330
+ with open(result_file, "w", newline="", encoding="utf-8") as f:
331
+ if results:
332
+ writer = csv.DictWriter(f, fieldnames=results[0].keys())
333
+ writer.writeheader()
334
+ writer.writerows(results)
335
+
336
+ MMStar_eval(results, result_file)
337
+
338
+ logging.info(f"\033[32mSaving results to {result_file}\033[0m")
339
+ logging.info(f"\033[32mEvaluation complete\033[0m")
340
+
341
+
342
+ if __name__ == "__main__":
343
+ main()