aviary.labbench 0.30.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,155 @@
1
+ Metadata-Version: 2.4
2
+ Name: aviary.labbench
3
+ Version: 0.30.0
4
+ Summary: LAB-Bench environments implemented with aviary
5
+ Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
+ Classifier: Intended Audience :: Developers
7
+ Classifier: License :: OSI Approved :: Apache Software License
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: Programming Language :: Python :: 3 :: Only
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Classifier: Programming Language :: Python :: 3.14
14
+ Classifier: Programming Language :: Python
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.11
17
+ Description-Content-Type: text/markdown
18
+ Requires-Dist: fhaviary>=0.14
19
+ Requires-Dist: fhlmi
20
+ Requires-Dist: ldp>=0.25.2
21
+ Requires-Dist: paper-qa[pymupdf]>=2025
22
+ Requires-Dist: pydantic~=2.0
23
+ Requires-Dist: tenacity
24
+ Requires-Dist: typing-extensions; python_version <= "3.12"
25
+ Provides-Extra: datasets
26
+ Requires-Dist: datasets>=2.15; extra == "datasets"
27
+ Provides-Extra: dev
28
+ Requires-Dist: aviary.labbench[datasets,typing]; extra == "dev"
29
+ Requires-Dist: pandas; extra == "dev"
30
+ Requires-Dist: paper-qa>=5.29.1; extra == "dev"
31
+ Requires-Dist: tantivy>=0.25.0; python_version >= "3.14" and extra == "dev"
32
+ Provides-Extra: typing
33
+ Requires-Dist: pillow; extra == "typing"
34
+
35
+ # aviary.labbench
36
+
37
+ LAB-Bench environments implemented with aviary,
38
+ allowing agents to perform question answering on scientific tasks.
39
+
40
+ ## Installation
41
+
42
+ To install the LAB-Bench environment, run:
43
+
44
+ ```bash
45
+ pip install 'fhaviary[labbench]'
46
+ ```
47
+
48
+ ## Usage
49
+
50
+ In [`labbench/env.py`](src/aviary/envs/labbench/env.py), you will find:
51
+
52
+ - `GradablePaperQAEnvironment`: an PaperQA-backed environment
53
+ that can grade answers given an evaluation function.
54
+ - `ImageQAEnvironment`: an `GradablePaperQAEnvironment`
55
+ subclass for QA where image(s) are pre-added.
56
+
57
+ And in [`labbench/task.py`](src/aviary/envs/labbench/task.py), you will find:
58
+
59
+ - `TextQATaskDataset`: a task dataset designed to
60
+ pull down FigQA, LitQA2, or TableQA from Hugging Face,
61
+ and create one `GradablePaperQAEnvironment` per question.
62
+ - `ImageQATaskDataset`: a task dataset that pairs with `ImageQAEnvironment`
63
+ for FigQA or TableQA.
64
+
65
+ Here is an example of how to use them:
66
+
67
+ ```python
68
+ import os
69
+
70
+ from ldp.agent import SimpleAgent
71
+ from ldp.alg import Evaluator, EvaluatorConfig, MeanMetricsCallback
72
+ from paperqa import Settings
73
+
74
+ from aviary.env import TaskDataset
75
+
76
+
77
+ async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
78
+ settings = Settings(paper_directory=folder_of_litqa_v2_papers)
79
+ dataset = TaskDataset.from_name("litqa2", settings=settings)
80
+ metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
81
+
82
+ evaluator = Evaluator(
83
+ config=EvaluatorConfig(batch_size=3),
84
+ agent=SimpleAgent(),
85
+ dataset=dataset,
86
+ callbacks=[metrics_callback],
87
+ )
88
+ await evaluator.evaluate()
89
+ print(metrics_callback.eval_means)
90
+ ```
91
+
92
+ ### Image Question-Answer
93
+
94
+ This is an environment/dataset for giving PaperQA a `Docs` object with
95
+ the image(s) for one LAB-Bench question.
96
+ It's designed to be a comparison with zero-shotting the question to a LLM,
97
+ but instead of a singular prompt the image is put through the PaperQA agent loop.
98
+
99
+ ```python
100
+ from typing import cast
101
+
102
+ import litellm
103
+ import pytest
104
+ from ldp.agent import Agent
105
+ from ldp.alg import (
106
+ Evaluator,
107
+ EvaluatorConfig,
108
+ MeanMetricsCallback,
109
+ StoreTrajectoriesCallback,
110
+ )
111
+ from paperqa.settings import AgentSettings, IndexSettings
112
+
113
+ from aviary.envs.labbench import (
114
+ ImageQAEnvironment,
115
+ ImageQATaskDataset,
116
+ LABBenchDatasets,
117
+ )
118
+
119
+
120
+ @pytest.mark.asyncio
121
+ async def test_image_qa(tmp_path) -> None:
122
+ litellm.num_retries = 8 # Mitigate connection-related failures
123
+ settings = ImageQAEnvironment.make_base_settings()
124
+ settings.agent = AgentSettings(
125
+ agent_type="ldp.agent.SimpleAgent",
126
+ index=IndexSettings(paper_directory=tmp_path),
127
+ # TODO: add image support for paper_search
128
+ tool_names={"gather_evidence", "gen_answer", "complete", "reset"},
129
+ agent_evidence_n=3, # Bumped up to collect several perspectives
130
+ )
131
+ dataset = ImageQATaskDataset(dataset=LABBenchDatasets.TABLE_QA, settings=settings)
132
+ t_cb = StoreTrajectoriesCallback()
133
+ m_cb = MeanMetricsCallback(eval_dataset=dataset, track_tool_usage=True)
134
+ evaluator = Evaluator(
135
+ config=EvaluatorConfig(
136
+ batch_size=256, # Use batch size greater than FigQA size and TableQA size
137
+ max_rollout_steps=18, # Match aviary paper's PaperQA setting
138
+ ),
139
+ agent=cast(Agent, await settings.make_ldp_agent(settings.agent.agent_type)),
140
+ dataset=dataset,
141
+ callbacks=[t_cb, m_cb],
142
+ )
143
+ await evaluator.evaluate()
144
+ print(m_cb.eval_means)
145
+ ```
146
+
147
+ ## References
148
+
149
+ [1] Skarlinski et al.
150
+ [Language agents achieve superhuman synthesis of scientific knowledge](https://arxiv.org/abs/2409.13740).
151
+ ArXiv:2409.13740, 2024.
152
+
153
+ [2] Laurent et al.
154
+ [LAB-Bench: Measuring Capabilities of Language Models for Biology Research](https://arxiv.org/abs/2407.10362).
155
+ ArXiv:2407.10362, 2024.
@@ -0,0 +1,121 @@
1
+ # aviary.labbench
2
+
3
+ LAB-Bench environments implemented with aviary,
4
+ allowing agents to perform question answering on scientific tasks.
5
+
6
+ ## Installation
7
+
8
+ To install the LAB-Bench environment, run:
9
+
10
+ ```bash
11
+ pip install 'fhaviary[labbench]'
12
+ ```
13
+
14
+ ## Usage
15
+
16
+ In [`labbench/env.py`](src/aviary/envs/labbench/env.py), you will find:
17
+
18
+ - `GradablePaperQAEnvironment`: an PaperQA-backed environment
19
+ that can grade answers given an evaluation function.
20
+ - `ImageQAEnvironment`: an `GradablePaperQAEnvironment`
21
+ subclass for QA where image(s) are pre-added.
22
+
23
+ And in [`labbench/task.py`](src/aviary/envs/labbench/task.py), you will find:
24
+
25
+ - `TextQATaskDataset`: a task dataset designed to
26
+ pull down FigQA, LitQA2, or TableQA from Hugging Face,
27
+ and create one `GradablePaperQAEnvironment` per question.
28
+ - `ImageQATaskDataset`: a task dataset that pairs with `ImageQAEnvironment`
29
+ for FigQA or TableQA.
30
+
31
+ Here is an example of how to use them:
32
+
33
+ ```python
34
+ import os
35
+
36
+ from ldp.agent import SimpleAgent
37
+ from ldp.alg import Evaluator, EvaluatorConfig, MeanMetricsCallback
38
+ from paperqa import Settings
39
+
40
+ from aviary.env import TaskDataset
41
+
42
+
43
+ async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
44
+ settings = Settings(paper_directory=folder_of_litqa_v2_papers)
45
+ dataset = TaskDataset.from_name("litqa2", settings=settings)
46
+ metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
47
+
48
+ evaluator = Evaluator(
49
+ config=EvaluatorConfig(batch_size=3),
50
+ agent=SimpleAgent(),
51
+ dataset=dataset,
52
+ callbacks=[metrics_callback],
53
+ )
54
+ await evaluator.evaluate()
55
+ print(metrics_callback.eval_means)
56
+ ```
57
+
58
+ ### Image Question-Answer
59
+
60
+ This is an environment/dataset for giving PaperQA a `Docs` object with
61
+ the image(s) for one LAB-Bench question.
62
+ It's designed to be a comparison with zero-shotting the question to a LLM,
63
+ but instead of a singular prompt the image is put through the PaperQA agent loop.
64
+
65
+ ```python
66
+ from typing import cast
67
+
68
+ import litellm
69
+ import pytest
70
+ from ldp.agent import Agent
71
+ from ldp.alg import (
72
+ Evaluator,
73
+ EvaluatorConfig,
74
+ MeanMetricsCallback,
75
+ StoreTrajectoriesCallback,
76
+ )
77
+ from paperqa.settings import AgentSettings, IndexSettings
78
+
79
+ from aviary.envs.labbench import (
80
+ ImageQAEnvironment,
81
+ ImageQATaskDataset,
82
+ LABBenchDatasets,
83
+ )
84
+
85
+
86
+ @pytest.mark.asyncio
87
+ async def test_image_qa(tmp_path) -> None:
88
+ litellm.num_retries = 8 # Mitigate connection-related failures
89
+ settings = ImageQAEnvironment.make_base_settings()
90
+ settings.agent = AgentSettings(
91
+ agent_type="ldp.agent.SimpleAgent",
92
+ index=IndexSettings(paper_directory=tmp_path),
93
+ # TODO: add image support for paper_search
94
+ tool_names={"gather_evidence", "gen_answer", "complete", "reset"},
95
+ agent_evidence_n=3, # Bumped up to collect several perspectives
96
+ )
97
+ dataset = ImageQATaskDataset(dataset=LABBenchDatasets.TABLE_QA, settings=settings)
98
+ t_cb = StoreTrajectoriesCallback()
99
+ m_cb = MeanMetricsCallback(eval_dataset=dataset, track_tool_usage=True)
100
+ evaluator = Evaluator(
101
+ config=EvaluatorConfig(
102
+ batch_size=256, # Use batch size greater than FigQA size and TableQA size
103
+ max_rollout_steps=18, # Match aviary paper's PaperQA setting
104
+ ),
105
+ agent=cast(Agent, await settings.make_ldp_agent(settings.agent.agent_type)),
106
+ dataset=dataset,
107
+ callbacks=[t_cb, m_cb],
108
+ )
109
+ await evaluator.evaluate()
110
+ print(m_cb.eval_means)
111
+ ```
112
+
113
+ ## References
114
+
115
+ [1] Skarlinski et al.
116
+ [Language agents achieve superhuman synthesis of scientific knowledge](https://arxiv.org/abs/2409.13740).
117
+ ArXiv:2409.13740, 2024.
118
+
119
+ [2] Laurent et al.
120
+ [LAB-Bench: Measuring Capabilities of Language Models for Biology Research](https://arxiv.org/abs/2407.10362).
121
+ ArXiv:2407.10362, 2024.
@@ -0,0 +1,56 @@
1
+ [build-system]
2
+ build-backend = "setuptools.build_meta"
3
+ requires = ["setuptools>=64", "setuptools_scm>=8"]
4
+
5
+ [project]
6
+ authors = [
7
+ {email = "hello@futurehouse.org", name = "FutureHouse technical staff"},
8
+ ]
9
+ classifiers = [
10
+ "Intended Audience :: Developers",
11
+ "License :: OSI Approved :: Apache Software License",
12
+ "Operating System :: OS Independent",
13
+ "Programming Language :: Python :: 3 :: Only",
14
+ "Programming Language :: Python :: 3.11",
15
+ "Programming Language :: Python :: 3.12",
16
+ "Programming Language :: Python :: 3.13",
17
+ "Programming Language :: Python :: 3.14",
18
+ "Programming Language :: Python",
19
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
20
+ ]
21
+ dependencies = [
22
+ "fhaviary>=0.14", # For MultipleChoiceQuestion
23
+ "fhlmi",
24
+ "ldp>=0.25.2", # Pin for lmi migration
25
+ "paper-qa[pymupdf]>=2025", # Pin for multimodal
26
+ "pydantic~=2.0",
27
+ "tenacity",
28
+ "typing-extensions; python_version <= '3.12'", # For TypeVar default
29
+ ]
30
+ description = "LAB-Bench environments implemented with aviary"
31
+ dynamic = ["version"]
32
+ name = "aviary.labbench"
33
+ readme = "README.md"
34
+ requires-python = ">=3.11"
35
+
36
+ [project.optional-dependencies]
37
+ datasets = [
38
+ "datasets>=2.15", # Lower pin for https://github.com/huggingface/datasets/pull/6404
39
+ ]
40
+ dev = [
41
+ "aviary.labbench[datasets,typing]",
42
+ "pandas",
43
+ "paper-qa>=5.29.1", # Pin for gen_answer's EmptyDocsError, with fix
44
+ "tantivy>=0.25.0; python_version >= '3.14'", # For Python 3.14 support
45
+ ]
46
+ typing = ["pillow"]
47
+
48
+ [tool.ruff]
49
+ extend = "../../pyproject.toml"
50
+
51
+ [tool.setuptools.packages.find]
52
+ where = ["src"]
53
+
54
+ [tool.setuptools_scm]
55
+ root = "../.."
56
+ version_file = "src/aviary/envs/labbench/version.py"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,31 @@
1
+ from .env import (
2
+ DEFAULT_REWARD_MAPPING,
3
+ GradablePaperQAEnvironment,
4
+ ImageQAEnvironment,
5
+ make_discounted_returns,
6
+ )
7
+ from .task import (
8
+ DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
9
+ DEFAULT_LABBENCH_HF_HUB_NAME,
10
+ ImageQATaskDataset,
11
+ LABBenchDatasets,
12
+ PaperQATaskDataset,
13
+ TextQATaskDataset,
14
+ TextQATaskSplit,
15
+ read_ds_from_hub,
16
+ )
17
+
18
+ __all__ = [
19
+ "DEFAULT_AVIARY_PAPER_HF_HUB_NAME",
20
+ "DEFAULT_LABBENCH_HF_HUB_NAME",
21
+ "DEFAULT_REWARD_MAPPING",
22
+ "GradablePaperQAEnvironment",
23
+ "ImageQAEnvironment",
24
+ "ImageQATaskDataset",
25
+ "LABBenchDatasets",
26
+ "PaperQATaskDataset",
27
+ "TextQATaskDataset",
28
+ "TextQATaskSplit",
29
+ "make_discounted_returns",
30
+ "read_ds_from_hub",
31
+ ]
@@ -0,0 +1,276 @@
1
+ import logging
2
+ import sys
3
+ import tempfile
4
+ from collections.abc import Awaitable, Callable, Mapping, Sequence
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Generic, Self, cast
8
+ from uuid import UUID
9
+
10
+ from aviary.core import (
11
+ Messages,
12
+ MultipleChoiceEvaluation,
13
+ MultipleChoiceQuestion,
14
+ ToolRequestMessage,
15
+ )
16
+ from aviary.env import ENV_REGISTRY
17
+ from ldp.utils import discounted_returns
18
+ from lmi import EmbeddingModel, LiteLLMModel
19
+ from paperqa.agents.env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
20
+ from paperqa.agents.search import SearchIndex, maybe_get_manifest
21
+ from paperqa.docs import Docs
22
+ from paperqa.settings import AnswerSettings, ParsingSettings, Settings
23
+
24
+ if TYPE_CHECKING:
25
+ from PIL.Image import Image
26
+
27
+ if sys.version_info >= (3, 13):
28
+ from typing import TypeVar
29
+ else:
30
+ from typing_extensions import TypeVar # For TypeVar.default backport
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ TEvaluation = TypeVar("TEvaluation", default=MultipleChoiceEvaluation)
35
+
36
+ DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0}
37
+
38
+
39
+ def make_discounted_returns(
40
+ evaluation: MultipleChoiceEvaluation,
41
+ num_steps: int,
42
+ rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
43
+ discount: float = 1.0,
44
+ ) -> list[float]:
45
+ return discounted_returns(
46
+ # paper-qa has no intermediary rewards
47
+ [0] * (num_steps - 1) + [rewards[evaluation.value]],
48
+ terminated=[False] * (num_steps - 1) + [True],
49
+ discount=discount,
50
+ )
51
+
52
+
53
+ class GradablePaperQAEnvironment(PaperQAEnvironment, Generic[TEvaluation]):
54
+ """Extended environment that can grade answers."""
55
+
56
+ def __init__(
57
+ self,
58
+ query: str | MultipleChoiceQuestion,
59
+ settings: Settings,
60
+ docs: Docs,
61
+ llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
62
+ summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
63
+ embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
64
+ session_id: UUID | None = None,
65
+ sources: str | list[str] | None = None,
66
+ rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
67
+ evaluation_callback: Callable[[TEvaluation], Awaitable] | None = None,
68
+ **env_kwargs,
69
+ ):
70
+ super().__init__(
71
+ query,
72
+ settings,
73
+ docs,
74
+ llm_model,
75
+ summary_llm_model,
76
+ embedding_model,
77
+ session_id,
78
+ **env_kwargs,
79
+ )
80
+ # Enables checking an Index has the right DOI(s)
81
+ self.sources: list[str] | None = (
82
+ [sources] if isinstance(sources, str) else sources
83
+ )
84
+ self._evaluation_callback = evaluation_callback
85
+ self._rewards = rewards
86
+
87
+ async def validate_sources(
88
+ self, manifest_or_index: dict[str, dict[str, Any]] | SearchIndex | None = None
89
+ ) -> None:
90
+ """Validate the sources can be found in the input manifest or index."""
91
+ if not self.sources:
92
+ return
93
+ if manifest_or_index is None: # Let's try to load in the manifest
94
+ manifest_or_index = await maybe_get_manifest(
95
+ filename=await self._settings.agent.index.finalize_manifest_file()
96
+ )
97
+ if isinstance(manifest_or_index, SearchIndex):
98
+ entity: str = "index"
99
+ file_names: set[str] = {k for k in await manifest_or_index.index_files if k}
100
+ lowercased_dois: set[str] = set()
101
+ else:
102
+ entity = "manifest"
103
+ file_names = {k for k in manifest_or_index if k}
104
+ lowercased_dois = {
105
+ v["doi"].lower() for v in manifest_or_index.values() if v["doi"]
106
+ }
107
+ if not file_names: # File names being empty means something's wrong
108
+ logger.warning(
109
+ f"Can't validate sources {self.sources} without a correctly specified"
110
+ f" {entity}."
111
+ )
112
+ return
113
+ not_found = [
114
+ s
115
+ for s in self.sources
116
+ if s not in file_names and s.lower() not in lowercased_dois
117
+ ]
118
+ if not_found:
119
+ question = (
120
+ self._query
121
+ if isinstance(self._query, str)
122
+ else self._query.question_prompt
123
+ )
124
+ raise ValueError(
125
+ f"Sources {not_found} of {self.sources} not found in the {entity},"
126
+ f" the corresponding query was {question!r}."
127
+ )
128
+
129
+ async def _evaluate_answer(self) -> TEvaluation:
130
+ # If the ensuring evaluation fails (e.g. due to OpenAI being down), we can:
131
+ # - Suppress the exception and declare the evaluation as incorrect, which can
132
+ # negatively reward what otherwise was a good trajectory containing a correct
133
+ # answer. We don't want "bad" offline data, so it's not what we do.
134
+ # - Suppress the exception and just give super()'s reward, but again this could
135
+ # incorrectly reward what otherwise was a good trajectory.
136
+ # - Don't suppress the exception, which leads to the trajectory failing, and
137
+ # removes it from the learnable pool. This is the only safe default behavior.
138
+ evaluation, self.state.session.graded_answer = await cast(
139
+ "MultipleChoiceQuestion", self._query
140
+ ).grade(self.state.session.answer)
141
+ return evaluation # type: ignore[return-value]
142
+
143
+ async def step(
144
+ self, action: ToolRequestMessage
145
+ ) -> tuple[Messages, float, bool, bool]:
146
+ messages, reward, done, truncated = await super().step(action)
147
+ if not done or not isinstance(self._query, MultipleChoiceQuestion):
148
+ return messages, reward, done, truncated
149
+ evaluation = await self._evaluate_answer()
150
+ if evaluation_callback := self._evaluation_callback:
151
+ await evaluation_callback(evaluation)
152
+
153
+ return (
154
+ messages,
155
+ reward + self._rewards[cast("MultipleChoiceEvaluation", evaluation).value],
156
+ done,
157
+ truncated,
158
+ )
159
+
160
+ async def get_id(self) -> str:
161
+ if (
162
+ isinstance(self._query, str)
163
+ or self._query.question_id
164
+ == MultipleChoiceQuestion.model_fields["question_id"].default
165
+ ):
166
+ details = (
167
+ ", as just a question was configured"
168
+ if isinstance(self._query, str)
169
+ else ", as the default ID remains present"
170
+ )
171
+ raise ValueError(f"No question ID was configured{details}.")
172
+ return str(self._query.question_id)
173
+
174
+ def __deepcopy__(self, memo) -> Self:
175
+ copy_state = deepcopy(self.state, memo)
176
+ # We don't know the side effects of deep copying a litellm.Router,
177
+ # so we force a shallow copy of these LiteLLMModels
178
+ env_model_kwargs: dict[str, Any] = {
179
+ name: model if model is None else type(model)(**model.model_dump())
180
+ for name, model in (
181
+ ("llm_model", self._llm_model),
182
+ ("summary_llm_model", self._summary_llm_model),
183
+ ("embedding_model", self._embedding_model),
184
+ )
185
+ }
186
+ copy_self = type(self)(
187
+ query=self._query, # No need to copy since we read only
188
+ settings=deepcopy(self._settings, memo), # Deepcopy just to be safe
189
+ docs=copy_state.docs,
190
+ sources=self.sources,
191
+ rewards=self._rewards,
192
+ evaluation_callback=self._evaluation_callback,
193
+ **env_model_kwargs,
194
+ )
195
+ copy_self.state = copy_state
196
+ # Because we shallow copied the LiteLLMModels, we need to re-make the
197
+ # tool functions within the tools
198
+ copy_self.tools = copy_self.make_tools()
199
+ return copy_self
200
+
201
+
202
+ ENV_REGISTRY["paperqa-local"] = (
203
+ GradablePaperQAEnvironment.__module__,
204
+ GradablePaperQAEnvironment.__name__,
205
+ )
206
+
207
+
208
+ class ImageQAEnvironment(GradablePaperQAEnvironment):
209
+ """Image question-answer environment useful for LAB-Bench's FigQA and TableQA."""
210
+
211
+ @classmethod
212
+ def make_base_settings(cls, **kwargs) -> Settings:
213
+ """Make a settings object that takes into account image-based QA restrictions."""
214
+ return Settings(
215
+ # PaperQA doesn't support image embeddings yet, so disable embedding
216
+ # Disable doc details since we just have images here (not a PDF with metadata)
217
+ parsing=ParsingSettings(defer_embedding=True, use_doc_details=False),
218
+ answer=AnswerSettings(evidence_retrieval=False),
219
+ **kwargs,
220
+ )
221
+
222
+ def __init__(
223
+ self,
224
+ *args,
225
+ images: "bytes | Image | Sequence[bytes | Image]",
226
+ image_paths: str | Sequence[str],
227
+ **kwargs,
228
+ ):
229
+ super().__init__(*args, **kwargs)
230
+ if not isinstance(self._query, MultipleChoiceQuestion):
231
+ raise TypeError(
232
+ f"{type(self).__name__} requires a {MultipleChoiceQuestion.__name__}"
233
+ f" as the query, not {type(self._query)}."
234
+ )
235
+ # FigQA has 1 image with paths, TableQA has 1+ images with paths
236
+ if not isinstance(image_paths, str): # Assume TableQA
237
+ self._images_with_names: "list[tuple[bytes | Image, str]]" = [ # noqa: UP037
238
+ (image, Path(image_path).name)
239
+ for image, image_path in zip(
240
+ cast("Sequence[bytes | Image]", images), image_paths, strict=True
241
+ )
242
+ ]
243
+ else: # Assume FigQA
244
+ self._images_with_names = [
245
+ (cast("bytes | Image", images), Path(image_paths).name)
246
+ ]
247
+
248
+ def get_images(self) -> "list[bytes | Image]":
249
+ """
250
+ Get the image(s) used in the environment, helpful for recall measurement.
251
+
252
+ NOTE: FigQA has 1 image with paths, TableQA has 1+ images with paths.
253
+ """
254
+ return [image for image, _ in self._images_with_names]
255
+
256
+ async def _reset_docs(self) -> None:
257
+ """Hook to reset the docs when creating the initial state."""
258
+ self._docs.clear_docs()
259
+
260
+ # Now add the image(s) to the docs
261
+ with tempfile.TemporaryDirectory() as tmpdir:
262
+ for image, image_name in self._images_with_names:
263
+ tmp_image_path = Path(tmpdir) / image_name
264
+ if isinstance(image, bytes):
265
+ tmp_image_path.write_bytes(image)
266
+ else:
267
+ image.save(tmp_image_path)
268
+ await self._docs.aadd(
269
+ tmp_image_path,
270
+ citation=(
271
+ f"Row ID {self._query.question_id} filename {tmp_image_path.name}"
272
+ if isinstance(self._query, MultipleChoiceQuestion)
273
+ else f"Filename {tmp_image_path.name}"
274
+ ),
275
+ settings=self._settings,
276
+ )