fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import re
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
from mlx_vlm import load
|
|
16
|
+
from mlx_vlm.evals.utils import inference
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def process_question(sample: dict) -> str:
|
|
20
|
+
"""Format the question with choices if it's multiple choice."""
|
|
21
|
+
question = sample["query"]
|
|
22
|
+
|
|
23
|
+
if sample["question_type"] == "multi_choice" and sample["choices"]:
|
|
24
|
+
choices_text = "\n".join(
|
|
25
|
+
[f"({chr(65+i)}) {choice}" for i, choice in enumerate(sample["choices"])]
|
|
26
|
+
)
|
|
27
|
+
question = f"{question}\n{choices_text}"
|
|
28
|
+
return question
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def normalize_answer(response: str, problem: dict) -> Optional[str]:
|
|
32
|
+
"""Normalize the model's response to extract the answer."""
|
|
33
|
+
response = response.strip()
|
|
34
|
+
|
|
35
|
+
if not response:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
question_type = problem["question_type"]
|
|
39
|
+
answer_type = problem["answer_type"]
|
|
40
|
+
choices = problem.get("choices", [])
|
|
41
|
+
|
|
42
|
+
# For multiple choice, try to extract the letter
|
|
43
|
+
if question_type == "multi_choice":
|
|
44
|
+
# First, try to find boxed answers
|
|
45
|
+
boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
|
|
46
|
+
if boxed_match:
|
|
47
|
+
boxed_content = boxed_match.group(1)
|
|
48
|
+
# Check if it's a choice letter
|
|
49
|
+
letter_match = re.match(
|
|
50
|
+
r"^\(?([A-Z])\)?\.?$", boxed_content.strip().upper()
|
|
51
|
+
)
|
|
52
|
+
if letter_match:
|
|
53
|
+
letter = letter_match.group(1)
|
|
54
|
+
idx = ord(letter) - ord("A")
|
|
55
|
+
if 0 <= idx < len(choices):
|
|
56
|
+
return choices[idx]
|
|
57
|
+
# Check if it's directly one of the choices
|
|
58
|
+
if boxed_content.strip() in choices:
|
|
59
|
+
return boxed_content.strip()
|
|
60
|
+
|
|
61
|
+
# Try to find Chinese answer pattern "故选:X" or "故选X"
|
|
62
|
+
chinese_match = re.search(r"故选[::]\s*([A-Z])", response.upper())
|
|
63
|
+
if not chinese_match:
|
|
64
|
+
chinese_match = re.search(r"故选\s*([A-Z])", response.upper())
|
|
65
|
+
if chinese_match:
|
|
66
|
+
letter = chinese_match.group(1)
|
|
67
|
+
idx = ord(letter) - ord("A")
|
|
68
|
+
if 0 <= idx < len(choices):
|
|
69
|
+
return choices[idx]
|
|
70
|
+
|
|
71
|
+
# Try to find "the answer is X" or "answer: X" patterns near the end
|
|
72
|
+
answer_patterns = [
|
|
73
|
+
r"(?:the\s+)?answer\s+is\s+\(?([A-Z])\)?",
|
|
74
|
+
r"answer:\s*\(?([A-Z])\)?",
|
|
75
|
+
r"choose\s+\(?([A-Z])\)?",
|
|
76
|
+
r"option\s+\(?([A-Z])\)?",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
# Search from the end of the response (last 500 chars)
|
|
80
|
+
end_section = response[-500:] if len(response) > 500 else response
|
|
81
|
+
for pattern in answer_patterns:
|
|
82
|
+
matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
|
|
83
|
+
if matches:
|
|
84
|
+
# Take the last match
|
|
85
|
+
letter = matches[-1].group(1).upper()
|
|
86
|
+
idx = ord(letter) - ord("A")
|
|
87
|
+
if 0 <= idx < len(choices):
|
|
88
|
+
return choices[idx]
|
|
89
|
+
|
|
90
|
+
# Look for patterns like "(A)", "A)", "A.", "A" - prioritize from the end
|
|
91
|
+
matches = list(re.finditer(r"\(?([A-Z])\)?\.?", response.upper()))
|
|
92
|
+
if matches:
|
|
93
|
+
# Try the last few matches first
|
|
94
|
+
for match in reversed(matches[-5:]):
|
|
95
|
+
letter = match.group(1)
|
|
96
|
+
idx = ord(letter) - ord("A")
|
|
97
|
+
if 0 <= idx < len(choices):
|
|
98
|
+
return choices[idx]
|
|
99
|
+
|
|
100
|
+
# If the response is exactly one of the choices
|
|
101
|
+
if response in choices:
|
|
102
|
+
return response
|
|
103
|
+
|
|
104
|
+
# Try to find the most similar choice using edit distance
|
|
105
|
+
def distance(s1, s2):
|
|
106
|
+
if len(s1) < len(s2):
|
|
107
|
+
return distance(s2, s1)
|
|
108
|
+
if len(s2) == 0:
|
|
109
|
+
return len(s1)
|
|
110
|
+
|
|
111
|
+
previous_row = range(len(s2) + 1)
|
|
112
|
+
for i, c1 in enumerate(s1):
|
|
113
|
+
current_row = [i + 1]
|
|
114
|
+
for j, c2 in enumerate(s2):
|
|
115
|
+
insertions = previous_row[j + 1] + 1
|
|
116
|
+
deletions = current_row[j] + 1
|
|
117
|
+
substitutions = previous_row[j] + (c1 != c2)
|
|
118
|
+
current_row.append(min(insertions, deletions, substitutions))
|
|
119
|
+
previous_row = current_row
|
|
120
|
+
|
|
121
|
+
return previous_row[-1]
|
|
122
|
+
|
|
123
|
+
if choices:
|
|
124
|
+
distances = [
|
|
125
|
+
distance(response.lower(), choice.lower()) for choice in choices
|
|
126
|
+
]
|
|
127
|
+
return choices[distances.index(min(distances))]
|
|
128
|
+
|
|
129
|
+
# For integer answers
|
|
130
|
+
elif answer_type == "integer":
|
|
131
|
+
# First try to find boxed answer
|
|
132
|
+
boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
|
|
133
|
+
if boxed_match:
|
|
134
|
+
boxed_content = boxed_match.group(1)
|
|
135
|
+
# Remove commas from numbers
|
|
136
|
+
boxed_content = boxed_content.replace(",", "")
|
|
137
|
+
# Try scientific notation first
|
|
138
|
+
sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", boxed_content)
|
|
139
|
+
if sci_numbers:
|
|
140
|
+
try:
|
|
141
|
+
return str(int(float(sci_numbers[0])))
|
|
142
|
+
except:
|
|
143
|
+
pass
|
|
144
|
+
# Then regular numbers
|
|
145
|
+
numbers = re.findall(r"-?\d+", boxed_content)
|
|
146
|
+
if numbers:
|
|
147
|
+
try:
|
|
148
|
+
return str(int(numbers[0]))
|
|
149
|
+
except:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
# Try common answer patterns near the end
|
|
153
|
+
end_section = response[-500:] if len(response) > 500 else response
|
|
154
|
+
answer_patterns = [
|
|
155
|
+
r"(?:the\s+)?answer\s+is\s+(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
|
|
156
|
+
r"answer:\s*(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
|
|
157
|
+
r"(?:total|result|left|remaining)(?:\s+is|\s+are|:)\s*(-?[\d,]+\.?\d*[eE][+-]?\d+|-?[\d,]+)",
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
for pattern in answer_patterns:
|
|
161
|
+
matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
|
|
162
|
+
if matches:
|
|
163
|
+
try:
|
|
164
|
+
# Remove commas before converting
|
|
165
|
+
num_str = matches[-1].group(1).replace(",", "")
|
|
166
|
+
return str(int(float(num_str)))
|
|
167
|
+
except:
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
# Look for scientific notation anywhere in response
|
|
171
|
+
sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", response)
|
|
172
|
+
if sci_numbers:
|
|
173
|
+
try:
|
|
174
|
+
return str(int(float(sci_numbers[-1])))
|
|
175
|
+
except:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
# Fall back to finding all numbers (including comma-formatted) and taking the last one
|
|
179
|
+
# Match numbers with optional commas: 7,518 or 7518
|
|
180
|
+
numbers = re.findall(r"-?[\d,]+", response)
|
|
181
|
+
if numbers:
|
|
182
|
+
try:
|
|
183
|
+
# Remove commas and try the last number first
|
|
184
|
+
return str(int(numbers[-1].replace(",", "")))
|
|
185
|
+
except:
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
# For float answers
|
|
189
|
+
elif answer_type == "float":
|
|
190
|
+
precision = int(problem.get("precision", 2))
|
|
191
|
+
|
|
192
|
+
# First try to find boxed answer
|
|
193
|
+
boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
|
|
194
|
+
if boxed_match:
|
|
195
|
+
boxed_content = boxed_match.group(1)
|
|
196
|
+
# Try scientific notation first
|
|
197
|
+
sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", boxed_content)
|
|
198
|
+
if sci_numbers:
|
|
199
|
+
try:
|
|
200
|
+
return str(round(float(sci_numbers[0]), precision))
|
|
201
|
+
except:
|
|
202
|
+
pass
|
|
203
|
+
# Then regular numbers
|
|
204
|
+
numbers = re.findall(r"-?\d+\.?\d*", boxed_content)
|
|
205
|
+
if numbers:
|
|
206
|
+
try:
|
|
207
|
+
return str(round(float(numbers[0]), precision))
|
|
208
|
+
except:
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
# Try common answer patterns near the end
|
|
212
|
+
end_section = response[-500:] if len(response) > 500 else response
|
|
213
|
+
answer_patterns = [
|
|
214
|
+
r"(?:the\s+)?answer\s+is\s+(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)",
|
|
215
|
+
r"answer:\s*(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)",
|
|
216
|
+
r"d\s*=\s*(-?\d+\.?\d*[eE][+-]?\d+|-?\d+\.?\d*)", # For physics problems with d=
|
|
217
|
+
]
|
|
218
|
+
|
|
219
|
+
for pattern in answer_patterns:
|
|
220
|
+
matches = list(re.finditer(pattern, end_section, re.IGNORECASE))
|
|
221
|
+
if matches:
|
|
222
|
+
try:
|
|
223
|
+
return str(round(float(matches[-1].group(1)), precision))
|
|
224
|
+
except:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
# Look for scientific notation anywhere in response
|
|
228
|
+
sci_numbers = re.findall(r"-?\d+\.?\d*[eE][+-]?\d+", response)
|
|
229
|
+
if sci_numbers:
|
|
230
|
+
try:
|
|
231
|
+
return str(round(float(sci_numbers[-1]), precision))
|
|
232
|
+
except:
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
# Fall back to finding all numbers and taking the last one
|
|
236
|
+
numbers = re.findall(r"-?\d+\.?\d*", response)
|
|
237
|
+
if numbers:
|
|
238
|
+
try:
|
|
239
|
+
# Try the last number first
|
|
240
|
+
return str(round(float(numbers[-1]), precision))
|
|
241
|
+
except:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
return response
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def evaluate_answer(prediction: Optional[str], ground_truth: str) -> bool:
|
|
248
|
+
"""Check if the prediction matches the ground truth."""
|
|
249
|
+
if prediction is None:
|
|
250
|
+
return False
|
|
251
|
+
try:
|
|
252
|
+
# First check exact match
|
|
253
|
+
if str(prediction).strip() == str(ground_truth).strip():
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
# Handle numeric word representations
|
|
257
|
+
word_to_num = {
|
|
258
|
+
"zero": "0",
|
|
259
|
+
"one": "1",
|
|
260
|
+
"two": "2",
|
|
261
|
+
"three": "3",
|
|
262
|
+
"four": "4",
|
|
263
|
+
"five": "5",
|
|
264
|
+
"six": "6",
|
|
265
|
+
"seven": "7",
|
|
266
|
+
"eight": "8",
|
|
267
|
+
"nine": "9",
|
|
268
|
+
"ten": "10",
|
|
269
|
+
"eleven": "11",
|
|
270
|
+
"twelve": "12",
|
|
271
|
+
"thirteen": "13",
|
|
272
|
+
"fourteen": "14",
|
|
273
|
+
"fifteen": "15",
|
|
274
|
+
"sixteen": "16",
|
|
275
|
+
"seventeen": "17",
|
|
276
|
+
"eighteen": "18",
|
|
277
|
+
"nineteen": "19",
|
|
278
|
+
"twenty": "20",
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
pred_normalized = str(prediction).strip().lower()
|
|
282
|
+
gt_normalized = str(ground_truth).strip().lower()
|
|
283
|
+
|
|
284
|
+
# Convert words to numbers
|
|
285
|
+
if pred_normalized in word_to_num:
|
|
286
|
+
pred_normalized = word_to_num[pred_normalized]
|
|
287
|
+
if gt_normalized in word_to_num:
|
|
288
|
+
gt_normalized = word_to_num[gt_normalized]
|
|
289
|
+
|
|
290
|
+
return pred_normalized == gt_normalized
|
|
291
|
+
except:
|
|
292
|
+
return False
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def parse_args():
|
|
296
|
+
parser = argparse.ArgumentParser(
|
|
297
|
+
description="Evaluate models on MathVista benchmark"
|
|
298
|
+
)
|
|
299
|
+
parser.add_argument(
|
|
300
|
+
"--model",
|
|
301
|
+
type=str,
|
|
302
|
+
required=True,
|
|
303
|
+
help="The path to the MLX VLM model",
|
|
304
|
+
)
|
|
305
|
+
parser.add_argument(
|
|
306
|
+
"--adapter-path",
|
|
307
|
+
type=str,
|
|
308
|
+
help="Optional path for the trained adapter weights and config",
|
|
309
|
+
)
|
|
310
|
+
parser.add_argument(
|
|
311
|
+
"--dataset",
|
|
312
|
+
type=str,
|
|
313
|
+
default="AI4Math/MathVista",
|
|
314
|
+
help="Hugging Face dataset name",
|
|
315
|
+
)
|
|
316
|
+
parser.add_argument(
|
|
317
|
+
"--split",
|
|
318
|
+
type=str,
|
|
319
|
+
default="testmini",
|
|
320
|
+
choices=["testmini", "test"],
|
|
321
|
+
help="Dataset split to evaluate on",
|
|
322
|
+
)
|
|
323
|
+
parser.add_argument(
|
|
324
|
+
"--streaming",
|
|
325
|
+
action="store_true",
|
|
326
|
+
help="Use streaming dataset loading",
|
|
327
|
+
)
|
|
328
|
+
parser.add_argument(
|
|
329
|
+
"--max-samples",
|
|
330
|
+
type=int,
|
|
331
|
+
default=None,
|
|
332
|
+
help="Maximum number of samples to evaluate (for debugging)",
|
|
333
|
+
)
|
|
334
|
+
parser.add_argument(
|
|
335
|
+
"--output-dir",
|
|
336
|
+
type=str,
|
|
337
|
+
default="results/mathvista",
|
|
338
|
+
help="Directory to save results",
|
|
339
|
+
)
|
|
340
|
+
parser.add_argument(
|
|
341
|
+
"--max-tokens",
|
|
342
|
+
type=int,
|
|
343
|
+
default=512,
|
|
344
|
+
help="Maximum number of tokens to generate",
|
|
345
|
+
)
|
|
346
|
+
parser.add_argument(
|
|
347
|
+
"--temperature",
|
|
348
|
+
type=float,
|
|
349
|
+
default=0.0,
|
|
350
|
+
help="Temperature for generation",
|
|
351
|
+
)
|
|
352
|
+
parser.add_argument(
|
|
353
|
+
"--verbose",
|
|
354
|
+
action="store_true",
|
|
355
|
+
help="Print detailed output for debugging",
|
|
356
|
+
)
|
|
357
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
358
|
+
return parser.parse_args()
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def main():
|
|
362
|
+
args = parse_args()
|
|
363
|
+
|
|
364
|
+
random.seed(args.seed)
|
|
365
|
+
|
|
366
|
+
# Setup logging
|
|
367
|
+
logging.basicConfig(
|
|
368
|
+
level=logging.INFO if args.verbose else logging.WARNING,
|
|
369
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
logging.info(f"Loading model from {args.model}")
|
|
373
|
+
model, processor = load(
|
|
374
|
+
args.model, adapter_path=args.adapter_path, trust_remote_code=True
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Load dataset
|
|
378
|
+
logging.info(f"Loading dataset {args.dataset}, split {args.split}")
|
|
379
|
+
dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
|
|
380
|
+
|
|
381
|
+
if args.max_samples:
|
|
382
|
+
dataset = dataset.select(range(min(args.max_samples, len(dataset))))
|
|
383
|
+
|
|
384
|
+
# Create output directory
|
|
385
|
+
output_dir = Path(args.output_dir)
|
|
386
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
387
|
+
|
|
388
|
+
results = {}
|
|
389
|
+
category_scores = {}
|
|
390
|
+
correct = 0
|
|
391
|
+
total = 0
|
|
392
|
+
|
|
393
|
+
# Evaluate each sample
|
|
394
|
+
for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")):
|
|
395
|
+
pid = sample["pid"]
|
|
396
|
+
|
|
397
|
+
try:
|
|
398
|
+
# Load and process image
|
|
399
|
+
if "decoded_image" in sample and sample["decoded_image"]:
|
|
400
|
+
if isinstance(sample["decoded_image"], str):
|
|
401
|
+
image_path = sample["decoded_image"]
|
|
402
|
+
if os.path.exists(image_path):
|
|
403
|
+
image = Image.open(image_path).convert("RGB")
|
|
404
|
+
else:
|
|
405
|
+
logging.warning(
|
|
406
|
+
f"Image not found: {image_path}, skipping sample {pid}"
|
|
407
|
+
)
|
|
408
|
+
continue
|
|
409
|
+
else:
|
|
410
|
+
# Image is already loaded
|
|
411
|
+
image = sample["decoded_image"].convert("RGB")
|
|
412
|
+
else:
|
|
413
|
+
logging.warning(f"No image for sample {pid}, skipping")
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
# Create prompt
|
|
417
|
+
prompt = process_question(sample)
|
|
418
|
+
|
|
419
|
+
# Generate response
|
|
420
|
+
output = inference(
|
|
421
|
+
model,
|
|
422
|
+
processor,
|
|
423
|
+
prompt,
|
|
424
|
+
image=image,
|
|
425
|
+
max_tokens=args.max_tokens,
|
|
426
|
+
temperature=args.temperature,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
response = output.strip()
|
|
430
|
+
|
|
431
|
+
# Normalize answer
|
|
432
|
+
prediction = normalize_answer(response, sample)
|
|
433
|
+
|
|
434
|
+
# Evaluate
|
|
435
|
+
ground_truth = sample.get("answer", "")
|
|
436
|
+
if args.split == "testmini" and ground_truth:
|
|
437
|
+
is_correct = evaluate_answer(prediction, ground_truth)
|
|
438
|
+
if is_correct:
|
|
439
|
+
correct += 1
|
|
440
|
+
else:
|
|
441
|
+
is_correct = None
|
|
442
|
+
|
|
443
|
+
total += 1
|
|
444
|
+
|
|
445
|
+
# Store results
|
|
446
|
+
results[pid] = {
|
|
447
|
+
"pid": pid,
|
|
448
|
+
"question": sample["question"],
|
|
449
|
+
"query": sample["query"],
|
|
450
|
+
"question_type": sample["question_type"],
|
|
451
|
+
"answer_type": sample["answer_type"],
|
|
452
|
+
"choices": sample.get("choices", []),
|
|
453
|
+
"unit": sample.get("unit", ""),
|
|
454
|
+
"precision": sample.get("precision", 0),
|
|
455
|
+
"ground_truth": ground_truth,
|
|
456
|
+
"response": response,
|
|
457
|
+
"prediction": prediction,
|
|
458
|
+
"correct": is_correct,
|
|
459
|
+
"metadata": sample.get("metadata", {}),
|
|
460
|
+
}
|
|
461
|
+
# Track category-wise performance
|
|
462
|
+
category = sample.get("metadata", {}).get("category", "unknown")
|
|
463
|
+
if category not in category_scores:
|
|
464
|
+
category_scores[category] = {"correct": 0, "total": 0}
|
|
465
|
+
|
|
466
|
+
category_scores[category]["total"] += 1
|
|
467
|
+
if is_correct:
|
|
468
|
+
category_scores[category]["correct"] += 1
|
|
469
|
+
|
|
470
|
+
if args.verbose:
|
|
471
|
+
logging.info(f"\nSample {pid}:")
|
|
472
|
+
logging.info(f"Question: {sample['question']}")
|
|
473
|
+
logging.info(f"Response: {response}")
|
|
474
|
+
logging.info(f"Prediction: {prediction}")
|
|
475
|
+
logging.info(f"Ground Truth: {ground_truth}")
|
|
476
|
+
logging.info(f"Correct: {is_correct}")
|
|
477
|
+
|
|
478
|
+
except Exception as e:
|
|
479
|
+
logging.error(f"Error processing sample {pid}: {e}")
|
|
480
|
+
continue
|
|
481
|
+
|
|
482
|
+
# Calculate accuracy if applicable
|
|
483
|
+
if args.split == "testmini":
|
|
484
|
+
accuracy = correct / total if total > 0 else 0
|
|
485
|
+
else:
|
|
486
|
+
accuracy = None
|
|
487
|
+
correct = None
|
|
488
|
+
|
|
489
|
+
# Save results
|
|
490
|
+
model_name = args.model.split("/")[-1]
|
|
491
|
+
results_file = output_dir / f"{model_name}_MathVista_{args.split}.csv"
|
|
492
|
+
|
|
493
|
+
# Convert results to list of dictionaries for CSV writing
|
|
494
|
+
fieldnames = [
|
|
495
|
+
"pid",
|
|
496
|
+
"question",
|
|
497
|
+
"query",
|
|
498
|
+
"question_type",
|
|
499
|
+
"answer_type",
|
|
500
|
+
"choices",
|
|
501
|
+
"unit",
|
|
502
|
+
"precision",
|
|
503
|
+
"ground_truth",
|
|
504
|
+
"response",
|
|
505
|
+
"prediction",
|
|
506
|
+
"correct",
|
|
507
|
+
"metadata",
|
|
508
|
+
]
|
|
509
|
+
|
|
510
|
+
with open(results_file, "w", newline="", encoding="utf-8") as f:
|
|
511
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
512
|
+
writer.writeheader()
|
|
513
|
+
|
|
514
|
+
for result in results.values():
|
|
515
|
+
# Convert list and dict fields to strings for CSV
|
|
516
|
+
row = result.copy()
|
|
517
|
+
if isinstance(row.get("choices"), list):
|
|
518
|
+
row["choices"] = "; ".join(row["choices"])
|
|
519
|
+
if isinstance(row.get("metadata"), dict):
|
|
520
|
+
row["metadata"] = json.dumps(row["metadata"])
|
|
521
|
+
writer.writerow(row)
|
|
522
|
+
|
|
523
|
+
# Save summary
|
|
524
|
+
summary = {
|
|
525
|
+
"model": args.model,
|
|
526
|
+
"dataset": args.dataset,
|
|
527
|
+
"split": args.split,
|
|
528
|
+
"total_samples": total,
|
|
529
|
+
"category_scores": category_scores,
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
if accuracy is not None:
|
|
533
|
+
summary["correct"] = correct
|
|
534
|
+
summary["accuracy"] = accuracy
|
|
535
|
+
|
|
536
|
+
summary_file = output_dir / f"{model_name}_MathVista_{args.split}.json"
|
|
537
|
+
with open(summary_file, "w") as f:
|
|
538
|
+
json.dump(summary, f, indent=2)
|
|
539
|
+
|
|
540
|
+
print(f"\n{'='*80}")
|
|
541
|
+
print("MathVista Evaluation Results")
|
|
542
|
+
print(f"{'='*80}")
|
|
543
|
+
print(f"Model: {args.model}")
|
|
544
|
+
print(f"Split: {args.split}")
|
|
545
|
+
print(f"Total Samples: {total}")
|
|
546
|
+
if accuracy is not None:
|
|
547
|
+
print(f"Correct: {correct}")
|
|
548
|
+
print(f"Accuracy: {accuracy*100:.2f}%")
|
|
549
|
+
else:
|
|
550
|
+
print("Accuracy not computed for this split (no ground truth labels)")
|
|
551
|
+
|
|
552
|
+
print("\n" + "-" * 80)
|
|
553
|
+
print(f"Subcategory Scores:")
|
|
554
|
+
print(f"{'-'*80}")
|
|
555
|
+
for category, scores in category_scores.items():
|
|
556
|
+
cat_total = scores["total"]
|
|
557
|
+
cat_correct = scores["correct"]
|
|
558
|
+
cat_accuracy = cat_correct / cat_total if cat_total > 0 else 0
|
|
559
|
+
print(f" {category}: {cat_correct}/{cat_total} ({cat_accuracy*100:.2f}%)")
|
|
560
|
+
print(f"{'='*80}")
|
|
561
|
+
print(f"\nResults saved to {results_file} and {summary_file}")
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
if __name__ == "__main__":
|
|
565
|
+
main()
|