evalscope 0.5.5rc0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (49) hide show
  1. evalscope/backend/__init__.py +0 -3
  2. evalscope/backend/opencompass/tasks/eval_datasets.py +1 -1
  3. evalscope/backend/rag_eval/__init__.py +4 -0
  4. evalscope/backend/rag_eval/backend_manager.py +80 -0
  5. evalscope/backend/rag_eval/clip_benchmark/__init__.py +2 -0
  6. evalscope/backend/rag_eval/clip_benchmark/arguments.py +34 -0
  7. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +277 -0
  8. evalscope/backend/rag_eval/clip_benchmark/task_template.py +119 -0
  9. evalscope/backend/rag_eval/clip_benchmark/tasks/__init__.py +0 -0
  10. evalscope/backend/rag_eval/clip_benchmark/tasks/image_caption.py +83 -0
  11. evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +247 -0
  12. evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_retrieval.py +170 -0
  13. evalscope/backend/rag_eval/cmteb/__init__.py +4 -0
  14. evalscope/backend/rag_eval/cmteb/arguments.py +61 -0
  15. evalscope/backend/rag_eval/cmteb/base.py +91 -0
  16. evalscope/backend/rag_eval/cmteb/task_template.py +85 -0
  17. evalscope/backend/rag_eval/cmteb/tasks/Classification.py +302 -0
  18. evalscope/backend/rag_eval/cmteb/tasks/Clustering.py +252 -0
  19. evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +61 -0
  20. evalscope/backend/rag_eval/cmteb/tasks/PairClassification.py +113 -0
  21. evalscope/backend/rag_eval/cmteb/tasks/Reranking.py +151 -0
  22. evalscope/backend/rag_eval/cmteb/tasks/Retrieval.py +345 -0
  23. evalscope/backend/rag_eval/cmteb/tasks/STS.py +302 -0
  24. evalscope/backend/rag_eval/cmteb/tasks/__init__.py +70 -0
  25. evalscope/backend/rag_eval/ragas/__init__.py +2 -0
  26. evalscope/backend/rag_eval/ragas/arguments.py +47 -0
  27. evalscope/backend/rag_eval/ragas/metrics/__init__.py +2 -0
  28. evalscope/backend/rag_eval/ragas/metrics/multi_modal_faithfulness.py +91 -0
  29. evalscope/backend/rag_eval/ragas/metrics/multi_modal_relevance.py +99 -0
  30. evalscope/backend/rag_eval/ragas/task_template.py +61 -0
  31. evalscope/backend/rag_eval/ragas/tasks/__init__.py +2 -0
  32. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +263 -0
  33. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +72 -0
  34. evalscope/backend/vlm_eval_kit/backend_manager.py +0 -1
  35. evalscope/backend/vlm_eval_kit/custom_dataset.py +1 -1
  36. evalscope/evaluator/evaluator.py +1 -0
  37. evalscope/metrics/bundled_rouge_score/rouge_scorer.py +19 -0
  38. evalscope/models/api/openai_api.py +2 -2
  39. evalscope/perf/http_client.py +1 -1
  40. evalscope/perf/openai_api.py +2 -0
  41. evalscope/run.py +4 -0
  42. evalscope/utils/logger.py +44 -14
  43. evalscope/utils/task_utils.py +3 -0
  44. evalscope/version.py +2 -2
  45. {evalscope-0.5.5rc0.dist-info → evalscope-0.6.0.dist-info}/METADATA +95 -99
  46. {evalscope-0.5.5rc0.dist-info → evalscope-0.6.0.dist-info}/RECORD +49 -18
  47. {evalscope-0.5.5rc0.dist-info → evalscope-0.6.0.dist-info}/WHEEL +1 -1
  48. {evalscope-0.5.5rc0.dist-info → evalscope-0.6.0.dist-info}/entry_points.txt +0 -0
  49. {evalscope-0.5.5rc0.dist-info → evalscope-0.6.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +0,0 @@
1
- # Copyright (c) Alibaba, Inc. and its affiliates.
2
-
3
- from evalscope.backend.opencompass.backend_manager import OpenCompassBackendManager
@@ -50,12 +50,12 @@ with read_base():
50
50
  from opencompass.configs.datasets.nq.nq_gen_c788f6 import nq_datasets
51
51
  from opencompass.configs.datasets.triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
52
52
  from opencompass.configs.datasets.cmb.cmb_gen_dfb5c4 import cmb_datasets
53
- from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets
54
53
 
55
54
  # Note: to be supported
56
55
  # from opencompass.configs.datasets.flores.flores_gen_806ede import flores_datasets
57
56
  # from opencompass.configs.datasets.TheoremQA.TheoremQA_5shot_gen_6f0af8 import TheoremQA_datasets
58
57
  # from opencompass.configs.datasets.commonsenseqa.commonsenseqa_gen_c946f2 import commonsenseqa_datasets
58
+ # from opencompass.configs.datasets.bbh.bbh_gen_5b92b0 import bbh_datasets
59
59
 
60
60
 
61
61
  datasets = []
@@ -0,0 +1,4 @@
1
+ from evalscope.backend.rag_eval.utils.embedding import EmbeddingModel
2
+ from evalscope.backend.rag_eval.utils.llm import LLM, LocalLLM, ChatOpenAI
3
+ from evalscope.backend.rag_eval.utils.clip import VisionModel
4
+ from evalscope.backend.rag_eval.backend_manager import RAGEvalBackendManager
@@ -0,0 +1,80 @@
1
+ import os
2
+ from typing import Optional, Union
3
+ from evalscope.utils import is_module_installed, get_valid_list
4
+ from evalscope.backend.base import BackendManager
5
+ from evalscope.utils.logger import get_logger
6
+
7
+
8
+ logger = get_logger()
9
+
10
+
11
+ class RAGEvalBackendManager(BackendManager):
12
+ def __init__(self, config: Union[str, dict], **kwargs):
13
+ """BackendManager for VLM Evaluation Kit
14
+
15
+ Args:
16
+ config (Union[str, dict]): the configuration yaml-file or the configuration dictionary
17
+ """
18
+ super().__init__(config, **kwargs)
19
+
20
+ @staticmethod
21
+ def _check_env(module_name: str):
22
+ if is_module_installed(module_name):
23
+ logger.info(f"Check `{module_name}` Installed")
24
+ else:
25
+ logger.error(f"Please install `{module_name}` first")
26
+
27
+ @staticmethod
28
+ def run_mteb(model_args, eval_args):
29
+ from evalscope.backend.rag_eval.cmteb import ModelArguments, EvalArguments
30
+ from evalscope.backend.rag_eval.cmteb import one_stage_eval, two_stage_eval
31
+
32
+ if len(model_args) > 2:
33
+ raise ValueError("Not support multiple models yet")
34
+
35
+ # Convert arguments to dictionary
36
+ model_args_list = [ModelArguments(**args).to_dict() for args in model_args]
37
+ eval_args = EvalArguments(**eval_args).to_dict()
38
+
39
+ if len(model_args_list) == 1:
40
+ one_stage_eval(model_args_list[0], eval_args)
41
+ else: # len(model_args_list) == 2
42
+ two_stage_eval(model_args_list[0], model_args_list[1], eval_args)
43
+
44
+ @staticmethod
45
+ def run_ragas(testset_args, eval_args):
46
+ from evalscope.backend.rag_eval.ragas import rag_eval
47
+ from evalscope.backend.rag_eval.ragas.tasks import generate_testset
48
+ from evalscope.backend.rag_eval.ragas import (
49
+ TestsetGenerationArguments,
50
+ EvaluationArguments,
51
+ )
52
+
53
+ if testset_args is not None:
54
+ generate_testset(TestsetGenerationArguments(**testset_args))
55
+ if eval_args is not None:
56
+ rag_eval(EvaluationArguments(**eval_args))
57
+
58
+ @staticmethod
59
+ def run_clip_benchmark(args):
60
+ from evalscope.backend.rag_eval.clip_benchmark import Arguments, evaluate
61
+
62
+ evaluate(Arguments(**args))
63
+
64
+ def run(self, *args, **kwargs):
65
+ tool = self.config_d.pop("tool")
66
+ if tool.lower() == "mteb":
67
+ self._check_env("mteb")
68
+ model_args = self.config_d["model"]
69
+ eval_args = self.config_d["eval"]
70
+ self.run_mteb(model_args, eval_args)
71
+ elif tool.lower() == "ragas":
72
+ self._check_env("ragas")
73
+ testset_args = self.config_d.get("testset_generation", None)
74
+ eval_args = self.config_d.get("eval", None)
75
+ self.run_ragas(testset_args, eval_args)
76
+ elif tool.lower() == "clip_benchmark":
77
+ self._check_env("webdataset")
78
+ self.run_clip_benchmark(self.config_d["eval"])
79
+ else:
80
+ raise ValueError(f"Unknown tool: {tool}")
@@ -0,0 +1,2 @@
1
+ from evalscope.backend.rag_eval.clip_benchmark.task_template import evaluate
2
+ from evalscope.backend.rag_eval.clip_benchmark.arguments import Arguments
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Dict
3
+
4
+
5
+ @dataclass
6
+ class Arguments:
7
+ # fmt: off
8
+ """
9
+ A dataclass to store and manage the arguments for the model configuration and data processing.
10
+ """
11
+ """
12
+ For CLIP model support, you can use the following fields:
13
+ model_name: str
14
+ revision: str = "master"
15
+ hub: str = "modelscope"
16
+
17
+ For API VLM model support, you can use the following fields, (image caption only):
18
+ model_name="gpt-4o-mini"
19
+ api_base: str = "",
20
+ api_key: Optional[str] = None
21
+ prompt: str = None
22
+ """
23
+ models: List[Dict] = field(default_factory=dict) # List of paths to the pre-trained models or model identifiers
24
+ dataset_name: List[str] = field(default_factory=list) # List of dataset names to be used
25
+ data_dir: str = None # Root directory where the datasets are stored
26
+ split: str = "test" # Split of the dataset to be used (e.g., 'train', 'validation', 'test')
27
+ task: str = None
28
+ batch_size: int = 128 # Batch size for data loading
29
+ num_workers: int = 1 # Number of workers for data loading
30
+ verbose: bool = True # Flag to enable verbose logging
31
+ output_dir: str = "outputs" # Directory where the outputs (e.g., predictions, logs) will be saved
32
+ cache_dir: str = "cache" # Directory where the dataset cache will be stored
33
+ skip_existing: bool = False # Flag to skip processing if outputs already exist
34
+ limit: int = None # Limit the number of samples to be processed
@@ -0,0 +1,277 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.data import DataLoader, Dataset as TorchDataset
4
+ from evalscope.utils.logger import get_logger
5
+
6
+
7
+ logger = get_logger()
8
+
9
+
10
+ def build_dataset(
11
+ dataset_name,
12
+ root=None,
13
+ transform=None,
14
+ split="test",
15
+ wds_cache_dir=None,
16
+ **kwargs,
17
+ ):
18
+ """
19
+ Main function to use in order to build a dataset instance,
20
+
21
+ dataset_name: str
22
+ name of the dataset
23
+
24
+ root: str
25
+ root folder where the dataset is downloaded and stored. can be shared among datasets.
26
+
27
+ transform: torchvision transform applied to images
28
+
29
+ split: str
30
+ split to use, depending on the dataset can have different options.
31
+ In general, `train` and `test` are available.
32
+ For specific splits, please look at the corresponding dataset.
33
+
34
+ custom_classname_file: str or None
35
+ Custom classname file where keys are dataset names and values are list of classnames.
36
+
37
+ custom_template_file: str or None
38
+ Custom template file where keys are dataset names and values are list of prompts, or dicts
39
+ where keys are classnames and values are class-specific prompts.
40
+
41
+ """
42
+
43
+ if dataset_name == "dummy":
44
+ ds = Dummy()
45
+ elif dataset_name == "custom":
46
+ ds = build_custom_dataset(dataset_name, data_dir=root, transform=transform)
47
+ else:
48
+ # WebDataset support using `webdataset` library
49
+ ds = build_wds_dataset(
50
+ dataset_name,
51
+ transform=transform,
52
+ split=split,
53
+ data_dir=root,
54
+ cache_dir=wds_cache_dir,
55
+ )
56
+
57
+ return ds
58
+
59
+
60
+ class Dummy:
61
+
62
+ def __init__(self):
63
+ self.classes = ["blank image", "noisy image"]
64
+
65
+ def __getitem__(self, i):
66
+ return torch.zeros(3, 224, 224), 0
67
+
68
+ def __len__(self):
69
+ return 1
70
+
71
+
72
+ class DatasetWrapper(TorchDataset):
73
+ def __init__(self, dataset, transform=None, image_key="image", text_key="query"):
74
+ self.dataset = dataset
75
+ self.transform = transform
76
+ self.image_key = image_key
77
+ self.text_key = text_key
78
+
79
+ def __len__(self):
80
+ return len(self.dataset)
81
+
82
+ def __getitem__(self, idx):
83
+ item = self.dataset[idx]
84
+
85
+ # 加载图像
86
+ image = item[self.image_key]
87
+ if self.transform is not None:
88
+ image = self.transform(image, return_tensors="pt")
89
+
90
+ # 获取查询列表
91
+ query = item[self.text_key]
92
+ if isinstance(query, str):
93
+ query = [query]
94
+
95
+ return image, query
96
+
97
+
98
+ def get_dataset_default_task(dataset):
99
+ if dataset in (
100
+ "custom",
101
+ "muge",
102
+ "flickr30k",
103
+ "flickr8k",
104
+ "mscoco_captions",
105
+ "mscoco_captions2017",
106
+ "multilingual_mscoco_captions",
107
+ "flickr30k-200",
108
+ "crossmodal3600",
109
+ "xtd200",
110
+ ):
111
+ return "zeroshot_retrieval"
112
+ else:
113
+ return "zeroshot_classification"
114
+
115
+
116
+ def get_dataloader(dataset_name, dataset, batch_size, num_workers):
117
+ if dataset_name == "custom":
118
+ dataloader = DataLoader(
119
+ dataset,
120
+ batch_size=batch_size,
121
+ shuffle=False,
122
+ num_workers=num_workers,
123
+ collate_fn=image_captions_collate_fn,
124
+ )
125
+ else:
126
+ dataloader = DataLoader(
127
+ dataset.batched(batch_size),
128
+ batch_size=None,
129
+ shuffle=False,
130
+ num_workers=num_workers,
131
+ )
132
+ return dataloader
133
+
134
+
135
+ def image_captions_collate_fn(batch):
136
+ transposed = list(zip(*batch))
137
+ imgs = transposed[0]
138
+ texts = transposed[1]
139
+ return imgs, texts
140
+
141
+
142
+ def build_custom_dataset(dataset_name, data_dir, transform=None):
143
+ from datasets import load_dataset, Features, Image, Sequence, Value
144
+
145
+ qrels_ds = load_dataset(
146
+ "json",
147
+ data_files=os.path.join(data_dir, "image_queries.jsonl"),
148
+ features=Features(
149
+ {"image_path": Image(decode=True), "query": Sequence(Value("string"))}
150
+ ),
151
+ split="train",
152
+ )
153
+
154
+ dataset = DatasetWrapper(
155
+ qrels_ds, transform, image_key="image_path", text_key="query"
156
+ )
157
+ return dataset
158
+
159
+
160
+ def build_wds_dataset(
161
+ dataset_name, transform, split="test", data_dir="root", cache_dir=None
162
+ ):
163
+ """
164
+ Load a dataset in WebDataset format. Either local paths or HTTP URLs can be specified.
165
+ Expected file structure is:
166
+ ```
167
+ data_dir/
168
+ train/
169
+ nshards.txt
170
+ 0.tar
171
+ 1.tar
172
+ ...
173
+ test/
174
+ nshards.txt
175
+ 0.tar
176
+ 1.tar
177
+ ...
178
+ classnames.txt
179
+ zeroshot_classification_templates.txt
180
+ dataset_type.txt
181
+ ```
182
+ Classnames and templates are required for zeroshot classification, while dataset type
183
+ (equal to "retrieval") is required for zeroshot retrieval datasets.
184
+
185
+ You can use the `clip_benchmark_export_wds` or corresponding API
186
+ (`clip_benchmark.webdataset_builder.convert_dataset`) to convert datasets to this format.
187
+
188
+ Set `cache_dir` to a path to cache the dataset, otherwise, no caching will occur.
189
+ """
190
+ import webdataset as wds
191
+
192
+ def read_txt(fname):
193
+ if "://" in fname:
194
+ stream = os.popen("curl -L -s --fail '%s'" % fname, "r")
195
+ value = stream.read()
196
+ if stream.close():
197
+ raise FileNotFoundError("Failed to retreive data")
198
+ else:
199
+ with open(fname, "r") as file:
200
+ value = file.read()
201
+ return value
202
+
203
+ if not data_dir:
204
+ data_dir = f"https://modelscope.cn/datasets/clip-benchmark/wds_{dataset_name}/resolve/master"
205
+
206
+ # Git LFS files have a different file path to access the raw data than other files
207
+ if data_dir.startswith("https://modelscope.cn/datasets"):
208
+ *split_url_head, _, url_path = data_dir.split("/", 7)
209
+ url_head = "/".join(split_url_head)
210
+ metadata_dir = "/".join([url_head, "resolve", url_path])
211
+ tardata_dir = "/".join([url_head, "resolve", url_path])
212
+ else:
213
+ metadata_dir = tardata_dir = data_dir
214
+ # Get number of shards
215
+ nshards_fname = os.path.join(metadata_dir, split, "nshards.txt")
216
+ nshards = int(
217
+ read_txt(nshards_fname)
218
+ ) # Do not catch FileNotFound, nshards.txt should be mandatory
219
+
220
+ # Get dataset type (classification or retrieval)
221
+ type_fname = os.path.join(metadata_dir, "dataset_type.txt")
222
+ try:
223
+ dataset_type = read_txt(type_fname).strip().lower()
224
+ except FileNotFoundError:
225
+ dataset_type = "classification"
226
+
227
+ filepattern = os.path.join(tardata_dir, split, "{0..%d}.tar" % (nshards - 1))
228
+ # Load webdataset (support WEBP, PNG, and JPG for now)
229
+ if not cache_dir or not isinstance(cache_dir, str):
230
+ cache_dir = None
231
+ else:
232
+ os.makedirs(cache_dir, exist_ok=True)
233
+ dataset = wds.WebDataset(
234
+ filepattern,
235
+ cache_dir=cache_dir,
236
+ nodesplitter=lambda src: src,
237
+ shardshuffle=False,
238
+ verbose=True,
239
+ ).decode(
240
+ wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])
241
+ )
242
+
243
+ # Load based on classification or retrieval task
244
+ if dataset_type == "retrieval":
245
+ dataset = dataset.to_tuple(["webp", "png", "jpg", "jpeg"], "txt").map_tuple(
246
+ transform, str.splitlines
247
+ )
248
+ dataset.classes = dataset.templates = None
249
+ else:
250
+ label_type = (
251
+ "npy" if dataset_type == "multilabel" else "cls"
252
+ ) # Special case for multilabel
253
+ dataset = dataset.to_tuple(
254
+ ["webp", "png", "jpg", "jpeg"], label_type
255
+ ).map_tuple(transform, None)
256
+ # Get class names if present
257
+ classnames_fname = os.path.join(metadata_dir, "classnames.txt")
258
+ try:
259
+ dataset.classes = [
260
+ line.strip() for line in read_txt(classnames_fname).splitlines()
261
+ ]
262
+ except FileNotFoundError:
263
+ logger.warning("WARNING: classnames.txt not found")
264
+ dataset.classes = None
265
+ # Get zeroshot classification templates if present
266
+ templates_fname = os.path.join(
267
+ metadata_dir, "zeroshot_classification_templates.txt"
268
+ )
269
+ try:
270
+ dataset.templates = [
271
+ line.strip() for line in read_txt(templates_fname).splitlines()
272
+ ]
273
+ except FileNotFoundError:
274
+ logger.warning("WARNING: zeroshot_classification_templates.txt not found")
275
+ dataset.templates = None
276
+
277
+ return dataset
@@ -0,0 +1,119 @@
1
+ import os
2
+ import torch
3
+ import json
4
+ from itertools import product
5
+
6
+ from evalscope.backend.rag_eval.clip_benchmark.dataset_builder import (
7
+ build_dataset,
8
+ get_dataset_default_task,
9
+ get_dataloader,
10
+ )
11
+ from evalscope.backend.rag_eval.clip_benchmark.tasks import (
12
+ zeroshot_classification,
13
+ zeroshot_retrieval,
14
+ image_caption,
15
+ )
16
+ from evalscope.backend.rag_eval.clip_benchmark.arguments import Arguments
17
+ from evalscope.backend.rag_eval.utils.clip import VisionModel
18
+ from evalscope.utils.logger import get_logger
19
+
20
+ logger = get_logger()
21
+
22
+
23
+ def evaluate(args: Arguments):
24
+ models = args.models
25
+ dataset_names = args.dataset_name
26
+ data_dir = args.data_dir
27
+ split = args.split
28
+ batch_size = args.batch_size
29
+ num_workers = args.num_workers
30
+ verbose = args.verbose
31
+ input_task = args.task
32
+ output_dir = args.output_dir
33
+ cache_dir = args.cache_dir
34
+ skip_existing = args.skip_existing
35
+ limit = args.limit
36
+
37
+ # Iterate over model and dataset combinations
38
+ for model_cfg, dataset_name in product(models, dataset_names):
39
+ task = input_task or get_dataset_default_task(dataset_name)
40
+ model_name = os.path.basename(model_cfg["model_name"])
41
+
42
+ output_path = os.path.join(output_dir, model_name)
43
+ os.makedirs(output_path, exist_ok=True)
44
+ output_file = os.path.join(output_path, f"{dataset_name}_{task}.json")
45
+
46
+ # Skip evaluation if the result already exists and skip_existing is True
47
+ if os.path.exists(output_file) and skip_existing:
48
+ if verbose:
49
+ logger.info(f"Skip {output_dir}, exists already.")
50
+ return
51
+
52
+ # Determine device (CPU or GPU)
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ model_cfg["device"] = device
55
+ # Initialize the model
56
+ model = VisionModel.load(**model_cfg)
57
+
58
+ # Build the dataset
59
+ dataset = build_dataset(
60
+ dataset_name=dataset_name,
61
+ root=data_dir,
62
+ transform=model.transform,
63
+ split=split,
64
+ wds_cache_dir=f"{cache_dir}/{dataset_name}",
65
+ )
66
+
67
+ # Create the dataloader
68
+ dataloader = get_dataloader(dataset_name, dataset, batch_size, num_workers)
69
+
70
+ # Evaluate based on the task
71
+ if task == "zeroshot_classification":
72
+ zeroshot_templates = (
73
+ dataset.templates if hasattr(dataset, "templates") else None
74
+ )
75
+ if verbose:
76
+ logger.info(f"Zero-shot templates: {zeroshot_templates}")
77
+ classnames = dataset.classes if hasattr(dataset, "classes") else None
78
+ assert (
79
+ zeroshot_templates is not None and classnames is not None
80
+ ), "Dataset does not support classification"
81
+ metrics = zeroshot_classification.evaluate(
82
+ model,
83
+ dataloader,
84
+ classnames,
85
+ zeroshot_templates,
86
+ device=device,
87
+ verbose=verbose,
88
+ limit=limit,
89
+ )
90
+ elif task == "zeroshot_retrieval":
91
+ metrics = zeroshot_retrieval.evaluate(
92
+ model, dataloader, recall_k_list=[5], device=device, limit=limit
93
+ )
94
+ elif task == "image_caption":
95
+ output_path = os.path.join(output_path, dataset_name, "retrieval_data")
96
+ metrics = image_caption.evaluate(
97
+ model, dataloader, limit=limit, output_path=output_path
98
+ )
99
+
100
+ # Prepare dump data
101
+ dump = {
102
+ "dataset": dataset_name,
103
+ "model": model_name,
104
+ "task": task,
105
+ "metrics": metrics,
106
+ }
107
+
108
+ if verbose:
109
+ logger.info(f"Evaluation results: {dump}")
110
+
111
+ # Write the results to output file
112
+ if verbose:
113
+ logger.info(f"Dump results to: {output_file}")
114
+ with open(output_file, "w") as f:
115
+ json.dump(dump, f)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ evaluate()
@@ -0,0 +1,83 @@
1
+ from tqdm import tqdm
2
+ import pandas as pd
3
+ import os
4
+ from evalscope.backend.rag_eval.utils.tools import save_to_jsonl, save_to_tsv
5
+
6
+ from evalscope.utils.logger import get_logger
7
+
8
+ logger = get_logger()
9
+
10
+
11
+ def evaluate(model, dataloader, limit=None, output_path=""):
12
+ """
13
+ Evaluate the model on the dataset
14
+ Parameters
15
+ ----------
16
+ model: MultiModalModel
17
+ model to caption the image
18
+ dataloader: torch.utils.data.Dataloader
19
+ limit: int
20
+ limit the number of samples to evaluate
21
+ Returns
22
+ -------
23
+ dict of retrieval metrics
24
+ """
25
+ sample_count = 0
26
+ dataloader = dataloader_with_indices(dataloader)
27
+ query_caption_index = []
28
+ total_captions = []
29
+ total_querys = []
30
+ for batch_images, batch_texts, inds in tqdm(dataloader):
31
+ captions = model.encode_image(batch_images)
32
+ querys = [text for texts in batch_texts for text in texts]
33
+
34
+ batch_texts_image_index = [
35
+ ind for ind, texts in zip(inds, batch_texts) for text in texts
36
+ ]
37
+
38
+ total_captions.extend(captions)
39
+ total_querys.extend(querys)
40
+ query_caption_index.extend(batch_texts_image_index)
41
+
42
+ if limit is not None:
43
+ # Update sample counter
44
+ sample_count += len(batch_images)
45
+
46
+ if sample_count >= limit:
47
+ break
48
+
49
+ write_file(total_querys, total_captions, query_caption_index, output_path)
50
+ return {"convertion_successful": True, "save_path": output_path}
51
+
52
+
53
+ def write_file(query_list, corpus_list, qrels_list, output_path):
54
+ # 处理 query_list
55
+ query_df = pd.DataFrame(query_list, columns=["text"])
56
+ query_df["_id"] = query_df.index
57
+ query_df = query_df[["_id", "text"]]
58
+ save_to_jsonl(query_df, os.path.join(output_path, "queries.jsonl"))
59
+
60
+ # 处理 corpus_list
61
+ corpus_df = pd.DataFrame(corpus_list, columns=["text"])
62
+ corpus_df["_id"] = corpus_df.index
63
+ corpus_df = corpus_df[["_id", "text"]]
64
+ save_to_jsonl(corpus_df, os.path.join(output_path, "corpus.jsonl"))
65
+
66
+ # 处理 qrels_list
67
+ qrels_df = pd.DataFrame(qrels_list, columns=["corpus-id"])
68
+ qrels_df["query-id"] = qrels_df.index
69
+ qrels_df["score"] = 1
70
+ qrels_df = qrels_df[["query-id", "corpus-id", "score"]]
71
+ save_to_tsv(qrels_df, os.path.join(output_path, "qrels", "test.tsv"))
72
+
73
+ logger.info("Write files to {}".format(output_path))
74
+ return
75
+
76
+
77
+ def dataloader_with_indices(dataloader):
78
+ start = 0
79
+ for x, y in dataloader:
80
+ end = start + len(x)
81
+ inds = list(range(start, end))
82
+ yield x, y, inds
83
+ start = end