fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/evals/mmmu.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
from json import dump
|
|
8
|
+
|
|
9
|
+
from datasets import load_dataset
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from mlx_vlm import load
|
|
13
|
+
from mlx_vlm.evals.utils import inference
|
|
14
|
+
|
|
15
|
+
# All 30 MMMU subjects (confirmed from dataset)
|
|
16
|
+
MMMU_SUBJECTS = [
|
|
17
|
+
"Accounting",
|
|
18
|
+
"Agriculture",
|
|
19
|
+
"Architecture_and_Engineering",
|
|
20
|
+
"Art",
|
|
21
|
+
"Art_Theory",
|
|
22
|
+
"Basic_Medical_Science",
|
|
23
|
+
"Biology",
|
|
24
|
+
"Chemistry",
|
|
25
|
+
"Clinical_Medicine",
|
|
26
|
+
"Computer_Science",
|
|
27
|
+
"Design",
|
|
28
|
+
"Diagnostics_and_Laboratory_Medicine",
|
|
29
|
+
"Economics",
|
|
30
|
+
"Electronics",
|
|
31
|
+
"Energy_and_Power",
|
|
32
|
+
"Finance",
|
|
33
|
+
"Geography",
|
|
34
|
+
"History",
|
|
35
|
+
"Literature",
|
|
36
|
+
"Manage",
|
|
37
|
+
"Marketing",
|
|
38
|
+
"Materials",
|
|
39
|
+
"Math",
|
|
40
|
+
"Mechanical_Engineering",
|
|
41
|
+
"Music",
|
|
42
|
+
"Pharmacy",
|
|
43
|
+
"Physics",
|
|
44
|
+
"Psychology",
|
|
45
|
+
"Public_Health",
|
|
46
|
+
"Sociology",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
MMMU_PRO_SUBJECTS = [
|
|
50
|
+
"vision",
|
|
51
|
+
"standard (10 options)",
|
|
52
|
+
"standard (4 options)",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def normalize_number(s):
|
|
57
|
+
"""Normalize numeric strings for comparison."""
|
|
58
|
+
try:
|
|
59
|
+
return float(str(s).strip().replace(",", ""))
|
|
60
|
+
except:
|
|
61
|
+
return str(s).strip()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def MMMU_eval(data: list, eval_file: str):
|
|
65
|
+
"""
|
|
66
|
+
Evaluate MMMU results by subject.
|
|
67
|
+
Handles both multiple choice (A-F) and open-ended questions.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# Track by subject
|
|
71
|
+
subject_scores = {}
|
|
72
|
+
subject_counters = {}
|
|
73
|
+
|
|
74
|
+
total_correct = 0
|
|
75
|
+
total_questions = 0
|
|
76
|
+
|
|
77
|
+
for line in data:
|
|
78
|
+
predict = str(line["prediction"])
|
|
79
|
+
answer = str(line["answer"])
|
|
80
|
+
subject = str(line.get("subject", "Unknown"))
|
|
81
|
+
|
|
82
|
+
# Initialize subject tracking if needed
|
|
83
|
+
if subject not in subject_scores:
|
|
84
|
+
subject_scores[subject] = 0
|
|
85
|
+
subject_counters[subject] = 0
|
|
86
|
+
|
|
87
|
+
# Count this question
|
|
88
|
+
subject_counters[subject] += 1
|
|
89
|
+
total_questions += 1
|
|
90
|
+
|
|
91
|
+
# Normalize for comparison
|
|
92
|
+
predict_lower = predict.lower().strip()
|
|
93
|
+
answer_lower = answer.lower().strip()
|
|
94
|
+
|
|
95
|
+
is_correct = False
|
|
96
|
+
|
|
97
|
+
# Check if this is a multiple choice question (answer is A-F or I)
|
|
98
|
+
if answer in ["A", "B", "C", "D", "E", "F", "I"]:
|
|
99
|
+
# Multiple choice extraction with prioritized patterns
|
|
100
|
+
patterns = [
|
|
101
|
+
(r"option\s+([a-f])\b", 10), # High priority
|
|
102
|
+
(r"answer\s+is:?\s+([a-f])\b", 10),
|
|
103
|
+
(r"choice\s+is:?\s+([a-f])\b", 10),
|
|
104
|
+
(r"correct\s+answer\s+is:?\s+([a-f])\b", 10),
|
|
105
|
+
(r"correct\s+option\s+is:?\s+\(?([a-f])\)?", 10),
|
|
106
|
+
(r"\(([a-f])\)", 8), # Medium priority
|
|
107
|
+
(r"^([a-f])[.:\)]\s", 8),
|
|
108
|
+
(r"\b([a-f])\b", 5), # Low priority - isolated letters
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
best_match = None
|
|
112
|
+
best_priority = -1
|
|
113
|
+
|
|
114
|
+
# Try each pattern, keeping the highest priority match
|
|
115
|
+
for pattern, priority in patterns:
|
|
116
|
+
matches = re.findall(pattern, predict_lower, re.IGNORECASE)
|
|
117
|
+
if matches and priority > best_priority:
|
|
118
|
+
best_match = matches[0].lower()
|
|
119
|
+
best_priority = priority
|
|
120
|
+
# Stop early if we found a high-confidence pattern
|
|
121
|
+
if priority >= 10:
|
|
122
|
+
break
|
|
123
|
+
|
|
124
|
+
# Check if match is correct
|
|
125
|
+
if best_match and best_match == answer_lower:
|
|
126
|
+
is_correct = True
|
|
127
|
+
# Fallback: check first character
|
|
128
|
+
elif (
|
|
129
|
+
not best_match
|
|
130
|
+
and len(predict_lower) > 0
|
|
131
|
+
and predict_lower[0] in "abcdef"
|
|
132
|
+
):
|
|
133
|
+
if predict_lower[0] == answer_lower:
|
|
134
|
+
is_correct = True
|
|
135
|
+
|
|
136
|
+
else:
|
|
137
|
+
# Open-ended question - check if answer appears in prediction
|
|
138
|
+
# Exact substring match (case-insensitive)
|
|
139
|
+
if answer_lower in predict_lower:
|
|
140
|
+
is_correct = True
|
|
141
|
+
# For numeric answers, try numeric comparison
|
|
142
|
+
elif answer.replace(".", "").replace("-", "").replace(",", "").isdigit():
|
|
143
|
+
numbers = re.findall(r"-?\d+\.?\d*", predict)
|
|
144
|
+
answer_num = normalize_number(answer)
|
|
145
|
+
for num_str in numbers:
|
|
146
|
+
try:
|
|
147
|
+
if abs(normalize_number(num_str) - answer_num) < 0.01:
|
|
148
|
+
is_correct = True
|
|
149
|
+
break
|
|
150
|
+
except:
|
|
151
|
+
pass
|
|
152
|
+
# Word-level match for text answers
|
|
153
|
+
else:
|
|
154
|
+
answer_words = set(answer_lower.split())
|
|
155
|
+
predict_words = set(predict_lower.split())
|
|
156
|
+
if answer_words and answer_words.issubset(predict_words):
|
|
157
|
+
is_correct = True
|
|
158
|
+
|
|
159
|
+
if is_correct:
|
|
160
|
+
total_correct += 1
|
|
161
|
+
subject_scores[subject] += 1
|
|
162
|
+
line["score"] = 1
|
|
163
|
+
else:
|
|
164
|
+
line["score"] = 0
|
|
165
|
+
|
|
166
|
+
# Calculate final scores
|
|
167
|
+
results = {}
|
|
168
|
+
results["overall_accuracy"] = (
|
|
169
|
+
float(total_correct) / float(total_questions) if total_questions > 0 else 0.0
|
|
170
|
+
)
|
|
171
|
+
results["total_correct"] = total_correct
|
|
172
|
+
results["total_questions"] = total_questions
|
|
173
|
+
|
|
174
|
+
# Calculate subject scores
|
|
175
|
+
for subject in sorted(subject_scores.keys()):
|
|
176
|
+
if subject_counters[subject] > 0:
|
|
177
|
+
results[f"subject_{subject}_accuracy"] = float(
|
|
178
|
+
subject_scores[subject]
|
|
179
|
+
) / float(subject_counters[subject])
|
|
180
|
+
results[f"subject_{subject}_correct"] = subject_scores[subject]
|
|
181
|
+
results[f"subject_{subject}_total"] = subject_counters[subject]
|
|
182
|
+
|
|
183
|
+
# Print scores
|
|
184
|
+
print("\nMMMU Evaluation Results:")
|
|
185
|
+
print("=" * 80)
|
|
186
|
+
print(f"Model: {eval_file.split('/')[-1].split('_MMMU_')[0]}")
|
|
187
|
+
print(f"Total Questions: {total_questions}")
|
|
188
|
+
print(f"Total Correct: {total_correct}")
|
|
189
|
+
print(
|
|
190
|
+
f"Overall Accuracy: {results['overall_accuracy']:.4f} ({total_correct}/{total_questions})"
|
|
191
|
+
)
|
|
192
|
+
print("=" * 80)
|
|
193
|
+
print("Subject Breakdown:")
|
|
194
|
+
for subject in sorted(subject_scores.keys()):
|
|
195
|
+
acc = results.get(f"subject_{subject}_accuracy", 0.0)
|
|
196
|
+
correct = results.get(f"subject_{subject}_correct", 0)
|
|
197
|
+
total = results.get(f"subject_{subject}_total", 0)
|
|
198
|
+
print(f" {subject}: {acc:.4f} ({correct}/{total})")
|
|
199
|
+
print("=" * 80)
|
|
200
|
+
|
|
201
|
+
# Save results
|
|
202
|
+
score_pth = eval_file.replace(".csv", "_score.json")
|
|
203
|
+
with open(score_pth, "w") as f:
|
|
204
|
+
dump(results, f, indent=2)
|
|
205
|
+
|
|
206
|
+
with open(eval_file, "w", newline="", encoding="utf-8") as f:
|
|
207
|
+
if data:
|
|
208
|
+
writer = csv.DictWriter(f, fieldnames=data[0].keys())
|
|
209
|
+
writer.writeheader()
|
|
210
|
+
writer.writerows(data)
|
|
211
|
+
|
|
212
|
+
logging.info(
|
|
213
|
+
f"MMMU_eval successfully finished evaluating {eval_file}, results saved in {score_pth}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def process_question(example):
|
|
218
|
+
"""
|
|
219
|
+
Process MMMU question to format it properly.
|
|
220
|
+
MMMU questions may have options and images.
|
|
221
|
+
"""
|
|
222
|
+
question = example.get("question", "")
|
|
223
|
+
|
|
224
|
+
# Add options if they exist
|
|
225
|
+
options = example.get("options", None)
|
|
226
|
+
options = re.sub(r'[\[\]"\']', "", options).split(", ") if options else None
|
|
227
|
+
|
|
228
|
+
if options and isinstance(options, list):
|
|
229
|
+
question += "\n\nOptions:"
|
|
230
|
+
for i, option in enumerate(options):
|
|
231
|
+
letter = chr(65 + i) # A, B, C, D, ...
|
|
232
|
+
question += f"\n{letter}. {option}"
|
|
233
|
+
|
|
234
|
+
# Remove <image n> tags from the question
|
|
235
|
+
question = re.sub(r"<image \d+>", "", question).strip()
|
|
236
|
+
|
|
237
|
+
return question
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def get_images(example):
|
|
241
|
+
"""
|
|
242
|
+
Extract images from MMMU example.
|
|
243
|
+
MMMU can have multiple images per question.
|
|
244
|
+
"""
|
|
245
|
+
images = []
|
|
246
|
+
|
|
247
|
+
# MMMU dataset may have image_1, image_2, etc.
|
|
248
|
+
if "image" in example and example["image"] is not None:
|
|
249
|
+
try:
|
|
250
|
+
img = example["image"].convert("RGB")
|
|
251
|
+
images.append(img)
|
|
252
|
+
except Exception as e:
|
|
253
|
+
print(f"Warning: Could not process image - {e}")
|
|
254
|
+
else:
|
|
255
|
+
for i in range(0, 8): # Check up to 7 images
|
|
256
|
+
img_key = f"image_{i}"
|
|
257
|
+
if img_key in example and example[img_key] is not None:
|
|
258
|
+
try:
|
|
259
|
+
img = example[img_key].convert("RGB")
|
|
260
|
+
images.append(img)
|
|
261
|
+
except Exception as e:
|
|
262
|
+
print(f"Warning: Could not process image for key {img_key} - {e}")
|
|
263
|
+
continue
|
|
264
|
+
return images
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def list_subjects():
|
|
268
|
+
"""Print all available MMMU subjects."""
|
|
269
|
+
print("\n" + "=" * 80)
|
|
270
|
+
print("MMMU Pro Subjects (3 total)")
|
|
271
|
+
print("=" * 80)
|
|
272
|
+
for i, subject in enumerate(MMMU_PRO_SUBJECTS, 1):
|
|
273
|
+
print(f"{i:2d}. {subject}")
|
|
274
|
+
print("\n" + "=" * 80)
|
|
275
|
+
print("MMMU Available Subjects (30 total)")
|
|
276
|
+
print("=" * 80)
|
|
277
|
+
for i, subject in enumerate(MMMU_SUBJECTS, 1):
|
|
278
|
+
print(f"{i:2d}. {subject}")
|
|
279
|
+
print("=" * 80 + "\n")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def parse_arguments():
|
|
283
|
+
parser = argparse.ArgumentParser(
|
|
284
|
+
description="MMMU Evaluation - Massive Multi-discipline Multimodal Understanding",
|
|
285
|
+
epilog="Use --subset to evaluate a specific subject, or omit to evaluate all 30 subjects.",
|
|
286
|
+
)
|
|
287
|
+
parser.add_argument(
|
|
288
|
+
"--model",
|
|
289
|
+
type=str,
|
|
290
|
+
default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
|
|
291
|
+
help="Model path",
|
|
292
|
+
)
|
|
293
|
+
parser.add_argument("--adapter_path", type=str, default=None, help="Adapter path")
|
|
294
|
+
parser.add_argument("--dataset", type=str, default="MMMU/MMMU", help="Dataset path")
|
|
295
|
+
parser.add_argument(
|
|
296
|
+
"--split", type=str, default="validation", help="Split to use for evaluation"
|
|
297
|
+
)
|
|
298
|
+
parser.add_argument(
|
|
299
|
+
"--subset",
|
|
300
|
+
type=str,
|
|
301
|
+
default=None,
|
|
302
|
+
help=f"Subset to use - one of 30 subjects: {', '.join(MMMU_SUBJECTS[:5])}... (see SUBJECTS.md for full list)",
|
|
303
|
+
)
|
|
304
|
+
parser.add_argument(
|
|
305
|
+
"--streaming", action="store_true", help="Use streaming dataset loading"
|
|
306
|
+
)
|
|
307
|
+
parser.add_argument(
|
|
308
|
+
"--max-tokens",
|
|
309
|
+
type=int,
|
|
310
|
+
default=3000,
|
|
311
|
+
help="Maximum number of tokens to generate",
|
|
312
|
+
)
|
|
313
|
+
parser.add_argument(
|
|
314
|
+
"--temperature",
|
|
315
|
+
type=float,
|
|
316
|
+
default=0.0,
|
|
317
|
+
help="Temperature for sampling (0.0 for greedy)",
|
|
318
|
+
)
|
|
319
|
+
parser.add_argument(
|
|
320
|
+
"--top-p",
|
|
321
|
+
type=float,
|
|
322
|
+
default=0.9,
|
|
323
|
+
help="Top-p sampling parameter",
|
|
324
|
+
)
|
|
325
|
+
parser.add_argument(
|
|
326
|
+
"--repetition-penalty",
|
|
327
|
+
type=float,
|
|
328
|
+
default=1.0,
|
|
329
|
+
help="Repetition penalty parameter",
|
|
330
|
+
)
|
|
331
|
+
parser.add_argument(
|
|
332
|
+
"--resize-shape",
|
|
333
|
+
type=int,
|
|
334
|
+
nargs=2,
|
|
335
|
+
default=None,
|
|
336
|
+
help="Resize shape for the image",
|
|
337
|
+
)
|
|
338
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
339
|
+
parser.add_argument(
|
|
340
|
+
"--max-samples",
|
|
341
|
+
type=int,
|
|
342
|
+
default=None,
|
|
343
|
+
help="Maximum number of samples to evaluate (for testing)",
|
|
344
|
+
)
|
|
345
|
+
parser.add_argument(
|
|
346
|
+
"--list-subjects",
|
|
347
|
+
action="store_true",
|
|
348
|
+
help="List all 30 available subjects and exit",
|
|
349
|
+
)
|
|
350
|
+
parser.add_argument(
|
|
351
|
+
"--prediction-file",
|
|
352
|
+
type=str,
|
|
353
|
+
default=None,
|
|
354
|
+
help="Path to the prediction file",
|
|
355
|
+
)
|
|
356
|
+
parser.add_argument(
|
|
357
|
+
"--output-dir",
|
|
358
|
+
type=str,
|
|
359
|
+
default="results/mmmu",
|
|
360
|
+
help="Directory to save evaluation results",
|
|
361
|
+
)
|
|
362
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
363
|
+
return parser.parse_args()
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def main():
|
|
367
|
+
args = parse_arguments()
|
|
368
|
+
|
|
369
|
+
random.seed(args.seed)
|
|
370
|
+
|
|
371
|
+
# Setup logging
|
|
372
|
+
logging.basicConfig(
|
|
373
|
+
level=logging.INFO if args.verbose else logging.WARNING,
|
|
374
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if "pro" in args.dataset.lower():
|
|
378
|
+
subjects = MMMU_PRO_SUBJECTS
|
|
379
|
+
else:
|
|
380
|
+
subjects = MMMU_SUBJECTS
|
|
381
|
+
|
|
382
|
+
if args.prediction_file:
|
|
383
|
+
# Load predictions from file
|
|
384
|
+
logging.info(f"\033[32mLoading predictions from {args.prediction_file}\033[0m")
|
|
385
|
+
results = []
|
|
386
|
+
with open(args.prediction_file, "r", encoding="utf-8") as f:
|
|
387
|
+
reader = csv.DictReader(f)
|
|
388
|
+
for row in reader:
|
|
389
|
+
results.append(row)
|
|
390
|
+
|
|
391
|
+
# Evaluate loaded predictions
|
|
392
|
+
MMMU_eval(results, args.prediction_file)
|
|
393
|
+
logging.info(f"\033[32mEvaluation complete\033[0m")
|
|
394
|
+
return
|
|
395
|
+
|
|
396
|
+
# Handle --list-subjects flag
|
|
397
|
+
if args.list_subjects:
|
|
398
|
+
list_subjects()
|
|
399
|
+
return
|
|
400
|
+
|
|
401
|
+
logging.info("\033[32mStarting MMMU Evaluation\033[0m")
|
|
402
|
+
|
|
403
|
+
# Validate subset if provided
|
|
404
|
+
if args.subset and args.subset not in subjects:
|
|
405
|
+
logging.error(f"\033[31mError: Invalid subset '{args.subset}'\033[0m")
|
|
406
|
+
logging.error(f"\033[31mValid subjects are: {', '.join(subjects)}\033[0m")
|
|
407
|
+
logging.error(f"\033[31mSee SUBJECTS.md for more details\033[0m")
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
logging.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
|
|
411
|
+
|
|
412
|
+
# Load dataset
|
|
413
|
+
|
|
414
|
+
if args.subset:
|
|
415
|
+
logging.info(f"\033[32mUsing subset: {args.subset}\033[0m")
|
|
416
|
+
datasets = {
|
|
417
|
+
args.subset: load_dataset(
|
|
418
|
+
args.dataset, args.subset, split=args.split, streaming=args.streaming
|
|
419
|
+
)
|
|
420
|
+
}
|
|
421
|
+
subset_name = args.subset
|
|
422
|
+
else:
|
|
423
|
+
logging.info(f"\033[32mEvaluating all 30 subjects\033[0m")
|
|
424
|
+
datasets = {}
|
|
425
|
+
|
|
426
|
+
for subject in subjects:
|
|
427
|
+
try:
|
|
428
|
+
datasets[subject] = load_dataset(
|
|
429
|
+
args.dataset,
|
|
430
|
+
name=subject,
|
|
431
|
+
split=args.split,
|
|
432
|
+
streaming=args.streaming,
|
|
433
|
+
)
|
|
434
|
+
except Exception as e:
|
|
435
|
+
logging.error(
|
|
436
|
+
f"\033[31mError loading dataset for {subject}: {e}\033[0m"
|
|
437
|
+
)
|
|
438
|
+
continue
|
|
439
|
+
|
|
440
|
+
subset_name = "all"
|
|
441
|
+
|
|
442
|
+
# Limit samples if specified
|
|
443
|
+
if args.max_samples:
|
|
444
|
+
datasets = {
|
|
445
|
+
k: v.select(range(min(args.max_samples, len(v))))
|
|
446
|
+
for k, v in datasets.items()
|
|
447
|
+
}
|
|
448
|
+
logging.info(f"\033[33mLimited to {len(datasets)} samples for testing\033[0m")
|
|
449
|
+
|
|
450
|
+
logging.info(f"\033[32mDataset subset size: {len(datasets.keys())}\033[0m")
|
|
451
|
+
logging.info(f"\033[32mLoading model from {args.model}\033[0m")
|
|
452
|
+
|
|
453
|
+
model, processor = load(
|
|
454
|
+
args.model, adapter_path=args.adapter_path, trust_remote_code=True
|
|
455
|
+
)
|
|
456
|
+
config = model.config
|
|
457
|
+
logging.info(f"\033[32mConfig: {config}\033[0m")
|
|
458
|
+
|
|
459
|
+
# Create results directory
|
|
460
|
+
model_name = args.model.split("/")[-1]
|
|
461
|
+
result_file = f"{args.output_dir}/{model_name}_MMMU_{subset_name}_{args.split}_predictions.csv"
|
|
462
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
463
|
+
|
|
464
|
+
results = []
|
|
465
|
+
for subject, dataset in tqdm(datasets.items(), desc="Processing subjects"):
|
|
466
|
+
for idx, example in enumerate(tqdm(dataset, desc=f"Processing {subject}")):
|
|
467
|
+
question = process_question(example)
|
|
468
|
+
|
|
469
|
+
images = get_images(example)
|
|
470
|
+
try:
|
|
471
|
+
# Get prediction
|
|
472
|
+
prediction = inference(
|
|
473
|
+
model,
|
|
474
|
+
processor,
|
|
475
|
+
question,
|
|
476
|
+
images,
|
|
477
|
+
args.max_tokens,
|
|
478
|
+
args.temperature,
|
|
479
|
+
args.resize_shape,
|
|
480
|
+
args.verbose,
|
|
481
|
+
)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
print(f"Error during inference:", question, images, "error message:", e)
|
|
484
|
+
prediction = ""
|
|
485
|
+
|
|
486
|
+
# Store result
|
|
487
|
+
result = {
|
|
488
|
+
"id": example.get("id", idx),
|
|
489
|
+
"question": question,
|
|
490
|
+
"answer": example.get("answer", ""),
|
|
491
|
+
"subfield": example.get("subfield", "Unknown"),
|
|
492
|
+
"topic_difficulty": example.get("topic_difficulty", "Unknown"),
|
|
493
|
+
"question_type": example.get("question_type", "Unknown"),
|
|
494
|
+
"prediction": prediction,
|
|
495
|
+
"subject": example.get("subject", None) or subject,
|
|
496
|
+
}
|
|
497
|
+
results.append(result)
|
|
498
|
+
|
|
499
|
+
# Show progress
|
|
500
|
+
if (idx + 1) % 10 == 0 or idx < 5:
|
|
501
|
+
logging.info(
|
|
502
|
+
f"Sample {idx + 1}: Answer={result['answer']}, Prediction={prediction[:50]}..."
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Print first few results
|
|
506
|
+
print("\nFirst 5 results:")
|
|
507
|
+
for i, result in enumerate(results[:5]):
|
|
508
|
+
print(
|
|
509
|
+
f"{i+1}. Question: {result['question'][:50]}... | Answer: {result['answer']} | Prediction: {result['prediction'][:50]}..."
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Save results to CSV
|
|
513
|
+
with open(result_file, "w", newline="", encoding="utf-8") as f:
|
|
514
|
+
if results:
|
|
515
|
+
writer = csv.DictWriter(f, fieldnames=results[0].keys())
|
|
516
|
+
writer.writeheader()
|
|
517
|
+
writer.writerows(results)
|
|
518
|
+
|
|
519
|
+
logging.info(f"\033[32mSaved results to {result_file}\033[0m")
|
|
520
|
+
|
|
521
|
+
# Evaluate results
|
|
522
|
+
MMMU_eval(results, result_file)
|
|
523
|
+
|
|
524
|
+
logging.info(f"\033[32mEvaluation complete\033[0m")
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
if __name__ == "__main__":
|
|
528
|
+
main()
|