fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/evals/mmstar.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from json import dump
|
|
9
|
+
|
|
10
|
+
from datasets import load_dataset
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from mlx_vlm import load
|
|
14
|
+
from mlx_vlm.evals.utils import inference
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_answer(predict, answer):
|
|
18
|
+
"""
|
|
19
|
+
Extracts the answer from the model's predictions.
|
|
20
|
+
predict: Model prediction text
|
|
21
|
+
answer: Ground truth answer (A, B, C, or D)
|
|
22
|
+
Returns: bool: True if answer matches, False otherwise
|
|
23
|
+
"""
|
|
24
|
+
text = predict.lower().replace("\n", " ").strip()
|
|
25
|
+
answer_lower = answer.lower()
|
|
26
|
+
|
|
27
|
+
general_templates = [
|
|
28
|
+
r"^{0}\b",
|
|
29
|
+
r"^\({0}",
|
|
30
|
+
r"^option {0}\b",
|
|
31
|
+
r"\b{0}\s*[:\.\)]",
|
|
32
|
+
r"(?:^|\.|\s)\s*{0}\.",
|
|
33
|
+
r"\({0}\)",
|
|
34
|
+
r"option\s+{0}\b",
|
|
35
|
+
r"choice\s+{0}\b",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
concluding_templates = [
|
|
39
|
+
r"^the answer is {0}\b",
|
|
40
|
+
r"answer:\s*{0}\b",
|
|
41
|
+
r"answer\s+is\s+{0}\b",
|
|
42
|
+
r"correct\s+(?:answer|option|choice)\s+is:?\s+{0}\b",
|
|
43
|
+
r"the\s+answer\s+is\s+{0}\b",
|
|
44
|
+
r"is\s+{0}\s*:",
|
|
45
|
+
r"(?:therefore|thus|hence)[,\s]+(?:the\s+)?(?:answer\s+is\s+)?{0}\b",
|
|
46
|
+
r"(?:select|choose)\s+{0}\b",
|
|
47
|
+
r"it\s+is\s+{0}\b",
|
|
48
|
+
r"would\s+be\s+{0}\b",
|
|
49
|
+
r"\*\*(?:revised\s+)?answer\*\*:\s*{0}\b",
|
|
50
|
+
r"(?:correct\s+)?category\s+(?:for\s+this\s+image\s+)?is\s+\*\*{0}[:\s]",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
possible_answers = ["a", "b", "c", "d", "e"]
|
|
54
|
+
matches = []
|
|
55
|
+
|
|
56
|
+
for ans in possible_answers:
|
|
57
|
+
for pri, template_list in [(2, concluding_templates), (1, general_templates)]:
|
|
58
|
+
for template in template_list:
|
|
59
|
+
pattern = template.format(ans)
|
|
60
|
+
for match in re.finditer(pattern, text):
|
|
61
|
+
matches.append((match.end(), ans, pri))
|
|
62
|
+
|
|
63
|
+
if not matches:
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
# Sort ascending by (-priority, -end_position) to prefer higher priority first, then latest position
|
|
67
|
+
matches.sort(key=lambda m: (-m[2], -m[0]))
|
|
68
|
+
latest_ans = matches[0][1]
|
|
69
|
+
|
|
70
|
+
return latest_ans == answer_lower
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def MMStar_eval(data: list, eval_file: str):
|
|
74
|
+
MMStar_score_l2 = {
|
|
75
|
+
"coarse perception": {
|
|
76
|
+
"image scene and topic": 0,
|
|
77
|
+
"image style & quality": 0,
|
|
78
|
+
"image emotion": 0,
|
|
79
|
+
},
|
|
80
|
+
"fine-grained perception": {
|
|
81
|
+
"object counting": 0,
|
|
82
|
+
"recognition": 0,
|
|
83
|
+
"localization": 0,
|
|
84
|
+
},
|
|
85
|
+
"instance reasoning": {
|
|
86
|
+
"single-instance reasoning": 0,
|
|
87
|
+
"cross-instance attribute reasoning": 0,
|
|
88
|
+
"cross-instance relation reasoning": 0,
|
|
89
|
+
},
|
|
90
|
+
"logical reasoning": {
|
|
91
|
+
"code & sequence reasoning": 0,
|
|
92
|
+
"diagram reasoning": 0,
|
|
93
|
+
"common reasoning": 0,
|
|
94
|
+
},
|
|
95
|
+
"science & technology": {
|
|
96
|
+
"biology & chemistry & physics": 0,
|
|
97
|
+
"electronics & energy & mechanical eng.": 0,
|
|
98
|
+
"geography & earth science & agriculture": 0,
|
|
99
|
+
},
|
|
100
|
+
"math": {
|
|
101
|
+
"geometry": 0,
|
|
102
|
+
"numeric commonsense and calculation": 0,
|
|
103
|
+
"statistical reasoning": 0,
|
|
104
|
+
},
|
|
105
|
+
}
|
|
106
|
+
MMStar_counter = deepcopy(MMStar_score_l2)
|
|
107
|
+
|
|
108
|
+
for line in tqdm(data, desc="Evaluating"):
|
|
109
|
+
predict = str(line["prediction"])
|
|
110
|
+
answers = str(line["answer"])
|
|
111
|
+
category = str(line["category"])
|
|
112
|
+
l2_category = str(line["l2_category"])
|
|
113
|
+
|
|
114
|
+
MMStar_counter[category][l2_category] += 1
|
|
115
|
+
|
|
116
|
+
# Use comprehensive extraction
|
|
117
|
+
if extract_answer(predict, answers):
|
|
118
|
+
MMStar_score_l2[category][l2_category] += 1
|
|
119
|
+
|
|
120
|
+
line["score"] = 1
|
|
121
|
+
else:
|
|
122
|
+
line["score"] = 0
|
|
123
|
+
|
|
124
|
+
# Calculate scores
|
|
125
|
+
MMStar_score = {}
|
|
126
|
+
MMStar_score["final score"] = 0
|
|
127
|
+
total_correct = 0
|
|
128
|
+
|
|
129
|
+
for k, v in MMStar_score_l2.items():
|
|
130
|
+
cat_total = sum(MMStar_counter[k].values())
|
|
131
|
+
cat_correct = 0
|
|
132
|
+
for l2_k, l2_v in v.items():
|
|
133
|
+
count = MMStar_counter[k][l2_k]
|
|
134
|
+
if count > 0:
|
|
135
|
+
MMStar_score[f"{k}({l2_k})"] = float(l2_v) / float(count)
|
|
136
|
+
else:
|
|
137
|
+
MMStar_score[f"{k}({l2_k})"] = 0.0
|
|
138
|
+
cat_correct += l2_v
|
|
139
|
+
total_correct += l2_v
|
|
140
|
+
MMStar_score[k] = float(cat_correct) / cat_total if cat_total > 0 else 0.0
|
|
141
|
+
MMStar_score["final score"] += cat_correct
|
|
142
|
+
|
|
143
|
+
if len(data) > 0:
|
|
144
|
+
MMStar_score["final score"] = float(MMStar_score["final score"]) / float(
|
|
145
|
+
len(data)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Print results
|
|
149
|
+
print("\n" + "=" * 80)
|
|
150
|
+
print("MMStar Evaluation Results")
|
|
151
|
+
print("=" * 80)
|
|
152
|
+
print(
|
|
153
|
+
f"\nFinal Score: {total_correct}/{len(data)} = {MMStar_score['final score']*100:.2f}%\n"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
print("-" * 80)
|
|
157
|
+
print("Category Scores:")
|
|
158
|
+
print("-" * 80)
|
|
159
|
+
for category in [
|
|
160
|
+
"coarse perception",
|
|
161
|
+
"fine-grained perception",
|
|
162
|
+
"instance reasoning",
|
|
163
|
+
"logical reasoning",
|
|
164
|
+
"science & technology",
|
|
165
|
+
"math",
|
|
166
|
+
]:
|
|
167
|
+
if category in MMStar_score:
|
|
168
|
+
cat_total = sum(MMStar_counter[category].values())
|
|
169
|
+
cat_correct = sum(MMStar_score_l2[category].values())
|
|
170
|
+
print(
|
|
171
|
+
f"{category:30s}: {cat_correct:4d}/{cat_total:4d} = {MMStar_score[category]*100:6.2f}%"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
print("\n" + "-" * 80)
|
|
175
|
+
print("Subcategory Scores:")
|
|
176
|
+
print("-" * 80)
|
|
177
|
+
for category in [
|
|
178
|
+
"coarse perception",
|
|
179
|
+
"fine-grained perception",
|
|
180
|
+
"instance reasoning",
|
|
181
|
+
"logical reasoning",
|
|
182
|
+
"science & technology",
|
|
183
|
+
"math",
|
|
184
|
+
]:
|
|
185
|
+
print(f"\n{category.upper()}:")
|
|
186
|
+
for l2_cat, score in MMStar_score_l2[category].items():
|
|
187
|
+
count = MMStar_counter[category][l2_cat]
|
|
188
|
+
pct = (score / count * 100) if count > 0 else 0
|
|
189
|
+
print(f" {l2_cat:55s}: {score:4d}/{count:4d} = {pct:6.2f}%")
|
|
190
|
+
|
|
191
|
+
print("\n" + "=" * 80)
|
|
192
|
+
|
|
193
|
+
# Save scores
|
|
194
|
+
score_pth = eval_file.replace(".csv", "_score.json")
|
|
195
|
+
with open(score_pth, "w") as f:
|
|
196
|
+
dump(MMStar_score, f, indent=2)
|
|
197
|
+
|
|
198
|
+
with open(eval_file, "w", newline="", encoding="utf-8") as f:
|
|
199
|
+
if data:
|
|
200
|
+
writer = csv.DictWriter(f, fieldnames=data[0].keys())
|
|
201
|
+
writer.writeheader()
|
|
202
|
+
writer.writerows(data)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def parse_arguments():
|
|
206
|
+
parser = argparse.ArgumentParser(description="MMStar Evaluation")
|
|
207
|
+
parser.add_argument(
|
|
208
|
+
"--model",
|
|
209
|
+
type=str,
|
|
210
|
+
default="mlx-community/Qwen2-VL-2B-Instruct-bf16",
|
|
211
|
+
help="Model path",
|
|
212
|
+
)
|
|
213
|
+
parser.add_argument("--adapter-path", type=str, default=None, help="Adapter path")
|
|
214
|
+
parser.add_argument(
|
|
215
|
+
"--dataset", type=str, default="Lin-Chen/MMStar", help="Dataset path"
|
|
216
|
+
)
|
|
217
|
+
parser.add_argument(
|
|
218
|
+
"--split", type=str, default="val", help="Split to use for evaluation"
|
|
219
|
+
)
|
|
220
|
+
parser.add_argument(
|
|
221
|
+
"--streaming", action="store_true", help="Use streaming dataset loading"
|
|
222
|
+
)
|
|
223
|
+
parser.add_argument(
|
|
224
|
+
"--max-samples",
|
|
225
|
+
type=int,
|
|
226
|
+
default=None,
|
|
227
|
+
help="Maximum number of samples to evaluate (for debugging)",
|
|
228
|
+
)
|
|
229
|
+
parser.add_argument(
|
|
230
|
+
"--max-tokens",
|
|
231
|
+
type=int,
|
|
232
|
+
default=3000,
|
|
233
|
+
help="Maximum number of tokens to generate",
|
|
234
|
+
)
|
|
235
|
+
parser.add_argument(
|
|
236
|
+
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
|
237
|
+
)
|
|
238
|
+
parser.add_argument(
|
|
239
|
+
"--resize-shape",
|
|
240
|
+
type=int,
|
|
241
|
+
nargs=2,
|
|
242
|
+
default=None,
|
|
243
|
+
help="Resize shape for the image",
|
|
244
|
+
)
|
|
245
|
+
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
246
|
+
parser.add_argument(
|
|
247
|
+
"--prediction-file", type=str, default=None, help="Path to the prediction file"
|
|
248
|
+
)
|
|
249
|
+
parser.add_argument(
|
|
250
|
+
"--output-dir",
|
|
251
|
+
type=str,
|
|
252
|
+
default="results/mmstar",
|
|
253
|
+
help="Directory to save evaluation results",
|
|
254
|
+
)
|
|
255
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
256
|
+
|
|
257
|
+
return parser.parse_args()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def main():
|
|
261
|
+
args = parse_arguments()
|
|
262
|
+
|
|
263
|
+
random.seed(args.seed)
|
|
264
|
+
|
|
265
|
+
# Setup logging
|
|
266
|
+
logging.basicConfig(
|
|
267
|
+
level=logging.INFO if args.verbose else logging.WARNING,
|
|
268
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
logging.info("\033[32mStarting MMStar Evaluation\033[0m")
|
|
272
|
+
if args.prediction_file:
|
|
273
|
+
logging.info(
|
|
274
|
+
f"\033[32mLoading predictions from {args.prediction_file} for evaluation\033[0m"
|
|
275
|
+
)
|
|
276
|
+
results = []
|
|
277
|
+
with open(args.prediction_file, "r", encoding="utf-8") as f:
|
|
278
|
+
reader = csv.DictReader(f)
|
|
279
|
+
results = [row for row in reader]
|
|
280
|
+
MMStar_eval(results, args.prediction_file)
|
|
281
|
+
logging.info(f"\033[32mEvaluation complete\033[0m")
|
|
282
|
+
return
|
|
283
|
+
logging.info(f"\033[32mLoading dataset from {args.dataset}\033[0m")
|
|
284
|
+
dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
|
|
285
|
+
if args.max_samples:
|
|
286
|
+
dataset = dataset.take(args.max_samples)
|
|
287
|
+
|
|
288
|
+
logging.info(f"\033[32mLoading model from {args.model}\033[0m")
|
|
289
|
+
model, processor = load(
|
|
290
|
+
args.model, adapter_path=args.adapter_path, trust_remote_code=True
|
|
291
|
+
)
|
|
292
|
+
config = model.config
|
|
293
|
+
logging.info(f"\033[32mConfig: {config}\033[0m")
|
|
294
|
+
|
|
295
|
+
result_file = f'{args.output_dir}/{args.model.split("/")[-1]}_{args.dataset.split("/")[-1]}_{args.split}_predictions.csv'
|
|
296
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
297
|
+
|
|
298
|
+
results = []
|
|
299
|
+
for example in tqdm(dataset, desc="Running inference"):
|
|
300
|
+
question = example["question"]
|
|
301
|
+
image = example["image"].convert("RGB")
|
|
302
|
+
prediction = inference(
|
|
303
|
+
model,
|
|
304
|
+
processor,
|
|
305
|
+
question,
|
|
306
|
+
image,
|
|
307
|
+
args.max_tokens,
|
|
308
|
+
args.temperature,
|
|
309
|
+
args.resize_shape,
|
|
310
|
+
args.verbose,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
results.append(
|
|
314
|
+
{
|
|
315
|
+
"question": question,
|
|
316
|
+
"answer": example["answer"],
|
|
317
|
+
"category": example["category"],
|
|
318
|
+
"l2_category": example["l2_category"],
|
|
319
|
+
"meta_info": example["meta_info"],
|
|
320
|
+
"prediction": prediction,
|
|
321
|
+
}
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
print("\nFirst 5 results:")
|
|
325
|
+
for i, result in enumerate(results[:5]):
|
|
326
|
+
print(
|
|
327
|
+
f"{i+1}. Question: {result['question'][:50]}... | Answer: {result['answer']} | Prediction: {result['prediction'][:50]}..."
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
with open(result_file, "w", newline="", encoding="utf-8") as f:
|
|
331
|
+
if results:
|
|
332
|
+
writer = csv.DictWriter(f, fieldnames=results[0].keys())
|
|
333
|
+
writer.writeheader()
|
|
334
|
+
writer.writerows(results)
|
|
335
|
+
|
|
336
|
+
MMStar_eval(results, result_file)
|
|
337
|
+
|
|
338
|
+
logging.info(f"\033[32mSaving results to {result_file}\033[0m")
|
|
339
|
+
logging.info(f"\033[32mEvaluation complete\033[0m")
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
if __name__ == "__main__":
|
|
343
|
+
main()
|