evalscope 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +1 -1
- evalscope/arguments.py +73 -0
- evalscope/backend/base.py +5 -1
- evalscope/backend/opencompass/api_meta_template.py +8 -14
- evalscope/backend/opencompass/backend_manager.py +24 -15
- evalscope/backend/opencompass/tasks/eval_api.py +1 -6
- evalscope/backend/opencompass/tasks/eval_datasets.py +26 -28
- evalscope/backend/rag_eval/__init__.py +3 -3
- evalscope/backend/rag_eval/backend_manager.py +21 -25
- evalscope/backend/rag_eval/clip_benchmark/__init__.py +1 -1
- evalscope/backend/rag_eval/clip_benchmark/arguments.py +6 -6
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +62 -79
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +29 -43
- evalscope/backend/rag_eval/clip_benchmark/tasks/image_caption.py +20 -22
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +16 -23
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_retrieval.py +14 -35
- evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +69 -90
- evalscope/backend/rag_eval/cmteb/__init__.py +3 -3
- evalscope/backend/rag_eval/cmteb/arguments.py +25 -27
- evalscope/backend/rag_eval/cmteb/base.py +22 -23
- evalscope/backend/rag_eval/cmteb/task_template.py +15 -17
- evalscope/backend/rag_eval/cmteb/tasks/Classification.py +98 -79
- evalscope/backend/rag_eval/cmteb/tasks/Clustering.py +17 -22
- evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +17 -19
- evalscope/backend/rag_eval/cmteb/tasks/PairClassification.py +35 -29
- evalscope/backend/rag_eval/cmteb/tasks/Reranking.py +18 -5
- evalscope/backend/rag_eval/cmteb/tasks/Retrieval.py +163 -163
- evalscope/backend/rag_eval/cmteb/tasks/STS.py +126 -104
- evalscope/backend/rag_eval/cmteb/tasks/__init__.py +33 -34
- evalscope/backend/rag_eval/ragas/__init__.py +2 -2
- evalscope/backend/rag_eval/ragas/arguments.py +3 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +9 -9
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/CustomNodeFilter/scoring_prompt_chinese.json +7 -0
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +8 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +7 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +21 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +4 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +0 -1
- evalscope/backend/rag_eval/ragas/task_template.py +10 -15
- evalscope/backend/rag_eval/ragas/tasks/__init__.py +1 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +45 -0
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +135 -0
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +17 -133
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +8 -18
- evalscope/backend/rag_eval/utils/clip.py +46 -50
- evalscope/backend/rag_eval/utils/embedding.py +12 -11
- evalscope/backend/rag_eval/utils/llm.py +8 -6
- evalscope/backend/rag_eval/utils/tools.py +12 -11
- evalscope/backend/vlm_eval_kit/__init__.py +1 -1
- evalscope/backend/vlm_eval_kit/custom_dataset.py +7 -8
- evalscope/benchmarks/arc/__init__.py +3 -2
- evalscope/benchmarks/arc/ai2_arc.py +19 -16
- evalscope/benchmarks/arc/arc_adapter.py +32 -24
- evalscope/benchmarks/bbh/__init__.py +1 -2
- evalscope/benchmarks/bbh/bbh_adapter.py +28 -25
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +1 -1
- evalscope/benchmarks/benchmark.py +16 -16
- evalscope/benchmarks/ceval/__init__.py +3 -2
- evalscope/benchmarks/ceval/ceval_adapter.py +80 -69
- evalscope/benchmarks/ceval/ceval_exam.py +18 -31
- evalscope/benchmarks/cmmlu/__init__.py +3 -2
- evalscope/benchmarks/cmmlu/cmmlu.py +87 -92
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +109 -155
- evalscope/benchmarks/cmmlu/samples.jsonl +1 -1
- evalscope/benchmarks/competition_math/__init__.py +3 -2
- evalscope/benchmarks/competition_math/competition_math.py +7 -16
- evalscope/benchmarks/competition_math/competition_math_adapter.py +32 -34
- evalscope/benchmarks/data_adapter.py +24 -24
- evalscope/benchmarks/general_qa/__init__.py +3 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +34 -38
- evalscope/benchmarks/gsm8k/__init__.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k.py +6 -12
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +26 -24
- evalscope/benchmarks/hellaswag/__init__.py +3 -2
- evalscope/benchmarks/hellaswag/hellaswag.py +15 -19
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +27 -23
- evalscope/benchmarks/humaneval/__init__.py +1 -1
- evalscope/benchmarks/humaneval/humaneval.py +15 -18
- evalscope/benchmarks/humaneval/humaneval_adapter.py +0 -1
- evalscope/benchmarks/mmlu/__init__.py +3 -2
- evalscope/benchmarks/mmlu/mmlu.py +15 -29
- evalscope/benchmarks/mmlu/mmlu_adapter.py +85 -77
- evalscope/benchmarks/race/__init__.py +3 -2
- evalscope/benchmarks/race/race.py +21 -35
- evalscope/benchmarks/race/race_adapter.py +32 -29
- evalscope/benchmarks/race/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/__init__.py +3 -2
- evalscope/benchmarks/trivia_qa/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/trivia_qa.py +19 -34
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +27 -22
- evalscope/benchmarks/truthful_qa/__init__.py +3 -2
- evalscope/benchmarks/truthful_qa/truthful_qa.py +25 -29
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +36 -37
- evalscope/cli/cli.py +6 -5
- evalscope/cli/start_eval.py +31 -0
- evalscope/cli/start_perf.py +0 -3
- evalscope/cli/start_server.py +27 -41
- evalscope/config.py +119 -95
- evalscope/constants.py +61 -29
- evalscope/evaluator/__init__.py +1 -0
- evalscope/evaluator/evaluator.py +96 -377
- evalscope/evaluator/humaneval_evaluator.py +158 -0
- evalscope/evaluator/rating_eval.py +12 -33
- evalscope/evaluator/reviewer/auto_reviewer.py +47 -76
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +10 -20
- evalscope/metrics/code_metric.py +3 -9
- evalscope/metrics/math_accuracy.py +3 -6
- evalscope/metrics/metrics.py +21 -21
- evalscope/metrics/rouge_metric.py +11 -25
- evalscope/models/__init__.py +1 -2
- evalscope/models/api/openai_api.py +40 -29
- evalscope/models/custom/__init__.py +0 -1
- evalscope/models/custom/custom_model.py +3 -3
- evalscope/models/dummy_chat_model.py +7 -8
- evalscope/models/model_adapter.py +89 -156
- evalscope/models/openai_model.py +20 -20
- evalscope/perf/arguments.py +15 -3
- evalscope/perf/benchmark.py +7 -9
- evalscope/perf/http_client.py +3 -8
- evalscope/perf/main.py +10 -0
- evalscope/perf/plugin/api/custom_api.py +1 -2
- evalscope/perf/plugin/api/dashscope_api.py +1 -2
- evalscope/perf/plugin/api/openai_api.py +2 -3
- evalscope/perf/plugin/datasets/base.py +1 -2
- evalscope/perf/plugin/datasets/flickr8k.py +1 -2
- evalscope/perf/plugin/datasets/longalpaca.py +1 -2
- evalscope/perf/plugin/datasets/openqa.py +1 -2
- evalscope/perf/utils/analysis_result.py +1 -2
- evalscope/perf/utils/benchmark_util.py +1 -2
- evalscope/perf/utils/db_util.py +11 -8
- evalscope/perf/utils/local_server.py +19 -13
- evalscope/registry/config/cfg_arena_zhihu.yaml +1 -1
- evalscope/registry/tasks/arc.yaml +2 -3
- evalscope/registry/tasks/bbh.yaml +3 -4
- evalscope/registry/tasks/bbh_mini.yaml +3 -4
- evalscope/registry/tasks/ceval.yaml +3 -3
- evalscope/registry/tasks/ceval_mini.yaml +3 -4
- evalscope/registry/tasks/cmmlu.yaml +3 -3
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +1 -1
- evalscope/registry/tasks/general_qa.yaml +1 -1
- evalscope/registry/tasks/gsm8k.yaml +2 -2
- evalscope/registry/tasks/mmlu.yaml +3 -3
- evalscope/registry/tasks/mmlu_mini.yaml +3 -3
- evalscope/run.py +184 -375
- evalscope/run_arena.py +20 -25
- evalscope/summarizer.py +16 -17
- evalscope/third_party/longbench_write/README.md +99 -42
- evalscope/third_party/longbench_write/default_task.json +1 -1
- evalscope/third_party/longbench_write/default_task.yaml +8 -7
- evalscope/third_party/longbench_write/eval.py +29 -28
- evalscope/third_party/longbench_write/infer.py +16 -104
- evalscope/third_party/longbench_write/longbench_write.py +5 -5
- evalscope/third_party/longbench_write/resources/judge.txt +1 -1
- evalscope/third_party/longbench_write/tools/data_etl.py +4 -5
- evalscope/third_party/longbench_write/utils.py +0 -1
- evalscope/third_party/toolbench_static/eval.py +14 -15
- evalscope/third_party/toolbench_static/infer.py +48 -69
- evalscope/third_party/toolbench_static/llm/swift_infer.py +4 -12
- evalscope/third_party/toolbench_static/requirements.txt +1 -1
- evalscope/third_party/toolbench_static/toolbench_static.py +3 -3
- evalscope/tools/combine_reports.py +25 -30
- evalscope/tools/rewrite_eval_results.py +14 -46
- evalscope/utils/__init__.py +0 -1
- evalscope/utils/arena_utils.py +18 -48
- evalscope/{perf/utils → utils}/chat_service.py +3 -4
- evalscope/utils/completion_parsers.py +3 -8
- evalscope/utils/logger.py +9 -7
- evalscope/utils/model_utils.py +11 -0
- evalscope/utils/utils.py +12 -138
- evalscope/version.py +2 -2
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/METADATA +123 -118
- evalscope-0.8.0.dist-info/RECORD +285 -0
- tests/cli/test_run.py +54 -15
- tests/perf/test_perf.py +4 -0
- tests/rag/test_clip_benchmark.py +38 -38
- tests/rag/test_mteb.py +3 -2
- tests/rag/test_ragas.py +5 -5
- tests/swift/test_run_swift_eval.py +2 -3
- tests/swift/test_run_swift_vlm_eval.py +2 -3
- tests/swift/test_run_swift_vlm_jugde_eval.py +2 -3
- evalscope/backend/rag_eval/ragas/metrics/__init__.py +0 -2
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_faithfulness.py +0 -91
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_relevance.py +0 -99
- evalscope/cache.py +0 -98
- evalscope/models/template.py +0 -1446
- evalscope/run_ms.py +0 -140
- evalscope/utils/task_cfg_parser.py +0 -10
- evalscope/utils/task_utils.py +0 -22
- evalscope-0.7.2.dist-info/RECORD +0 -286
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/LICENSE +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/WHEEL +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import torch
|
|
3
|
-
from torch.utils.data import DataLoader
|
|
4
|
-
from
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
from torch.utils.data import Dataset as TorchDataset
|
|
5
5
|
|
|
6
|
+
from evalscope.utils.logger import get_logger
|
|
6
7
|
|
|
7
8
|
logger = get_logger()
|
|
8
9
|
|
|
@@ -11,7 +12,7 @@ def build_dataset(
|
|
|
11
12
|
dataset_name,
|
|
12
13
|
root=None,
|
|
13
14
|
transform=None,
|
|
14
|
-
split=
|
|
15
|
+
split='test',
|
|
15
16
|
wds_cache_dir=None,
|
|
16
17
|
**kwargs,
|
|
17
18
|
):
|
|
@@ -40,9 +41,9 @@ def build_dataset(
|
|
|
40
41
|
|
|
41
42
|
"""
|
|
42
43
|
|
|
43
|
-
if dataset_name ==
|
|
44
|
+
if dataset_name == 'dummy':
|
|
44
45
|
ds = Dummy()
|
|
45
|
-
elif dataset_name ==
|
|
46
|
+
elif dataset_name == 'custom':
|
|
46
47
|
ds = build_custom_dataset(dataset_name, data_dir=root, transform=transform)
|
|
47
48
|
else:
|
|
48
49
|
# WebDataset support using `webdataset` library
|
|
@@ -60,7 +61,7 @@ def build_dataset(
|
|
|
60
61
|
class Dummy:
|
|
61
62
|
|
|
62
63
|
def __init__(self):
|
|
63
|
-
self.classes = [
|
|
64
|
+
self.classes = ['blank image', 'noisy image']
|
|
64
65
|
|
|
65
66
|
def __getitem__(self, i):
|
|
66
67
|
return torch.zeros(3, 224, 224), 0
|
|
@@ -70,7 +71,8 @@ class Dummy:
|
|
|
70
71
|
|
|
71
72
|
|
|
72
73
|
class DatasetWrapper(TorchDataset):
|
|
73
|
-
|
|
74
|
+
|
|
75
|
+
def __init__(self, dataset, transform=None, image_key='image', text_key='query'):
|
|
74
76
|
self.dataset = dataset
|
|
75
77
|
self.transform = transform
|
|
76
78
|
self.image_key = image_key
|
|
@@ -85,7 +87,7 @@ class DatasetWrapper(TorchDataset):
|
|
|
85
87
|
# 加载图像
|
|
86
88
|
image = item[self.image_key]
|
|
87
89
|
if self.transform is not None:
|
|
88
|
-
image = self.transform(image, return_tensors=
|
|
90
|
+
image = self.transform(image, return_tensors='pt')
|
|
89
91
|
|
|
90
92
|
# 获取查询列表
|
|
91
93
|
query = item[self.text_key]
|
|
@@ -97,24 +99,24 @@ class DatasetWrapper(TorchDataset):
|
|
|
97
99
|
|
|
98
100
|
def get_dataset_default_task(dataset):
|
|
99
101
|
if dataset in (
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
102
|
+
'custom',
|
|
103
|
+
'muge',
|
|
104
|
+
'flickr30k',
|
|
105
|
+
'flickr8k',
|
|
106
|
+
'mscoco_captions',
|
|
107
|
+
'mscoco_captions2017',
|
|
108
|
+
'multilingual_mscoco_captions',
|
|
109
|
+
'flickr30k-200',
|
|
110
|
+
'crossmodal3600',
|
|
111
|
+
'xtd200',
|
|
110
112
|
):
|
|
111
|
-
return
|
|
113
|
+
return 'zeroshot_retrieval'
|
|
112
114
|
else:
|
|
113
|
-
return
|
|
115
|
+
return 'zeroshot_classification'
|
|
114
116
|
|
|
115
117
|
|
|
116
118
|
def get_dataloader(dataset_name, dataset, batch_size, num_workers):
|
|
117
|
-
if dataset_name ==
|
|
119
|
+
if dataset_name == 'custom':
|
|
118
120
|
dataloader = DataLoader(
|
|
119
121
|
dataset,
|
|
120
122
|
batch_size=batch_size,
|
|
@@ -140,26 +142,23 @@ def image_captions_collate_fn(batch):
|
|
|
140
142
|
|
|
141
143
|
|
|
142
144
|
def build_custom_dataset(dataset_name, data_dir, transform=None):
|
|
143
|
-
from datasets import
|
|
145
|
+
from datasets import Features, Image, Sequence, Value, load_dataset
|
|
144
146
|
|
|
145
147
|
qrels_ds = load_dataset(
|
|
146
|
-
|
|
147
|
-
data_files=os.path.join(data_dir,
|
|
148
|
-
features=Features(
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
148
|
+
'json',
|
|
149
|
+
data_files=os.path.join(data_dir, 'image_queries.jsonl'),
|
|
150
|
+
features=Features({
|
|
151
|
+
'image_path': Image(decode=True),
|
|
152
|
+
'query': Sequence(Value('string'))
|
|
153
|
+
}),
|
|
154
|
+
split='train',
|
|
152
155
|
)
|
|
153
156
|
|
|
154
|
-
dataset = DatasetWrapper(
|
|
155
|
-
qrels_ds, transform, image_key="image_path", text_key="query"
|
|
156
|
-
)
|
|
157
|
+
dataset = DatasetWrapper(qrels_ds, transform, image_key='image_path', text_key='query')
|
|
157
158
|
return dataset
|
|
158
159
|
|
|
159
160
|
|
|
160
|
-
def build_wds_dataset(
|
|
161
|
-
dataset_name, transform, split="test", data_dir="root", cache_dir=None
|
|
162
|
-
):
|
|
161
|
+
def build_wds_dataset(dataset_name, transform, split='test', data_dir='root', cache_dir=None):
|
|
163
162
|
"""
|
|
164
163
|
Load a dataset in WebDataset format. Either local paths or HTTP URLs can be specified.
|
|
165
164
|
Expected file structure is:
|
|
@@ -190,41 +189,39 @@ def build_wds_dataset(
|
|
|
190
189
|
import webdataset as wds
|
|
191
190
|
|
|
192
191
|
def read_txt(fname):
|
|
193
|
-
if
|
|
194
|
-
stream = os.popen("curl -L -s --fail '%s'" % fname,
|
|
192
|
+
if '://' in fname:
|
|
193
|
+
stream = os.popen("curl -L -s --fail '%s'" % fname, 'r')
|
|
195
194
|
value = stream.read()
|
|
196
195
|
if stream.close():
|
|
197
|
-
raise FileNotFoundError(
|
|
196
|
+
raise FileNotFoundError('Failed to retreive data')
|
|
198
197
|
else:
|
|
199
|
-
with open(fname,
|
|
198
|
+
with open(fname, 'r') as file:
|
|
200
199
|
value = file.read()
|
|
201
200
|
return value
|
|
202
201
|
|
|
203
202
|
if not data_dir:
|
|
204
|
-
data_dir = f
|
|
203
|
+
data_dir = f'https://modelscope.cn/datasets/clip-benchmark/wds_{dataset_name}/resolve/master'
|
|
205
204
|
|
|
206
205
|
# Git LFS files have a different file path to access the raw data than other files
|
|
207
|
-
if data_dir.startswith(
|
|
208
|
-
*split_url_head, _, url_path = data_dir.split(
|
|
209
|
-
url_head =
|
|
210
|
-
metadata_dir =
|
|
211
|
-
tardata_dir =
|
|
206
|
+
if data_dir.startswith('https://modelscope.cn/datasets'):
|
|
207
|
+
*split_url_head, _, url_path = data_dir.split('/', 7)
|
|
208
|
+
url_head = '/'.join(split_url_head)
|
|
209
|
+
metadata_dir = '/'.join([url_head, 'resolve', url_path])
|
|
210
|
+
tardata_dir = '/'.join([url_head, 'resolve', url_path])
|
|
212
211
|
else:
|
|
213
212
|
metadata_dir = tardata_dir = data_dir
|
|
214
213
|
# Get number of shards
|
|
215
|
-
nshards_fname = os.path.join(metadata_dir, split,
|
|
216
|
-
nshards = int(
|
|
217
|
-
|
|
218
|
-
) # Do not catch FileNotFound, nshards.txt should be mandatory
|
|
219
|
-
|
|
214
|
+
nshards_fname = os.path.join(metadata_dir, split, 'nshards.txt')
|
|
215
|
+
nshards = int(read_txt(nshards_fname)) # Do not catch FileNotFound, nshards.txt should be mandatory
|
|
216
|
+
|
|
220
217
|
# Get dataset type (classification or retrieval)
|
|
221
|
-
type_fname = os.path.join(metadata_dir,
|
|
218
|
+
type_fname = os.path.join(metadata_dir, 'dataset_type.txt')
|
|
222
219
|
try:
|
|
223
220
|
dataset_type = read_txt(type_fname).strip().lower()
|
|
224
221
|
except FileNotFoundError:
|
|
225
|
-
dataset_type =
|
|
226
|
-
|
|
227
|
-
filepattern = os.path.join(tardata_dir, split,
|
|
222
|
+
dataset_type = 'classification'
|
|
223
|
+
|
|
224
|
+
filepattern = os.path.join(tardata_dir, split, '{0..%d}.tar' % (nshards - 1))
|
|
228
225
|
# Load webdataset (support WEBP, PNG, and JPG for now)
|
|
229
226
|
if not cache_dir or not isinstance(cache_dir, str):
|
|
230
227
|
cache_dir = None
|
|
@@ -236,42 +233,28 @@ def build_wds_dataset(
|
|
|
236
233
|
nodesplitter=lambda src: src,
|
|
237
234
|
shardshuffle=False,
|
|
238
235
|
verbose=True,
|
|
239
|
-
).decode(
|
|
240
|
-
|
|
241
|
-
)
|
|
242
|
-
|
|
236
|
+
).decode(wds.autodecode.ImageHandler('pil', extensions=['webp', 'png', 'jpg', 'jpeg']))
|
|
237
|
+
|
|
243
238
|
# Load based on classification or retrieval task
|
|
244
|
-
if dataset_type ==
|
|
245
|
-
dataset = dataset.to_tuple([
|
|
246
|
-
transform, str.splitlines
|
|
247
|
-
)
|
|
239
|
+
if dataset_type == 'retrieval':
|
|
240
|
+
dataset = dataset.to_tuple(['webp', 'png', 'jpg', 'jpeg'], 'txt').map_tuple(transform, str.splitlines)
|
|
248
241
|
dataset.classes = dataset.templates = None
|
|
249
242
|
else:
|
|
250
|
-
label_type = (
|
|
251
|
-
|
|
252
|
-
) # Special case for multilabel
|
|
253
|
-
dataset = dataset.to_tuple(
|
|
254
|
-
["webp", "png", "jpg", "jpeg"], label_type
|
|
255
|
-
).map_tuple(transform, None)
|
|
243
|
+
label_type = ('npy' if dataset_type == 'multilabel' else 'cls') # Special case for multilabel
|
|
244
|
+
dataset = dataset.to_tuple(['webp', 'png', 'jpg', 'jpeg'], label_type).map_tuple(transform, None)
|
|
256
245
|
# Get class names if present
|
|
257
|
-
classnames_fname = os.path.join(metadata_dir,
|
|
246
|
+
classnames_fname = os.path.join(metadata_dir, 'classnames.txt')
|
|
258
247
|
try:
|
|
259
|
-
dataset.classes = [
|
|
260
|
-
line.strip() for line in read_txt(classnames_fname).splitlines()
|
|
261
|
-
]
|
|
248
|
+
dataset.classes = [line.strip() for line in read_txt(classnames_fname).splitlines()]
|
|
262
249
|
except FileNotFoundError:
|
|
263
|
-
logger.warning(
|
|
250
|
+
logger.warning('WARNING: classnames.txt not found')
|
|
264
251
|
dataset.classes = None
|
|
265
252
|
# Get zeroshot classification templates if present
|
|
266
|
-
templates_fname = os.path.join(
|
|
267
|
-
metadata_dir, "zeroshot_classification_templates.txt"
|
|
268
|
-
)
|
|
253
|
+
templates_fname = os.path.join(metadata_dir, 'zeroshot_classification_templates.txt')
|
|
269
254
|
try:
|
|
270
|
-
dataset.templates = [
|
|
271
|
-
line.strip() for line in read_txt(templates_fname).splitlines()
|
|
272
|
-
]
|
|
255
|
+
dataset.templates = [line.strip() for line in read_txt(templates_fname).splitlines()]
|
|
273
256
|
except FileNotFoundError:
|
|
274
|
-
logger.warning(
|
|
257
|
+
logger.warning('WARNING: zeroshot_classification_templates.txt not found')
|
|
275
258
|
dataset.templates = None
|
|
276
259
|
|
|
277
260
|
return dataset
|
|
@@ -1,19 +1,12 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import os
|
|
2
3
|
import torch
|
|
3
|
-
import json
|
|
4
4
|
from itertools import product
|
|
5
5
|
|
|
6
|
-
from evalscope.backend.rag_eval.clip_benchmark.dataset_builder import (
|
|
7
|
-
build_dataset,
|
|
8
|
-
get_dataset_default_task,
|
|
9
|
-
get_dataloader,
|
|
10
|
-
)
|
|
11
|
-
from evalscope.backend.rag_eval.clip_benchmark.tasks import (
|
|
12
|
-
zeroshot_classification,
|
|
13
|
-
zeroshot_retrieval,
|
|
14
|
-
image_caption,
|
|
15
|
-
)
|
|
16
6
|
from evalscope.backend.rag_eval.clip_benchmark.arguments import Arguments
|
|
7
|
+
from evalscope.backend.rag_eval.clip_benchmark.dataset_builder import (build_dataset, get_dataloader,
|
|
8
|
+
get_dataset_default_task)
|
|
9
|
+
from evalscope.backend.rag_eval.clip_benchmark.tasks import image_caption, zeroshot_classification, zeroshot_retrieval
|
|
17
10
|
from evalscope.backend.rag_eval.utils.clip import VisionModel
|
|
18
11
|
from evalscope.utils.logger import get_logger
|
|
19
12
|
|
|
@@ -37,21 +30,21 @@ def evaluate(args: Arguments):
|
|
|
37
30
|
# Iterate over model and dataset combinations
|
|
38
31
|
for model_cfg, dataset_name in product(models, dataset_names):
|
|
39
32
|
task = input_task or get_dataset_default_task(dataset_name)
|
|
40
|
-
model_name = os.path.basename(model_cfg[
|
|
33
|
+
model_name = os.path.basename(model_cfg['model_name'])
|
|
41
34
|
|
|
42
35
|
output_path = os.path.join(output_dir, model_name)
|
|
43
36
|
os.makedirs(output_path, exist_ok=True)
|
|
44
|
-
output_file = os.path.join(output_path, f
|
|
37
|
+
output_file = os.path.join(output_path, f'{dataset_name}_{task}.json')
|
|
45
38
|
|
|
46
39
|
# Skip evaluation if the result already exists and skip_existing is True
|
|
47
40
|
if os.path.exists(output_file) and skip_existing:
|
|
48
41
|
if verbose:
|
|
49
|
-
logger.info(f
|
|
42
|
+
logger.info(f'Skip {output_dir}, exists already.')
|
|
50
43
|
return
|
|
51
44
|
|
|
52
45
|
# Determine device (CPU or GPU)
|
|
53
|
-
device =
|
|
54
|
-
model_cfg[
|
|
46
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
47
|
+
model_cfg['device'] = device
|
|
55
48
|
# Initialize the model
|
|
56
49
|
model = VisionModel.load(**model_cfg)
|
|
57
50
|
|
|
@@ -61,23 +54,20 @@ def evaluate(args: Arguments):
|
|
|
61
54
|
root=data_dir,
|
|
62
55
|
transform=model.transform,
|
|
63
56
|
split=split,
|
|
64
|
-
wds_cache_dir=f
|
|
57
|
+
wds_cache_dir=f'{cache_dir}/{dataset_name}',
|
|
65
58
|
)
|
|
66
59
|
|
|
67
60
|
# Create the dataloader
|
|
68
61
|
dataloader = get_dataloader(dataset_name, dataset, batch_size, num_workers)
|
|
69
62
|
|
|
70
63
|
# Evaluate based on the task
|
|
71
|
-
if task ==
|
|
72
|
-
zeroshot_templates = (
|
|
73
|
-
dataset.templates if hasattr(dataset, "templates") else None
|
|
74
|
-
)
|
|
64
|
+
if task == 'zeroshot_classification':
|
|
65
|
+
zeroshot_templates = (dataset.templates if hasattr(dataset, 'templates') else None)
|
|
75
66
|
if verbose:
|
|
76
|
-
logger.info(f
|
|
77
|
-
classnames = dataset.classes if hasattr(dataset,
|
|
78
|
-
assert (
|
|
79
|
-
|
|
80
|
-
), "Dataset does not support classification"
|
|
67
|
+
logger.info(f'Zero-shot templates: {zeroshot_templates}')
|
|
68
|
+
classnames = dataset.classes if hasattr(dataset, 'classes') else None
|
|
69
|
+
assert (zeroshot_templates is not None
|
|
70
|
+
and classnames is not None), 'Dataset does not support classification'
|
|
81
71
|
metrics = zeroshot_classification.evaluate(
|
|
82
72
|
model,
|
|
83
73
|
dataloader,
|
|
@@ -87,33 +77,29 @@ def evaluate(args: Arguments):
|
|
|
87
77
|
verbose=verbose,
|
|
88
78
|
limit=limit,
|
|
89
79
|
)
|
|
90
|
-
elif task ==
|
|
91
|
-
metrics = zeroshot_retrieval.evaluate(
|
|
92
|
-
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
output_path = os.path.join(output_path, dataset_name, "retrieval_data")
|
|
96
|
-
metrics = image_caption.evaluate(
|
|
97
|
-
model, dataloader, limit=limit, output_path=output_path
|
|
98
|
-
)
|
|
80
|
+
elif task == 'zeroshot_retrieval':
|
|
81
|
+
metrics = zeroshot_retrieval.evaluate(model, dataloader, recall_k_list=[5], device=device, limit=limit)
|
|
82
|
+
elif task == 'image_caption':
|
|
83
|
+
output_path = os.path.join(output_path, dataset_name, 'retrieval_data')
|
|
84
|
+
metrics = image_caption.evaluate(model, dataloader, limit=limit, output_path=output_path)
|
|
99
85
|
|
|
100
86
|
# Prepare dump data
|
|
101
87
|
dump = {
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
88
|
+
'dataset': dataset_name,
|
|
89
|
+
'model': model_name,
|
|
90
|
+
'task': task,
|
|
91
|
+
'metrics': metrics,
|
|
106
92
|
}
|
|
107
93
|
|
|
108
94
|
if verbose:
|
|
109
|
-
logger.info(f
|
|
95
|
+
logger.info(f'Evaluation results: {dump}')
|
|
110
96
|
|
|
111
97
|
# Write the results to output file
|
|
112
98
|
if verbose:
|
|
113
|
-
logger.info(f
|
|
114
|
-
with open(output_file,
|
|
99
|
+
logger.info(f'Dump results to: {output_file}')
|
|
100
|
+
with open(output_file, 'w') as f:
|
|
115
101
|
json.dump(dump, f)
|
|
116
102
|
|
|
117
103
|
|
|
118
|
-
if __name__ ==
|
|
104
|
+
if __name__ == '__main__':
|
|
119
105
|
evaluate()
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
from tqdm import tqdm
|
|
2
|
-
import pandas as pd
|
|
3
1
|
import os
|
|
4
|
-
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from tqdm import tqdm
|
|
5
4
|
|
|
5
|
+
from evalscope.backend.rag_eval.utils.tools import save_to_jsonl, save_to_tsv
|
|
6
6
|
from evalscope.utils.logger import get_logger
|
|
7
7
|
|
|
8
8
|
logger = get_logger()
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def evaluate(model, dataloader, limit=None, output_path=
|
|
11
|
+
def evaluate(model, dataloader, limit=None, output_path=''):
|
|
12
12
|
"""
|
|
13
13
|
Evaluate the model on the dataset
|
|
14
14
|
Parameters
|
|
@@ -31,9 +31,7 @@ def evaluate(model, dataloader, limit=None, output_path=""):
|
|
|
31
31
|
captions = model.encode_image(batch_images)
|
|
32
32
|
querys = [text for texts in batch_texts for text in texts]
|
|
33
33
|
|
|
34
|
-
batch_texts_image_index = [
|
|
35
|
-
ind for ind, texts in zip(inds, batch_texts) for text in texts
|
|
36
|
-
]
|
|
34
|
+
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
|
|
37
35
|
|
|
38
36
|
total_captions.extend(captions)
|
|
39
37
|
total_querys.extend(querys)
|
|
@@ -47,30 +45,30 @@ def evaluate(model, dataloader, limit=None, output_path=""):
|
|
|
47
45
|
break
|
|
48
46
|
|
|
49
47
|
write_file(total_querys, total_captions, query_caption_index, output_path)
|
|
50
|
-
return {
|
|
48
|
+
return {'convertion_successful': True, 'save_path': output_path}
|
|
51
49
|
|
|
52
50
|
|
|
53
51
|
def write_file(query_list, corpus_list, qrels_list, output_path):
|
|
54
52
|
# 处理 query_list
|
|
55
|
-
query_df = pd.DataFrame(query_list, columns=[
|
|
56
|
-
query_df[
|
|
57
|
-
query_df = query_df[[
|
|
58
|
-
save_to_jsonl(query_df, os.path.join(output_path,
|
|
53
|
+
query_df = pd.DataFrame(query_list, columns=['text'])
|
|
54
|
+
query_df['_id'] = query_df.index
|
|
55
|
+
query_df = query_df[['_id', 'text']]
|
|
56
|
+
save_to_jsonl(query_df, os.path.join(output_path, 'queries.jsonl'))
|
|
59
57
|
|
|
60
58
|
# 处理 corpus_list
|
|
61
|
-
corpus_df = pd.DataFrame(corpus_list, columns=[
|
|
62
|
-
corpus_df[
|
|
63
|
-
corpus_df = corpus_df[[
|
|
64
|
-
save_to_jsonl(corpus_df, os.path.join(output_path,
|
|
59
|
+
corpus_df = pd.DataFrame(corpus_list, columns=['text'])
|
|
60
|
+
corpus_df['_id'] = corpus_df.index
|
|
61
|
+
corpus_df = corpus_df[['_id', 'text']]
|
|
62
|
+
save_to_jsonl(corpus_df, os.path.join(output_path, 'corpus.jsonl'))
|
|
65
63
|
|
|
66
64
|
# 处理 qrels_list
|
|
67
|
-
qrels_df = pd.DataFrame(qrels_list, columns=[
|
|
68
|
-
qrels_df[
|
|
69
|
-
qrels_df[
|
|
70
|
-
qrels_df = qrels_df[[
|
|
71
|
-
save_to_tsv(qrels_df, os.path.join(output_path,
|
|
65
|
+
qrels_df = pd.DataFrame(qrels_list, columns=['corpus-id'])
|
|
66
|
+
qrels_df['query-id'] = qrels_df.index
|
|
67
|
+
qrels_df['score'] = 1
|
|
68
|
+
qrels_df = qrels_df[['query-id', 'corpus-id', 'score']]
|
|
69
|
+
save_to_tsv(qrels_df, os.path.join(output_path, 'qrels', 'test.tsv'))
|
|
72
70
|
|
|
73
|
-
logger.info(
|
|
71
|
+
logger.info('Write files to {}'.format(output_path))
|
|
74
72
|
return
|
|
75
73
|
|
|
76
74
|
|
|
@@ -4,14 +4,12 @@ Thanks to the authors of OpenCLIP
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from contextlib import suppress
|
|
8
|
-
|
|
9
7
|
import torch
|
|
10
8
|
import torch.nn.functional as F
|
|
9
|
+
from contextlib import suppress
|
|
10
|
+
from sklearn.metrics import balanced_accuracy_score, classification_report
|
|
11
11
|
from tqdm import tqdm
|
|
12
12
|
|
|
13
|
-
from sklearn.metrics import classification_report, balanced_accuracy_score
|
|
14
|
-
|
|
15
13
|
from evalscope.utils.logger import get_logger
|
|
16
14
|
|
|
17
15
|
logger = get_logger()
|
|
@@ -49,7 +47,7 @@ def zero_shot_classifier(model, classnames, templates, device, amp=True):
|
|
|
49
47
|
# generic prompts tht are specialized for each class by replacing {c} with the class name
|
|
50
48
|
texts = [template.format(c=classname) for template in templates]
|
|
51
49
|
else:
|
|
52
|
-
raise ValueError(
|
|
50
|
+
raise ValueError('templates must be a list or a dict')
|
|
53
51
|
class_embedding = model.encode_text(texts).mean(dim=0)
|
|
54
52
|
class_embedding = F.normalize(class_embedding, dim=0)
|
|
55
53
|
zeroshot_weights.append(class_embedding)
|
|
@@ -57,7 +55,7 @@ def zero_shot_classifier(model, classnames, templates, device, amp=True):
|
|
|
57
55
|
return zeroshot_weights
|
|
58
56
|
|
|
59
57
|
|
|
60
|
-
def accuracy(output, target, topk=(1,)):
|
|
58
|
+
def accuracy(output, target, topk=(1, )):
|
|
61
59
|
"""
|
|
62
60
|
Compute top-k accuracy
|
|
63
61
|
|
|
@@ -79,10 +77,7 @@ def accuracy(output, target, topk=(1,)):
|
|
|
79
77
|
pred = output.topk(max(topk), 1, True, True)[1].t()
|
|
80
78
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
81
79
|
n = len(target)
|
|
82
|
-
return [
|
|
83
|
-
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n
|
|
84
|
-
for k in topk
|
|
85
|
-
]
|
|
80
|
+
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
|
|
86
81
|
|
|
87
82
|
|
|
88
83
|
def run_classification(model, classifier, dataloader, device, amp=True, limit=None):
|
|
@@ -115,7 +110,7 @@ def run_classification(model, classifier, dataloader, device, amp=True, limit=No
|
|
|
115
110
|
# predict
|
|
116
111
|
image_features = model.encode_image(images)
|
|
117
112
|
logits = 100.0 * image_features @ classifier
|
|
118
|
-
|
|
113
|
+
|
|
119
114
|
if limit is not None:
|
|
120
115
|
# Update sample counter
|
|
121
116
|
sample_count += len(images)
|
|
@@ -217,15 +212,13 @@ def evaluate(
|
|
|
217
212
|
|
|
218
213
|
if is_multilabel:
|
|
219
214
|
if verbose:
|
|
220
|
-
logger.info(
|
|
215
|
+
logger.info('Detected a multi-label classification dataset')
|
|
221
216
|
# Multiple labels per image, multiple classes on the dataset
|
|
222
217
|
ap_per_class = average_precision_per_class(logits, target)
|
|
223
218
|
if verbose:
|
|
224
|
-
for class_name, ap in zip(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
logger.info(f"Class: {class_name}, AveragePrecision: {ap}")
|
|
228
|
-
return {"mean_average_precision": ap_per_class.mean().item()}
|
|
219
|
+
for class_name, ap in zip(dataloader.dataset.classes, ap_per_class.tolist()):
|
|
220
|
+
logger.info(f'Class: {class_name}, AveragePrecision: {ap}')
|
|
221
|
+
return {'mean_average_precision': ap_per_class.mean().item()}
|
|
229
222
|
else:
|
|
230
223
|
# Single label per image, multiple classes on the dataset
|
|
231
224
|
# just compute accuracy and mean_per_class_recall
|
|
@@ -235,13 +228,13 @@ def evaluate(
|
|
|
235
228
|
if len(dataloader.dataset.classes) >= 5:
|
|
236
229
|
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
|
237
230
|
else:
|
|
238
|
-
(acc1,) = accuracy(logits, target, topk=(1,))
|
|
239
|
-
acc5 = float(
|
|
231
|
+
(acc1, ) = accuracy(logits, target, topk=(1, ))
|
|
232
|
+
acc5 = float('nan')
|
|
240
233
|
mean_per_class_recall = balanced_accuracy_score(target, pred)
|
|
241
234
|
if verbose:
|
|
242
|
-
logger.info(
|
|
235
|
+
logger.info('\n' + classification_report(target, pred, digits=3))
|
|
243
236
|
return {
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
237
|
+
'acc1': acc1,
|
|
238
|
+
'acc5': acc5,
|
|
239
|
+
'mean_per_class_recall': mean_per_class_recall,
|
|
247
240
|
}
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from contextlib import suppress
|
|
3
|
-
|
|
4
2
|
import torch
|
|
5
3
|
import torch.nn.functional as F
|
|
4
|
+
from contextlib import suppress
|
|
6
5
|
from tqdm import tqdm
|
|
6
|
+
|
|
7
7
|
from evalscope.utils.logger import get_logger
|
|
8
8
|
|
|
9
9
|
logger = get_logger()
|
|
@@ -51,9 +51,7 @@ def evaluate(model, dataloader, device, amp=True, recall_k_list=[5], limit=None)
|
|
|
51
51
|
for batch_images, batch_texts, inds in tqdm(dataloader):
|
|
52
52
|
|
|
53
53
|
# store the index of image for each text
|
|
54
|
-
batch_texts_image_index = [
|
|
55
|
-
ind for ind, texts in zip(inds, batch_texts) for text in texts
|
|
56
|
-
]
|
|
54
|
+
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
|
|
57
55
|
|
|
58
56
|
# compute the embedding of images and texts
|
|
59
57
|
batch_images_emb = model.encode_image(batch_images)
|
|
@@ -93,33 +91,16 @@ def evaluate(model, dataloader, device, amp=True, recall_k_list=[5], limit=None)
|
|
|
93
91
|
# so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
|
|
94
92
|
# which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
|
|
95
93
|
# it over the dataset.
|
|
96
|
-
metrics[f
|
|
97
|
-
(
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
)
|
|
107
|
-
metrics[f"text_retrieval_recall@{recall_k}"] = (
|
|
108
|
-
(
|
|
109
|
-
batchify(
|
|
110
|
-
recall_at_k,
|
|
111
|
-
scores.T,
|
|
112
|
-
positive_pairs.T,
|
|
113
|
-
batch_size,
|
|
114
|
-
device,
|
|
115
|
-
k=recall_k,
|
|
116
|
-
)
|
|
117
|
-
> 0
|
|
118
|
-
)
|
|
119
|
-
.float()
|
|
120
|
-
.mean()
|
|
121
|
-
.item()
|
|
122
|
-
)
|
|
94
|
+
metrics[f'image_retrieval_recall@{recall_k}'] = ((batchify(
|
|
95
|
+
recall_at_k, scores, positive_pairs, batch_size, device, k=recall_k) > 0).float().mean().item())
|
|
96
|
+
metrics[f'text_retrieval_recall@{recall_k}'] = ((batchify(
|
|
97
|
+
recall_at_k,
|
|
98
|
+
scores.T,
|
|
99
|
+
positive_pairs.T,
|
|
100
|
+
batch_size,
|
|
101
|
+
device,
|
|
102
|
+
k=recall_k,
|
|
103
|
+
) > 0).float().mean().item())
|
|
123
104
|
|
|
124
105
|
return metrics
|
|
125
106
|
|
|
@@ -147,9 +128,7 @@ def recall_at_k(scores, positive_pairs, k):
|
|
|
147
128
|
# compute number of positives for each text
|
|
148
129
|
nb_positive = positive_pairs.sum(dim=1)
|
|
149
130
|
# nb_texts, k, nb_images
|
|
150
|
-
topk_indices_onehot = torch.nn.functional.one_hot(
|
|
151
|
-
topk_indices, num_classes=nb_images
|
|
152
|
-
)
|
|
131
|
+
topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
|
|
153
132
|
# compute number of true positives
|
|
154
133
|
positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
|
|
155
134
|
# a true positive means a positive among the topk
|