fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import csv
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import random
|
|
6
|
+
import traceback
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import mlx.core as mx
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from mlx_vlm import load
|
|
15
|
+
from mlx_vlm.evals.utils import inference
|
|
16
|
+
from mlx_vlm.generate import batch_generate
|
|
17
|
+
from mlx_vlm.sample_utils import top_p_sampling
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def process_question(sample: dict) -> str:
|
|
21
|
+
"""Format the question."""
|
|
22
|
+
return sample["question"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def normalize_answer(response: str, problem: dict) -> Optional[str]:
|
|
26
|
+
"""Normalize the model's response to extract the answer."""
|
|
27
|
+
if not response:
|
|
28
|
+
return None
|
|
29
|
+
return response.strip()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def evaluate_answer(prediction: Optional[str], ground_truth: list) -> bool:
|
|
33
|
+
"""Check if any ground truth answer is contained in the prediction."""
|
|
34
|
+
if prediction is None:
|
|
35
|
+
return False
|
|
36
|
+
pred = prediction.strip().lower()
|
|
37
|
+
return any(str(a).strip().lower() in pred for a in ground_truth)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def OCRBench_val(results_list, args, model_name, dataset="OCRBench"):
|
|
41
|
+
correct = 0
|
|
42
|
+
total = len(results_list)
|
|
43
|
+
category_scores = {}
|
|
44
|
+
for row in results_list:
|
|
45
|
+
ground_truth = row["ground_truth"]
|
|
46
|
+
if isinstance(ground_truth, str):
|
|
47
|
+
|
|
48
|
+
ground_truth = [a.strip() for a in ground_truth.split(";")]
|
|
49
|
+
prediction = row["prediction"]
|
|
50
|
+
|
|
51
|
+
is_correct = evaluate_answer(prediction, ground_truth)
|
|
52
|
+
row["correct"] = is_correct
|
|
53
|
+
if is_correct:
|
|
54
|
+
correct += 1
|
|
55
|
+
category = row["type"]
|
|
56
|
+
if category not in category_scores:
|
|
57
|
+
category_scores[category] = {"correct": 0, "total": 0}
|
|
58
|
+
category_scores[category]["total"] += 1
|
|
59
|
+
if is_correct:
|
|
60
|
+
category_scores[category]["correct"] += 1
|
|
61
|
+
|
|
62
|
+
accuracy = correct / total if total > 0 else 0
|
|
63
|
+
|
|
64
|
+
output_dir = Path(args.output_dir)
|
|
65
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
66
|
+
|
|
67
|
+
results_file = output_dir / f"{model_name}_{dataset}_{args.split}.csv"
|
|
68
|
+
|
|
69
|
+
fieldnames = [
|
|
70
|
+
"id",
|
|
71
|
+
"question",
|
|
72
|
+
"dataset",
|
|
73
|
+
"type",
|
|
74
|
+
"ground_truth",
|
|
75
|
+
"response",
|
|
76
|
+
"prediction",
|
|
77
|
+
"correct",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
with open(results_file, "w", newline="", encoding="utf-8") as f:
|
|
81
|
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
82
|
+
writer.writeheader()
|
|
83
|
+
for row in results_list:
|
|
84
|
+
out_row = row.copy()
|
|
85
|
+
if isinstance(out_row["ground_truth"], list):
|
|
86
|
+
out_row["ground_truth"] = "; ".join(map(str, out_row["ground_truth"]))
|
|
87
|
+
writer.writerow(out_row)
|
|
88
|
+
|
|
89
|
+
summary = {
|
|
90
|
+
"model": model_name,
|
|
91
|
+
"dataset": args.dataset,
|
|
92
|
+
"split": args.split,
|
|
93
|
+
"total_samples": total,
|
|
94
|
+
"correct": correct,
|
|
95
|
+
"accuracy": accuracy,
|
|
96
|
+
"category_scores": category_scores,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
summary_file = output_dir / f"{model_name}_{dataset}_{args.split}.json"
|
|
100
|
+
with open(summary_file, "w") as f:
|
|
101
|
+
json.dump(summary, f, indent=2)
|
|
102
|
+
|
|
103
|
+
print(f"\n{'='*80}")
|
|
104
|
+
print(f"{dataset} Evaluation Results")
|
|
105
|
+
print(f"{'='*80}")
|
|
106
|
+
print(f"Model: {summary['model']}")
|
|
107
|
+
print(f"Split: {args.split}")
|
|
108
|
+
print(f"Total Samples: {total}")
|
|
109
|
+
print(f"Correct: {correct}")
|
|
110
|
+
print(f"Accuracy: {accuracy*100:.2f}%")
|
|
111
|
+
|
|
112
|
+
if len(category_scores.items()) > 1:
|
|
113
|
+
print("\n" + "-" * 80)
|
|
114
|
+
print(f"Subcategory Scores:")
|
|
115
|
+
print(f"{'-'*80}")
|
|
116
|
+
for category, scores in category_scores.items():
|
|
117
|
+
cat_total = scores["total"]
|
|
118
|
+
cat_correct = scores["correct"]
|
|
119
|
+
cat_accuracy = cat_correct / cat_total if cat_total > 0 else 0
|
|
120
|
+
print(f" {category}: {cat_correct}/{cat_total} ({cat_accuracy*100:.2f}%)")
|
|
121
|
+
print(f"{'='*80}")
|
|
122
|
+
print(f"\nResults saved to {results_file} and {summary_file}")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def parse_args():
|
|
126
|
+
parser = argparse.ArgumentParser(
|
|
127
|
+
description="Evaluate models on OCRBench benchmark"
|
|
128
|
+
)
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--model",
|
|
131
|
+
type=str,
|
|
132
|
+
required=True,
|
|
133
|
+
help="The path to the MLX VLM model",
|
|
134
|
+
)
|
|
135
|
+
parser.add_argument(
|
|
136
|
+
"--adapter-path",
|
|
137
|
+
type=str,
|
|
138
|
+
help="Optional path for the trained adapter weights and config",
|
|
139
|
+
)
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--dataset",
|
|
142
|
+
type=str,
|
|
143
|
+
default="echo840/OCRBench",
|
|
144
|
+
help="Hugging Face dataset name",
|
|
145
|
+
)
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--split",
|
|
148
|
+
type=str,
|
|
149
|
+
default="test",
|
|
150
|
+
choices=["test"],
|
|
151
|
+
help="Dataset split to evaluate on",
|
|
152
|
+
)
|
|
153
|
+
parser.add_argument(
|
|
154
|
+
"--streaming",
|
|
155
|
+
action="store_true",
|
|
156
|
+
help="Use streaming dataset loading",
|
|
157
|
+
)
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--max-samples",
|
|
160
|
+
type=int,
|
|
161
|
+
default=None,
|
|
162
|
+
help="Maximum number of samples to evaluate (for debugging)",
|
|
163
|
+
)
|
|
164
|
+
parser.add_argument(
|
|
165
|
+
"--predictions-file",
|
|
166
|
+
type=str,
|
|
167
|
+
default=None,
|
|
168
|
+
help="File with predictions",
|
|
169
|
+
)
|
|
170
|
+
parser.add_argument(
|
|
171
|
+
"--output-dir",
|
|
172
|
+
type=str,
|
|
173
|
+
default="results/ocrbench",
|
|
174
|
+
help="Directory to save results",
|
|
175
|
+
)
|
|
176
|
+
parser.add_argument(
|
|
177
|
+
"--max-tokens",
|
|
178
|
+
type=int,
|
|
179
|
+
default=512,
|
|
180
|
+
help="Maximum number of tokens to generate",
|
|
181
|
+
)
|
|
182
|
+
parser.add_argument(
|
|
183
|
+
"--temperature",
|
|
184
|
+
type=float,
|
|
185
|
+
default=0.0,
|
|
186
|
+
help="Temperature for generation",
|
|
187
|
+
)
|
|
188
|
+
parser.add_argument(
|
|
189
|
+
"--verbose",
|
|
190
|
+
action="store_true",
|
|
191
|
+
help="Print detailed output for debugging",
|
|
192
|
+
)
|
|
193
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
194
|
+
parser.add_argument(
|
|
195
|
+
"--batch-size",
|
|
196
|
+
type=int,
|
|
197
|
+
default=1,
|
|
198
|
+
help="Batch size for generation (1 = sequential, >1 = batch generation)",
|
|
199
|
+
)
|
|
200
|
+
return parser.parse_args()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def create_sampler(temperature: float, top_p: float = 1.0):
|
|
204
|
+
"""Create a sampler function for batch generation.
|
|
205
|
+
|
|
206
|
+
For accuracy consistency across batch sizes, we use deterministic sampling
|
|
207
|
+
(temperature=0) by default. This ensures the same outputs regardless of batch size.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def sampler(logits: mx.array) -> mx.array:
|
|
211
|
+
if temperature == 0:
|
|
212
|
+
return mx.argmax(logits, axis=-1)
|
|
213
|
+
else:
|
|
214
|
+
if top_p > 0 and top_p < 1.0:
|
|
215
|
+
return top_p_sampling(logits, top_p, temperature)
|
|
216
|
+
else:
|
|
217
|
+
return mx.random.categorical(logits * (1 / temperature))
|
|
218
|
+
|
|
219
|
+
return sampler
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def process_batch(
|
|
223
|
+
model,
|
|
224
|
+
processor,
|
|
225
|
+
batch_samples,
|
|
226
|
+
args,
|
|
227
|
+
):
|
|
228
|
+
"""Process a batch of samples using batch_generate.
|
|
229
|
+
|
|
230
|
+
batch_generate now handles image size sorting internally to minimize
|
|
231
|
+
padding effects and maintain accuracy.
|
|
232
|
+
"""
|
|
233
|
+
prompts = []
|
|
234
|
+
images = []
|
|
235
|
+
sample_metadata = []
|
|
236
|
+
|
|
237
|
+
for sample in batch_samples:
|
|
238
|
+
pid = sample.get("id", str(sample.get("_idx", 0)))
|
|
239
|
+
|
|
240
|
+
# Load and process image
|
|
241
|
+
if "image" in sample and sample["image"]:
|
|
242
|
+
image = sample["image"].convert("RGB")
|
|
243
|
+
else:
|
|
244
|
+
logging.warning(f"No image for sample {pid}, skipping")
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
images.append(image)
|
|
248
|
+
|
|
249
|
+
# Create prompt
|
|
250
|
+
prompt = process_question(sample)
|
|
251
|
+
prompts.append(prompt)
|
|
252
|
+
|
|
253
|
+
# Store metadata for results
|
|
254
|
+
sample_metadata.append(
|
|
255
|
+
{
|
|
256
|
+
"id": pid,
|
|
257
|
+
"question": sample["question"],
|
|
258
|
+
"dataset": sample.get("dataset", ""),
|
|
259
|
+
"type": sample.get("type", ""),
|
|
260
|
+
"ground_truth": (
|
|
261
|
+
sample.get("answers", [])
|
|
262
|
+
if hasattr(sample, "answers")
|
|
263
|
+
else sample.get("answer", [])
|
|
264
|
+
),
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if not prompts:
|
|
269
|
+
return []
|
|
270
|
+
|
|
271
|
+
# Create sampler for deterministic output (temperature=0 by default)
|
|
272
|
+
sampler = create_sampler(args.temperature)
|
|
273
|
+
|
|
274
|
+
# Use batch_generate for processing
|
|
275
|
+
# batch_generate now handles image size sorting internally to avoid padding issues
|
|
276
|
+
batch_response = batch_generate(
|
|
277
|
+
model,
|
|
278
|
+
processor,
|
|
279
|
+
images=images,
|
|
280
|
+
prompts=prompts,
|
|
281
|
+
max_tokens=args.max_tokens,
|
|
282
|
+
sampler=sampler,
|
|
283
|
+
verbose=args.verbose,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Process results
|
|
287
|
+
results = []
|
|
288
|
+
for text, metadata in zip(batch_response.texts, sample_metadata):
|
|
289
|
+
response = text.strip()
|
|
290
|
+
prediction = normalize_answer(response, {"question": metadata["question"]})
|
|
291
|
+
|
|
292
|
+
result = {
|
|
293
|
+
**metadata,
|
|
294
|
+
"response": response,
|
|
295
|
+
"prediction": prediction,
|
|
296
|
+
"correct": False,
|
|
297
|
+
}
|
|
298
|
+
results.append(result)
|
|
299
|
+
|
|
300
|
+
if args.verbose:
|
|
301
|
+
logging.info(f"\nSample {metadata['id']}:")
|
|
302
|
+
logging.info(f"Question: {metadata['question']}")
|
|
303
|
+
logging.info(f"Response: {response}")
|
|
304
|
+
logging.info(f"Prediction: {prediction}")
|
|
305
|
+
logging.info(f"Ground Truth: {metadata['ground_truth']}")
|
|
306
|
+
|
|
307
|
+
return results
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def main():
|
|
311
|
+
args = parse_args()
|
|
312
|
+
|
|
313
|
+
random.seed(args.seed)
|
|
314
|
+
|
|
315
|
+
# Setup logging
|
|
316
|
+
logging.basicConfig(
|
|
317
|
+
level=logging.INFO if args.verbose else logging.WARNING,
|
|
318
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if args.predictions_file:
|
|
322
|
+
logging.info(
|
|
323
|
+
f"\033[32mLoading predictions from {args.predictions_file} for evaluation\033[0m"
|
|
324
|
+
)
|
|
325
|
+
with open(args.predictions_file, "r", encoding="utf-8") as f:
|
|
326
|
+
reader = csv.DictReader(f)
|
|
327
|
+
loaded_results = list(reader)
|
|
328
|
+
model_name = Path(args.predictions_file).stem.split("_OCRBench")[0]
|
|
329
|
+
dataset = (
|
|
330
|
+
"OCRBench-v2" if "OCRBench-v2" in args.predictions_file else "OCRBench"
|
|
331
|
+
)
|
|
332
|
+
OCRBench_val(loaded_results, args, model_name, dataset)
|
|
333
|
+
logging.info(f"\033[32mEvaluation complete\033[0m")
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
logging.info(f"Loading model from {args.model}")
|
|
337
|
+
model, processor = load(
|
|
338
|
+
args.model, adapter_path=args.adapter_path, trust_remote_code=True
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Load dataset
|
|
342
|
+
logging.info(f"Loading dataset {args.dataset}, split {args.split}")
|
|
343
|
+
dataset = load_dataset(args.dataset, split=args.split, streaming=args.streaming)
|
|
344
|
+
|
|
345
|
+
if args.max_samples:
|
|
346
|
+
dataset = dataset.take(args.max_samples)
|
|
347
|
+
|
|
348
|
+
# Convert to list for batching if streaming
|
|
349
|
+
if args.streaming:
|
|
350
|
+
dataset = list(dataset)
|
|
351
|
+
|
|
352
|
+
results = {}
|
|
353
|
+
batch_size = args.batch_size
|
|
354
|
+
|
|
355
|
+
if batch_size > 1:
|
|
356
|
+
# Batch generation mode
|
|
357
|
+
logging.info(f"Using batch generation with batch_size={batch_size}")
|
|
358
|
+
|
|
359
|
+
# Collect samples into batches
|
|
360
|
+
batch = []
|
|
361
|
+
all_samples = list(dataset) if hasattr(dataset, "__iter__") else dataset
|
|
362
|
+
|
|
363
|
+
# Add index to samples for tracking
|
|
364
|
+
for idx, sample in enumerate(all_samples):
|
|
365
|
+
sample["_idx"] = idx
|
|
366
|
+
|
|
367
|
+
for idx, sample in enumerate(
|
|
368
|
+
tqdm(all_samples, desc=f"Evaluating (batch_size={batch_size})")
|
|
369
|
+
):
|
|
370
|
+
batch.append(sample)
|
|
371
|
+
|
|
372
|
+
# Process batch when full or at the end
|
|
373
|
+
if len(batch) >= batch_size or idx == len(all_samples) - 1:
|
|
374
|
+
try:
|
|
375
|
+
batch_results = process_batch(model, processor, batch, args)
|
|
376
|
+
for result in batch_results:
|
|
377
|
+
results[result["id"]] = result
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logging.error(f"Error processing batch: {e}")
|
|
380
|
+
traceback.print_exc()
|
|
381
|
+
|
|
382
|
+
batch = []
|
|
383
|
+
|
|
384
|
+
# Clear memory after each batch
|
|
385
|
+
mx.clear_cache()
|
|
386
|
+
|
|
387
|
+
else:
|
|
388
|
+
# Sequential generation mode (original behavior)
|
|
389
|
+
for idx, sample in enumerate(tqdm(dataset, desc="Evaluating")):
|
|
390
|
+
pid = sample.get("id", str(idx))
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
# Load and process image
|
|
394
|
+
if "image" in sample and sample["image"]:
|
|
395
|
+
image = sample["image"].convert("RGB")
|
|
396
|
+
else:
|
|
397
|
+
logging.warning(f"No image for sample {pid}, skipping")
|
|
398
|
+
continue
|
|
399
|
+
|
|
400
|
+
# Create prompt
|
|
401
|
+
prompt = process_question(sample)
|
|
402
|
+
|
|
403
|
+
# Generate response
|
|
404
|
+
output = inference(
|
|
405
|
+
model,
|
|
406
|
+
processor,
|
|
407
|
+
prompt,
|
|
408
|
+
image=image,
|
|
409
|
+
max_tokens=args.max_tokens,
|
|
410
|
+
temperature=args.temperature,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
response = output.strip()
|
|
414
|
+
|
|
415
|
+
# Normalize answer
|
|
416
|
+
prediction = normalize_answer(response, sample)
|
|
417
|
+
|
|
418
|
+
# Store results (evaluation happens later)
|
|
419
|
+
results[pid] = {
|
|
420
|
+
"id": pid,
|
|
421
|
+
"question": sample["question"],
|
|
422
|
+
"dataset": sample.get("dataset", ""),
|
|
423
|
+
"type": sample.get("type", ""),
|
|
424
|
+
"ground_truth": (
|
|
425
|
+
sample.get("answers", [])
|
|
426
|
+
if hasattr(sample, "answers")
|
|
427
|
+
else sample.get("answer", [])
|
|
428
|
+
),
|
|
429
|
+
"response": response,
|
|
430
|
+
"prediction": prediction,
|
|
431
|
+
"correct": False,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
if args.verbose:
|
|
435
|
+
logging.info(f"\nSample {pid}:")
|
|
436
|
+
logging.info(f"Question: {sample['question']}")
|
|
437
|
+
logging.info(f"Response: {response}")
|
|
438
|
+
logging.info(f"Prediction: {prediction}")
|
|
439
|
+
logging.info(f"Ground Truth: {sample.get('answers', [])}")
|
|
440
|
+
|
|
441
|
+
except Exception as e:
|
|
442
|
+
traceback.print_exc()
|
|
443
|
+
logging.error(f"Error processing sample {pid}: {e}")
|
|
444
|
+
continue
|
|
445
|
+
|
|
446
|
+
results_list = list(results.values())
|
|
447
|
+
model_name = args.model.split("/")[-1]
|
|
448
|
+
dataset = args.dataset.split("/")[-1]
|
|
449
|
+
OCRBench_val(results_list, args, model_name, dataset)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
if __name__ == "__main__":
|
|
453
|
+
main()
|
mlx_vlm/evals/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from mlx_vlm import generate
|
|
2
|
+
from mlx_vlm.prompt_utils import apply_chat_template
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def inference(
|
|
6
|
+
model,
|
|
7
|
+
processor,
|
|
8
|
+
question,
|
|
9
|
+
image,
|
|
10
|
+
max_tokens=3000,
|
|
11
|
+
temperature=0.0,
|
|
12
|
+
resize_shape=None,
|
|
13
|
+
verbose=False,
|
|
14
|
+
):
|
|
15
|
+
"""Run inference on a single question."""
|
|
16
|
+
if image is None:
|
|
17
|
+
num_images = 0
|
|
18
|
+
elif isinstance(image, list):
|
|
19
|
+
num_images = len(image)
|
|
20
|
+
else:
|
|
21
|
+
num_images = 1
|
|
22
|
+
|
|
23
|
+
prompt = apply_chat_template(
|
|
24
|
+
processor, model.config, question, num_images=num_images
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
response = generate(
|
|
28
|
+
model,
|
|
29
|
+
processor,
|
|
30
|
+
prompt,
|
|
31
|
+
image=image,
|
|
32
|
+
max_tokens=max_tokens,
|
|
33
|
+
temperature=temperature,
|
|
34
|
+
resize_shape=resize_shape,
|
|
35
|
+
verbose=verbose,
|
|
36
|
+
)
|
|
37
|
+
return response.text
|