capevalkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- capevalkit/__init__.py +30 -0
- capevalkit/api.py +509 -0
- capevalkit/benchmarks.py +1093 -0
- capevalkit/cli.py +379 -0
- capevalkit/compat.py +12 -0
- capevalkit/context.py +110 -0
- capevalkit/correlations.py +73 -0
- capevalkit/dispatcher.py +204 -0
- capevalkit/downloads.py +370 -0
- capevalkit/launcher.py +31 -0
- capevalkit/manifests.py +127 -0
- capevalkit/metrics/__init__.py +2 -0
- capevalkit/metrics/clipscore_metric.py +239 -0
- capevalkit/metrics/fleur_metric.py +254 -0
- capevalkit/metrics/pacscore_metric.py +227 -0
- capevalkit/metrics/polos.py +95 -0
- capevalkit/metrics/polos_validate.py +63 -0
- capevalkit/metrics/pycocoevalcap_metrics.py +147 -0
- capevalkit/metrics/vela_metric.py +151 -0
- capevalkit/overlays.py +46 -0
- capevalkit/paths.py +21 -0
- capevalkit/progress.py +24 -0
- capevalkit/reproduce.py +929 -0
- capevalkit/resources/benchmarks/expected/bleu/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/bleu/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/bleu/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/bleu/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/bleu/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/cider/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/cider/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/cider/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/cider/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/cider/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/clipscore/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/clipscore/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/clipscore/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/clipscore/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/clipscore/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/fleur/composite.json +7 -0
- capevalkit/resources/benchmarks/expected/fleur/flickr8k-cf.json +7 -0
- capevalkit/resources/benchmarks/expected/fleur/flickr8k-ex.json +7 -0
- capevalkit/resources/benchmarks/expected/meteor/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/meteor/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/meteor/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/meteor/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/meteor/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/pacscore/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/pacscore/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/pacscore/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/pacscore/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/pacscore/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/polos/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/polos/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/polos/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/polos/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/polos/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/refclipscore/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/refclipscore/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/refclipscore/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/refclipscore/nebula.json +7 -0
- capevalkit/resources/benchmarks/expected/refclipscore/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/reffleur/composite.json +7 -0
- capevalkit/resources/benchmarks/expected/reffleur/flickr8k-cf.json +7 -0
- capevalkit/resources/benchmarks/expected/reffleur/flickr8k-ex.json +7 -0
- capevalkit/resources/benchmarks/expected/refpacscore/composite.json +7 -0
- capevalkit/resources/benchmarks/expected/refpacscore/flickr8k-cf.json +7 -0
- capevalkit/resources/benchmarks/expected/refpacscore/flickr8k-ex.json +7 -0
- capevalkit/resources/benchmarks/expected/refpacscore/nebula.json +7 -0
- capevalkit/resources/benchmarks/expected/refpacscore/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/rouge/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/rouge/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/rouge/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/rouge/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/rouge/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/spice/composite.json +8 -0
- capevalkit/resources/benchmarks/expected/spice/flickr8k-cf.json +8 -0
- capevalkit/resources/benchmarks/expected/spice/flickr8k-ex.json +8 -0
- capevalkit/resources/benchmarks/expected/spice/nebula.json +8 -0
- capevalkit/resources/benchmarks/expected/spice/polaris.json +7 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testa-desc.json +8 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testa-flu.json +8 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testa-rel.json +8 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testb-desc.json +8 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testb-flu.json +8 -0
- capevalkit/resources/benchmarks/expected/vela/longcaparena-testb-rel.json +8 -0
- capevalkit/resources/metrics/bleu/metric.toml +19 -0
- capevalkit/resources/metrics/cider/metric.toml +19 -0
- capevalkit/resources/metrics/clipscore/metric.toml +19 -0
- capevalkit/resources/metrics/clipscore-vitl/metric.toml +18 -0
- capevalkit/resources/metrics/clipscoreavg/metric.toml +18 -0
- capevalkit/resources/metrics/fleur/metric.toml +18 -0
- capevalkit/resources/metrics/meteor/metric.toml +19 -0
- capevalkit/resources/metrics/pacscore/metric.toml +18 -0
- capevalkit/resources/metrics/pacscore-vitl/metric.toml +18 -0
- capevalkit/resources/metrics/pacscoreavg/metric.toml +18 -0
- capevalkit/resources/metrics/pacscorepp/metric.toml +18 -0
- capevalkit/resources/metrics/pacscoreppavg/metric.toml +18 -0
- capevalkit/resources/metrics/polos/metric.toml +28 -0
- capevalkit/resources/metrics/refclipscore/metric.toml +18 -0
- capevalkit/resources/metrics/refclipscore-vitl/metric.toml +18 -0
- capevalkit/resources/metrics/reffleur/metric.toml +18 -0
- capevalkit/resources/metrics/refpacscore/metric.toml +18 -0
- capevalkit/resources/metrics/refpacscore-vitl/metric.toml +18 -0
- capevalkit/resources/metrics/refpacscorepp/metric.toml +18 -0
- capevalkit/resources/metrics/rouge/metric.toml +19 -0
- capevalkit/resources/metrics/spice/metric.toml +19 -0
- capevalkit/resources/metrics/vela/metric.toml +18 -0
- capevalkit/resources/overlays/metrics/upstreams/clipscore/pyproject.toml +18 -0
- capevalkit/resources/overlays/metrics/upstreams/clipscore/uv.toml +2 -0
- capevalkit/resources/overlays/metrics/upstreams/fleur/fleur_wrapper/__init__.py +1 -0
- capevalkit/resources/overlays/metrics/upstreams/fleur/pyproject.toml +22 -0
- capevalkit/resources/overlays/metrics/upstreams/fleur/uv.toml +1 -0
- capevalkit/resources/overlays/metrics/upstreams/pacscore/pyproject.toml +22 -0
- capevalkit/resources/overlays/metrics/upstreams/pacscore/uv.toml +2 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/polos/models/encoders/__init__.py +14 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/polos/models/encoders/bert.py +106 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/polos/models/estimators/polos_estimator.py +239 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/polos/models/model_base.py +276 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/polos/tokenizers_/__init__.py +13 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/pyproject.toml +31 -0
- capevalkit/resources/overlays/metrics/upstreams/polos/uv.toml +2 -0
- capevalkit/resources/overlays/metrics/upstreams/pycocoevalcap/pyproject.toml +12 -0
- capevalkit/resources/overlays/metrics/upstreams/pycocoevalcap/uv.toml +2 -0
- capevalkit/resources/overlays/metrics/upstreams/vela/configs/config_regressor.yaml +50 -0
- capevalkit/resources/overlays/metrics/upstreams/vela/pyproject.toml +30 -0
- capevalkit/resources/overlays/metrics/upstreams/vela/uv.toml +1 -0
- capevalkit/resources/upstreams.lock.json +79 -0
- capevalkit/runtime.py +219 -0
- capevalkit/runtime_env.py +20 -0
- capevalkit/verify.py +118 -0
- capevalkit-0.1.0.dist-info/METADATA +482 -0
- capevalkit-0.1.0.dist-info/RECORD +135 -0
- capevalkit-0.1.0.dist-info/WHEEL +4 -0
- capevalkit-0.1.0.dist-info/entry_points.txt +2 -0
- capevalkit-0.1.0.dist-info/licenses/LICENSE +32 -0
capevalkit/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Image captioning metric evaluation kit."""
|
|
2
|
+
|
|
3
|
+
from .api import (
|
|
4
|
+
CaptionBatch,
|
|
5
|
+
CaptionEvalRun,
|
|
6
|
+
CaptionSample,
|
|
7
|
+
MetricOutput,
|
|
8
|
+
benchmark,
|
|
9
|
+
evaluate_caption_model,
|
|
10
|
+
evaluate_captions,
|
|
11
|
+
evaluate_metric,
|
|
12
|
+
load_samples,
|
|
13
|
+
score,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__version__ = "0.1.0"
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"CaptionBatch",
|
|
20
|
+
"CaptionEvalRun",
|
|
21
|
+
"CaptionSample",
|
|
22
|
+
"MetricOutput",
|
|
23
|
+
"benchmark",
|
|
24
|
+
"evaluate_caption_model",
|
|
25
|
+
"evaluate_captions",
|
|
26
|
+
"evaluate_metric",
|
|
27
|
+
"load_samples",
|
|
28
|
+
"score",
|
|
29
|
+
"__version__",
|
|
30
|
+
]
|
capevalkit/api.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Union
|
|
8
|
+
|
|
9
|
+
from .benchmarks import (
|
|
10
|
+
BenchmarkItem,
|
|
11
|
+
benchmark_metric,
|
|
12
|
+
benchmark_result,
|
|
13
|
+
load_benchmark,
|
|
14
|
+
)
|
|
15
|
+
from .compat import zip_strict
|
|
16
|
+
from .dispatcher import dispatch
|
|
17
|
+
from .manifests import get_manifest, load_manifests
|
|
18
|
+
from .paths import repo_root
|
|
19
|
+
from .reproduce import NO_REFERENCE_METRICS
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class CaptionSample:
|
|
24
|
+
id: str
|
|
25
|
+
image: str
|
|
26
|
+
references: list[str]
|
|
27
|
+
prediction: str | None = None
|
|
28
|
+
human_score: float | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class CaptionBatch:
|
|
33
|
+
ids: list[str]
|
|
34
|
+
images: list[str]
|
|
35
|
+
references: list[list[str]]
|
|
36
|
+
samples: list[CaptionSample]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class MetricOutput:
|
|
41
|
+
name: str
|
|
42
|
+
per_item: Mapping[str, float]
|
|
43
|
+
score: float | None = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
MetricCallable = Callable[[list[CaptionSample]], Union[Mapping[str, float], MetricOutput, dict[str, Any]]]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def score(
|
|
50
|
+
metric: str,
|
|
51
|
+
predictions: str,
|
|
52
|
+
output: str,
|
|
53
|
+
*,
|
|
54
|
+
references: str | None = None,
|
|
55
|
+
image_dir: str | None = None,
|
|
56
|
+
extra_args: list[str] | None = None,
|
|
57
|
+
) -> int:
|
|
58
|
+
manifest = get_manifest(metric)
|
|
59
|
+
command = [
|
|
60
|
+
*manifest.runner,
|
|
61
|
+
"--predictions",
|
|
62
|
+
str(Path(predictions).resolve()),
|
|
63
|
+
"--output",
|
|
64
|
+
str(Path(output).resolve()),
|
|
65
|
+
]
|
|
66
|
+
if references:
|
|
67
|
+
command.extend(["--references", str(Path(references).resolve())])
|
|
68
|
+
if image_dir:
|
|
69
|
+
command.extend(["--image-dir", str(Path(image_dir).resolve())])
|
|
70
|
+
command.extend(extra_args or [])
|
|
71
|
+
return dispatch(metric, command)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def benchmark(
|
|
75
|
+
metric: str,
|
|
76
|
+
benchmark_name: str,
|
|
77
|
+
output: str,
|
|
78
|
+
*,
|
|
79
|
+
data_root: str | None = None,
|
|
80
|
+
extra_args: list[str] | None = None,
|
|
81
|
+
use_references: bool = True,
|
|
82
|
+
score_key: str | None = None,
|
|
83
|
+
limit: int | None = None,
|
|
84
|
+
) -> int:
|
|
85
|
+
return benchmark_metric(
|
|
86
|
+
metric,
|
|
87
|
+
benchmark_name,
|
|
88
|
+
output,
|
|
89
|
+
data_root=data_root,
|
|
90
|
+
metric_args=extra_args,
|
|
91
|
+
use_references=use_references,
|
|
92
|
+
score_key=score_key,
|
|
93
|
+
limit=limit,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_samples(
|
|
98
|
+
benchmark_name: str,
|
|
99
|
+
*,
|
|
100
|
+
data_root: str | None = None,
|
|
101
|
+
predictions: str | Path | Mapping[str, str] | None = None,
|
|
102
|
+
limit: int | None = None,
|
|
103
|
+
) -> list[CaptionSample]:
|
|
104
|
+
items = load_benchmark(benchmark_name, data_root, limit=limit)
|
|
105
|
+
prediction_map = _load_prediction_map(predictions)
|
|
106
|
+
return [_sample_from_item(item, prediction_map) for item in items]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def evaluate_metric(
|
|
110
|
+
*,
|
|
111
|
+
benchmark: str,
|
|
112
|
+
metric: MetricCallable,
|
|
113
|
+
metric_name: str = "metric",
|
|
114
|
+
data_root: str | None = None,
|
|
115
|
+
predictions: str | Path | Mapping[str, str] | None = None,
|
|
116
|
+
output: str | Path | None = None,
|
|
117
|
+
score_key: str | None = None,
|
|
118
|
+
limit: int | None = None,
|
|
119
|
+
) -> dict[str, Any]:
|
|
120
|
+
items = load_benchmark(benchmark, data_root, limit=limit)
|
|
121
|
+
if not items:
|
|
122
|
+
raise ValueError(f"{benchmark} has no benchmark items")
|
|
123
|
+
prediction_map = _load_prediction_map(predictions)
|
|
124
|
+
samples = [_sample_from_item(item, prediction_map) for item in items]
|
|
125
|
+
metric_output = _normalize_metric_output(metric(samples), metric_name)
|
|
126
|
+
evaluated_items = [
|
|
127
|
+
BenchmarkItem(
|
|
128
|
+
id=item.id,
|
|
129
|
+
image=item.image,
|
|
130
|
+
caption=sample.prediction or "",
|
|
131
|
+
references=item.references,
|
|
132
|
+
score=item.score,
|
|
133
|
+
)
|
|
134
|
+
for item, sample in zip_strict(items, samples)
|
|
135
|
+
]
|
|
136
|
+
result = benchmark_result(
|
|
137
|
+
metric_name,
|
|
138
|
+
benchmark,
|
|
139
|
+
items=evaluated_items,
|
|
140
|
+
metric_output=metric_output,
|
|
141
|
+
score_key=score_key,
|
|
142
|
+
)
|
|
143
|
+
if output is not None:
|
|
144
|
+
output_path = Path(output)
|
|
145
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
output_path.write_text(json.dumps(result, indent=2, sort_keys=True))
|
|
147
|
+
return result
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class CaptionEvalRun:
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
*,
|
|
154
|
+
images: Sequence[str | Path],
|
|
155
|
+
metrics: Sequence[str],
|
|
156
|
+
ids: Sequence[str] | None = None,
|
|
157
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None = None,
|
|
158
|
+
output_dir: str | Path | None = None,
|
|
159
|
+
limit: int | None = None,
|
|
160
|
+
) -> None:
|
|
161
|
+
self.metrics = list(metrics)
|
|
162
|
+
self.output_dir = (
|
|
163
|
+
Path(output_dir).resolve()
|
|
164
|
+
if output_dir is not None
|
|
165
|
+
else repo_root() / "outputs" / "caption-model"
|
|
166
|
+
)
|
|
167
|
+
self.samples = _samples_from_images(images, ids=ids, references=references, limit=limit)
|
|
168
|
+
if not self.samples:
|
|
169
|
+
raise ValueError("no images were provided")
|
|
170
|
+
self._captions: dict[str, str] = {}
|
|
171
|
+
|
|
172
|
+
def __enter__(self) -> CaptionEvalRun:
|
|
173
|
+
return self
|
|
174
|
+
|
|
175
|
+
def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
def iter_batches(self, batch_size: int = 1) -> Iterable[CaptionBatch]:
|
|
179
|
+
if batch_size <= 0:
|
|
180
|
+
raise ValueError("batch_size must be positive")
|
|
181
|
+
for start in range(0, len(self.samples), batch_size):
|
|
182
|
+
samples = self.samples[start:start + batch_size]
|
|
183
|
+
yield CaptionBatch(
|
|
184
|
+
ids=[sample.id for sample in samples],
|
|
185
|
+
images=[sample.image for sample in samples],
|
|
186
|
+
references=[sample.references for sample in samples],
|
|
187
|
+
samples=samples,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def record(
|
|
191
|
+
self,
|
|
192
|
+
ids: Sequence[str] | Mapping[str, str],
|
|
193
|
+
captions: Sequence[str] | None = None,
|
|
194
|
+
) -> None:
|
|
195
|
+
if isinstance(ids, Mapping):
|
|
196
|
+
if captions is not None:
|
|
197
|
+
raise ValueError("captions must be omitted when ids is a mapping")
|
|
198
|
+
items = ids.items()
|
|
199
|
+
else:
|
|
200
|
+
if captions is None:
|
|
201
|
+
raise ValueError("captions are required when ids is a sequence")
|
|
202
|
+
items = zip_strict(ids, captions)
|
|
203
|
+
known_ids = {sample.id for sample in self.samples}
|
|
204
|
+
for item_id, caption in items:
|
|
205
|
+
item_id = str(item_id)
|
|
206
|
+
if item_id not in known_ids:
|
|
207
|
+
raise KeyError(f"unknown sample id: {item_id}")
|
|
208
|
+
self._captions[item_id] = str(caption)
|
|
209
|
+
|
|
210
|
+
def evaluate(
|
|
211
|
+
self,
|
|
212
|
+
*,
|
|
213
|
+
extra_args_by_metric: Mapping[str, Sequence[str]] | None = None,
|
|
214
|
+
quiet: bool = False,
|
|
215
|
+
) -> dict[str, Any]:
|
|
216
|
+
missing = [sample.id for sample in self.samples if sample.id not in self._captions]
|
|
217
|
+
if missing:
|
|
218
|
+
raise ValueError(f"missing captions for sample ids: {missing[:5]}")
|
|
219
|
+
|
|
220
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
221
|
+
predictions_path = self.output_dir / "predictions.jsonl"
|
|
222
|
+
references_path = self.output_dir / "references.jsonl"
|
|
223
|
+
_write_jsonl(
|
|
224
|
+
predictions_path,
|
|
225
|
+
[
|
|
226
|
+
{
|
|
227
|
+
"id": sample.id,
|
|
228
|
+
"caption": self._captions[sample.id],
|
|
229
|
+
"image": sample.image,
|
|
230
|
+
}
|
|
231
|
+
for sample in self.samples
|
|
232
|
+
],
|
|
233
|
+
)
|
|
234
|
+
_write_jsonl(
|
|
235
|
+
references_path,
|
|
236
|
+
[{"id": sample.id, "references": sample.references} for sample in self.samples],
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
results: dict[str, Any] = {}
|
|
240
|
+
for metric in self.metrics:
|
|
241
|
+
output_path = self.output_dir / f"{metric}.json"
|
|
242
|
+
command = _metric_score_command(
|
|
243
|
+
metric,
|
|
244
|
+
predictions=predictions_path,
|
|
245
|
+
references=references_path,
|
|
246
|
+
output=output_path,
|
|
247
|
+
extra_args=list((extra_args_by_metric or {}).get(metric, ())),
|
|
248
|
+
)
|
|
249
|
+
code = dispatch(metric, command, quiet=quiet)
|
|
250
|
+
if code != 0:
|
|
251
|
+
raise RuntimeError(f"{metric} exited with code {code}")
|
|
252
|
+
results[metric] = json.loads(output_path.read_text())
|
|
253
|
+
return results
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def evaluate_caption_model(
|
|
257
|
+
*,
|
|
258
|
+
images: Sequence[str | Path],
|
|
259
|
+
metrics: Sequence[str],
|
|
260
|
+
predict: Callable[[CaptionBatch], Sequence[str] | Mapping[str, str]],
|
|
261
|
+
ids: Sequence[str] | None = None,
|
|
262
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None = None,
|
|
263
|
+
output_dir: str | Path | None = None,
|
|
264
|
+
batch_size: int = 1,
|
|
265
|
+
limit: int | None = None,
|
|
266
|
+
extra_args_by_metric: Mapping[str, Sequence[str]] | None = None,
|
|
267
|
+
quiet: bool = False,
|
|
268
|
+
) -> dict[str, Any]:
|
|
269
|
+
with CaptionEvalRun(
|
|
270
|
+
images=images,
|
|
271
|
+
metrics=metrics,
|
|
272
|
+
ids=ids,
|
|
273
|
+
references=references,
|
|
274
|
+
output_dir=output_dir,
|
|
275
|
+
limit=limit,
|
|
276
|
+
) as run:
|
|
277
|
+
for batch in run.iter_batches(batch_size=batch_size):
|
|
278
|
+
captions = predict(batch)
|
|
279
|
+
if isinstance(captions, Mapping):
|
|
280
|
+
run.record(captions)
|
|
281
|
+
else:
|
|
282
|
+
run.record(batch.ids, captions)
|
|
283
|
+
return run.evaluate(extra_args_by_metric=extra_args_by_metric, quiet=quiet)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def evaluate_captions(
|
|
287
|
+
*,
|
|
288
|
+
metrics: Sequence[str],
|
|
289
|
+
pairs: Sequence[Mapping[str, Any]] | None = None,
|
|
290
|
+
images: Sequence[str | Path] | None = None,
|
|
291
|
+
captions: Sequence[str] | Mapping[str, str] | None = None,
|
|
292
|
+
ids: Sequence[str] | None = None,
|
|
293
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None = None,
|
|
294
|
+
output_dir: str | Path | None = None,
|
|
295
|
+
limit: int | None = None,
|
|
296
|
+
extra_args_by_metric: Mapping[str, Sequence[str]] | None = None,
|
|
297
|
+
quiet: bool = False,
|
|
298
|
+
) -> dict[str, Any]:
|
|
299
|
+
images, ids, captions, references = _caption_inputs(
|
|
300
|
+
pairs=pairs,
|
|
301
|
+
images=images,
|
|
302
|
+
captions=captions,
|
|
303
|
+
ids=ids,
|
|
304
|
+
references=references,
|
|
305
|
+
)
|
|
306
|
+
with CaptionEvalRun(
|
|
307
|
+
images=images,
|
|
308
|
+
metrics=metrics,
|
|
309
|
+
ids=ids,
|
|
310
|
+
references=references,
|
|
311
|
+
output_dir=output_dir,
|
|
312
|
+
limit=limit,
|
|
313
|
+
) as run:
|
|
314
|
+
run.record({sample.id: captions[sample.id] for sample in run.samples})
|
|
315
|
+
return run.evaluate(extra_args_by_metric=extra_args_by_metric, quiet=quiet)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _sample_from_item(item: BenchmarkItem, predictions: Mapping[str, str] | None) -> CaptionSample:
|
|
319
|
+
return CaptionSample(
|
|
320
|
+
id=item.id,
|
|
321
|
+
image=item.image,
|
|
322
|
+
references=item.references,
|
|
323
|
+
prediction=predictions[item.id] if predictions is not None else item.caption,
|
|
324
|
+
human_score=item.score,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _samples_from_images(
|
|
329
|
+
images: Sequence[str | Path],
|
|
330
|
+
*,
|
|
331
|
+
ids: Sequence[str] | None,
|
|
332
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None,
|
|
333
|
+
limit: int | None,
|
|
334
|
+
) -> list[CaptionSample]:
|
|
335
|
+
image_values = [str(Path(image).resolve()) for image in images]
|
|
336
|
+
if limit is not None:
|
|
337
|
+
image_values = image_values[:limit]
|
|
338
|
+
if ids is None:
|
|
339
|
+
id_values = [str(index) for index in range(len(image_values))]
|
|
340
|
+
else:
|
|
341
|
+
id_values = [str(item_id) for item_id in ids]
|
|
342
|
+
if limit is not None:
|
|
343
|
+
id_values = id_values[:limit]
|
|
344
|
+
if len(id_values) != len(image_values):
|
|
345
|
+
raise ValueError("ids and images must have the same length")
|
|
346
|
+
refs_by_id = _references_by_id(id_values, references, limit=limit)
|
|
347
|
+
return [
|
|
348
|
+
CaptionSample(id=item_id, image=image, references=refs_by_id[item_id])
|
|
349
|
+
for item_id, image in zip_strict(id_values, image_values)
|
|
350
|
+
]
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _caption_inputs(
|
|
354
|
+
*,
|
|
355
|
+
pairs: Sequence[Mapping[str, Any]] | None,
|
|
356
|
+
images: Sequence[str | Path] | None,
|
|
357
|
+
captions: Sequence[str] | Mapping[str, str] | None,
|
|
358
|
+
ids: Sequence[str] | None,
|
|
359
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None,
|
|
360
|
+
) -> tuple[list[str], list[str], dict[str, str], Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None]:
|
|
361
|
+
if pairs is not None:
|
|
362
|
+
if images is not None or captions is not None or ids is not None or references is not None:
|
|
363
|
+
raise ValueError("pairs cannot be combined with images, captions, ids, or references")
|
|
364
|
+
image_values: list[str] = []
|
|
365
|
+
id_values: list[str] = []
|
|
366
|
+
caption_values: dict[str, str] = {}
|
|
367
|
+
reference_values: dict[str, str | Sequence[str]] = {}
|
|
368
|
+
for index, pair in enumerate(pairs):
|
|
369
|
+
item_id = str(pair.get("id", index))
|
|
370
|
+
caption = pair.get("caption", pair.get("prediction"))
|
|
371
|
+
if not isinstance(caption, str):
|
|
372
|
+
raise ValueError("each pair must contain caption or prediction")
|
|
373
|
+
image = pair.get("image", pair.get("image_path"))
|
|
374
|
+
if image is None:
|
|
375
|
+
raise ValueError("each pair must contain image or image_path")
|
|
376
|
+
image_values.append(str(image))
|
|
377
|
+
id_values.append(item_id)
|
|
378
|
+
caption_values[item_id] = caption
|
|
379
|
+
if "references" in pair:
|
|
380
|
+
reference_values[item_id] = pair["references"]
|
|
381
|
+
elif "captions" in pair:
|
|
382
|
+
reference_values[item_id] = pair["captions"]
|
|
383
|
+
return image_values, id_values, caption_values, reference_values
|
|
384
|
+
|
|
385
|
+
if images is None or captions is None:
|
|
386
|
+
raise ValueError("either pairs or both images and captions are required")
|
|
387
|
+
image_values = [str(image) for image in images]
|
|
388
|
+
id_values = [str(item_id) for item_id in ids] if ids is not None else [str(index) for index in range(len(image_values))]
|
|
389
|
+
if len(id_values) != len(image_values):
|
|
390
|
+
raise ValueError("ids and images must have the same length")
|
|
391
|
+
if isinstance(captions, Mapping):
|
|
392
|
+
caption_values = {str(item_id): str(caption) for item_id, caption in captions.items()}
|
|
393
|
+
else:
|
|
394
|
+
caption_list = [str(caption) for caption in captions]
|
|
395
|
+
if len(caption_list) != len(image_values):
|
|
396
|
+
raise ValueError("captions and images must have the same length")
|
|
397
|
+
caption_values = dict(zip_strict(id_values, caption_list))
|
|
398
|
+
return image_values, id_values, caption_values, references
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _references_by_id(
|
|
402
|
+
ids: Sequence[str],
|
|
403
|
+
references: Sequence[str | Sequence[str]] | Mapping[str, str | Sequence[str]] | None,
|
|
404
|
+
*,
|
|
405
|
+
limit: int | None,
|
|
406
|
+
) -> dict[str, list[str]]:
|
|
407
|
+
if references is None:
|
|
408
|
+
return {item_id: [] for item_id in ids}
|
|
409
|
+
if isinstance(references, Mapping):
|
|
410
|
+
return {item_id: _reference_list(references.get(item_id, [])) for item_id in ids}
|
|
411
|
+
reference_values = list(references)
|
|
412
|
+
if limit is not None:
|
|
413
|
+
reference_values = reference_values[:limit]
|
|
414
|
+
if len(reference_values) != len(ids):
|
|
415
|
+
raise ValueError("references and images must have the same length")
|
|
416
|
+
return {
|
|
417
|
+
item_id: _reference_list(reference)
|
|
418
|
+
for item_id, reference in zip_strict(ids, reference_values)
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def _reference_list(value: str | Sequence[str] | Any) -> list[str]:
|
|
423
|
+
if value is None:
|
|
424
|
+
return []
|
|
425
|
+
if isinstance(value, str):
|
|
426
|
+
return [value]
|
|
427
|
+
if isinstance(value, Sequence):
|
|
428
|
+
return [str(item) for item in value]
|
|
429
|
+
raise ValueError("references must be strings or sequences of strings")
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _load_prediction_map(predictions: str | Path | Mapping[str, str] | None) -> dict[str, str] | None:
|
|
433
|
+
if predictions is None:
|
|
434
|
+
return None
|
|
435
|
+
if isinstance(predictions, Mapping):
|
|
436
|
+
return {str(item_id): str(caption) for item_id, caption in predictions.items()}
|
|
437
|
+
rows = {}
|
|
438
|
+
for line in Path(predictions).read_text().splitlines():
|
|
439
|
+
if not line.strip():
|
|
440
|
+
continue
|
|
441
|
+
row = json.loads(line)
|
|
442
|
+
caption = row.get("caption", row.get("prediction"))
|
|
443
|
+
if not isinstance(caption, str):
|
|
444
|
+
raise ValueError("prediction rows must contain caption or prediction")
|
|
445
|
+
rows[str(row["id"])] = caption
|
|
446
|
+
return rows
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def _normalize_metric_output(result: Mapping[str, float] | MetricOutput | dict[str, Any], metric_name: str) -> dict[str, Any]:
|
|
450
|
+
if isinstance(result, MetricOutput):
|
|
451
|
+
per_item = {str(item_id): float(score) for item_id, score in result.per_item.items()}
|
|
452
|
+
return {
|
|
453
|
+
result.name: {
|
|
454
|
+
"score": float(result.score) if result.score is not None else _mean(per_item.values()),
|
|
455
|
+
"per_item": per_item,
|
|
456
|
+
}
|
|
457
|
+
}
|
|
458
|
+
if not isinstance(result, Mapping):
|
|
459
|
+
raise TypeError("metric callable must return a mapping or MetricOutput")
|
|
460
|
+
if _looks_like_metric_output(result):
|
|
461
|
+
return dict(result)
|
|
462
|
+
per_item = {str(item_id): float(score) for item_id, score in result.items()}
|
|
463
|
+
return {metric_name: {"score": _mean(per_item.values()), "per_item": per_item}}
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def _looks_like_metric_output(result: Mapping[str, Any]) -> bool:
|
|
467
|
+
if isinstance(result.get("per_item"), Mapping):
|
|
468
|
+
return True
|
|
469
|
+
return any(isinstance(value, Mapping) and isinstance(value.get("per_item"), Mapping) for value in result.values())
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _mean(values: Iterable[float]) -> float:
|
|
473
|
+
numbers = list(values)
|
|
474
|
+
return sum(numbers) / len(numbers) if numbers else 0.0
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _metric_score_command(
|
|
478
|
+
metric: str,
|
|
479
|
+
*,
|
|
480
|
+
predictions: Path,
|
|
481
|
+
references: Path,
|
|
482
|
+
output: Path,
|
|
483
|
+
extra_args: list[str],
|
|
484
|
+
) -> list[str]:
|
|
485
|
+
manifest = get_manifest(metric)
|
|
486
|
+
command = [*manifest.runner, "--predictions", str(predictions), "--output", str(output), *extra_args]
|
|
487
|
+
if metric not in NO_REFERENCE_METRICS:
|
|
488
|
+
command[command.index("--output"):command.index("--output")] = ["--references", str(references)]
|
|
489
|
+
return command
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _write_jsonl(path: Path, rows: Sequence[Mapping[str, Any]]) -> None:
|
|
493
|
+
path.write_text("".join(json.dumps(row, ensure_ascii=False) + "\n" for row in rows))
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
__all__ = [
|
|
497
|
+
"CaptionBatch",
|
|
498
|
+
"CaptionEvalRun",
|
|
499
|
+
"CaptionSample",
|
|
500
|
+
"MetricOutput",
|
|
501
|
+
"benchmark",
|
|
502
|
+
"evaluate_caption_model",
|
|
503
|
+
"evaluate_captions",
|
|
504
|
+
"evaluate_metric",
|
|
505
|
+
"get_manifest",
|
|
506
|
+
"load_manifests",
|
|
507
|
+
"load_samples",
|
|
508
|
+
"score",
|
|
509
|
+
]
|