aviary.labbench 0.30.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aviary_labbench-0.30.0/PKG-INFO +155 -0
- aviary_labbench-0.30.0/README.md +121 -0
- aviary_labbench-0.30.0/pyproject.toml +56 -0
- aviary_labbench-0.30.0/setup.cfg +4 -0
- aviary_labbench-0.30.0/src/aviary/envs/labbench/__init__.py +31 -0
- aviary_labbench-0.30.0/src/aviary/envs/labbench/env.py +276 -0
- aviary_labbench-0.30.0/src/aviary/envs/labbench/py.typed +0 -0
- aviary_labbench-0.30.0/src/aviary/envs/labbench/task.py +546 -0
- aviary_labbench-0.30.0/src/aviary/envs/labbench/version.py +34 -0
- aviary_labbench-0.30.0/src/aviary.labbench.egg-info/PKG-INFO +155 -0
- aviary_labbench-0.30.0/src/aviary.labbench.egg-info/SOURCES.txt +19 -0
- aviary_labbench-0.30.0/src/aviary.labbench.egg-info/dependency_links.txt +1 -0
- aviary_labbench-0.30.0/src/aviary.labbench.egg-info/requires.txt +23 -0
- aviary_labbench-0.30.0/src/aviary.labbench.egg-info/top_level.txt +1 -0
- aviary_labbench-0.30.0/tests/cassettes/TestPaperQATaskDataset.test_tool_failure.yaml +313 -0
- aviary_labbench-0.30.0/tests/conftest.py +61 -0
- aviary_labbench-0.30.0/tests/stub_data/bates.txt +838 -0
- aviary_labbench-0.30.0/tests/stub_data/flag_day.html +3047 -0
- aviary_labbench-0.30.0/tests/stub_data/paper.pdf +0 -0
- aviary_labbench-0.30.0/tests/test_envs.py +188 -0
- aviary_labbench-0.30.0/tests/test_tasks.py +633 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aviary.labbench
|
|
3
|
+
Version: 0.30.0
|
|
4
|
+
Summary: LAB-Bench environments implemented with aviary
|
|
5
|
+
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
|
6
|
+
Classifier: Intended Audience :: Developers
|
|
7
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
14
|
+
Classifier: Programming Language :: Python
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.11
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
Requires-Dist: fhaviary>=0.14
|
|
19
|
+
Requires-Dist: fhlmi
|
|
20
|
+
Requires-Dist: ldp>=0.25.2
|
|
21
|
+
Requires-Dist: paper-qa[pymupdf]>=2025
|
|
22
|
+
Requires-Dist: pydantic~=2.0
|
|
23
|
+
Requires-Dist: tenacity
|
|
24
|
+
Requires-Dist: typing-extensions; python_version <= "3.12"
|
|
25
|
+
Provides-Extra: datasets
|
|
26
|
+
Requires-Dist: datasets>=2.15; extra == "datasets"
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: aviary.labbench[datasets,typing]; extra == "dev"
|
|
29
|
+
Requires-Dist: pandas; extra == "dev"
|
|
30
|
+
Requires-Dist: paper-qa>=5.29.1; extra == "dev"
|
|
31
|
+
Requires-Dist: tantivy>=0.25.0; python_version >= "3.14" and extra == "dev"
|
|
32
|
+
Provides-Extra: typing
|
|
33
|
+
Requires-Dist: pillow; extra == "typing"
|
|
34
|
+
|
|
35
|
+
# aviary.labbench
|
|
36
|
+
|
|
37
|
+
LAB-Bench environments implemented with aviary,
|
|
38
|
+
allowing agents to perform question answering on scientific tasks.
|
|
39
|
+
|
|
40
|
+
## Installation
|
|
41
|
+
|
|
42
|
+
To install the LAB-Bench environment, run:
|
|
43
|
+
|
|
44
|
+
```bash
|
|
45
|
+
pip install 'fhaviary[labbench]'
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Usage
|
|
49
|
+
|
|
50
|
+
In [`labbench/env.py`](src/aviary/envs/labbench/env.py), you will find:
|
|
51
|
+
|
|
52
|
+
- `GradablePaperQAEnvironment`: an PaperQA-backed environment
|
|
53
|
+
that can grade answers given an evaluation function.
|
|
54
|
+
- `ImageQAEnvironment`: an `GradablePaperQAEnvironment`
|
|
55
|
+
subclass for QA where image(s) are pre-added.
|
|
56
|
+
|
|
57
|
+
And in [`labbench/task.py`](src/aviary/envs/labbench/task.py), you will find:
|
|
58
|
+
|
|
59
|
+
- `TextQATaskDataset`: a task dataset designed to
|
|
60
|
+
pull down FigQA, LitQA2, or TableQA from Hugging Face,
|
|
61
|
+
and create one `GradablePaperQAEnvironment` per question.
|
|
62
|
+
- `ImageQATaskDataset`: a task dataset that pairs with `ImageQAEnvironment`
|
|
63
|
+
for FigQA or TableQA.
|
|
64
|
+
|
|
65
|
+
Here is an example of how to use them:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import os
|
|
69
|
+
|
|
70
|
+
from ldp.agent import SimpleAgent
|
|
71
|
+
from ldp.alg import Evaluator, EvaluatorConfig, MeanMetricsCallback
|
|
72
|
+
from paperqa import Settings
|
|
73
|
+
|
|
74
|
+
from aviary.env import TaskDataset
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
|
|
78
|
+
settings = Settings(paper_directory=folder_of_litqa_v2_papers)
|
|
79
|
+
dataset = TaskDataset.from_name("litqa2", settings=settings)
|
|
80
|
+
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
|
|
81
|
+
|
|
82
|
+
evaluator = Evaluator(
|
|
83
|
+
config=EvaluatorConfig(batch_size=3),
|
|
84
|
+
agent=SimpleAgent(),
|
|
85
|
+
dataset=dataset,
|
|
86
|
+
callbacks=[metrics_callback],
|
|
87
|
+
)
|
|
88
|
+
await evaluator.evaluate()
|
|
89
|
+
print(metrics_callback.eval_means)
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Image Question-Answer
|
|
93
|
+
|
|
94
|
+
This is an environment/dataset for giving PaperQA a `Docs` object with
|
|
95
|
+
the image(s) for one LAB-Bench question.
|
|
96
|
+
It's designed to be a comparison with zero-shotting the question to a LLM,
|
|
97
|
+
but instead of a singular prompt the image is put through the PaperQA agent loop.
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
from typing import cast
|
|
101
|
+
|
|
102
|
+
import litellm
|
|
103
|
+
import pytest
|
|
104
|
+
from ldp.agent import Agent
|
|
105
|
+
from ldp.alg import (
|
|
106
|
+
Evaluator,
|
|
107
|
+
EvaluatorConfig,
|
|
108
|
+
MeanMetricsCallback,
|
|
109
|
+
StoreTrajectoriesCallback,
|
|
110
|
+
)
|
|
111
|
+
from paperqa.settings import AgentSettings, IndexSettings
|
|
112
|
+
|
|
113
|
+
from aviary.envs.labbench import (
|
|
114
|
+
ImageQAEnvironment,
|
|
115
|
+
ImageQATaskDataset,
|
|
116
|
+
LABBenchDatasets,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.mark.asyncio
|
|
121
|
+
async def test_image_qa(tmp_path) -> None:
|
|
122
|
+
litellm.num_retries = 8 # Mitigate connection-related failures
|
|
123
|
+
settings = ImageQAEnvironment.make_base_settings()
|
|
124
|
+
settings.agent = AgentSettings(
|
|
125
|
+
agent_type="ldp.agent.SimpleAgent",
|
|
126
|
+
index=IndexSettings(paper_directory=tmp_path),
|
|
127
|
+
# TODO: add image support for paper_search
|
|
128
|
+
tool_names={"gather_evidence", "gen_answer", "complete", "reset"},
|
|
129
|
+
agent_evidence_n=3, # Bumped up to collect several perspectives
|
|
130
|
+
)
|
|
131
|
+
dataset = ImageQATaskDataset(dataset=LABBenchDatasets.TABLE_QA, settings=settings)
|
|
132
|
+
t_cb = StoreTrajectoriesCallback()
|
|
133
|
+
m_cb = MeanMetricsCallback(eval_dataset=dataset, track_tool_usage=True)
|
|
134
|
+
evaluator = Evaluator(
|
|
135
|
+
config=EvaluatorConfig(
|
|
136
|
+
batch_size=256, # Use batch size greater than FigQA size and TableQA size
|
|
137
|
+
max_rollout_steps=18, # Match aviary paper's PaperQA setting
|
|
138
|
+
),
|
|
139
|
+
agent=cast(Agent, await settings.make_ldp_agent(settings.agent.agent_type)),
|
|
140
|
+
dataset=dataset,
|
|
141
|
+
callbacks=[t_cb, m_cb],
|
|
142
|
+
)
|
|
143
|
+
await evaluator.evaluate()
|
|
144
|
+
print(m_cb.eval_means)
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
## References
|
|
148
|
+
|
|
149
|
+
[1] Skarlinski et al.
|
|
150
|
+
[Language agents achieve superhuman synthesis of scientific knowledge](https://arxiv.org/abs/2409.13740).
|
|
151
|
+
ArXiv:2409.13740, 2024.
|
|
152
|
+
|
|
153
|
+
[2] Laurent et al.
|
|
154
|
+
[LAB-Bench: Measuring Capabilities of Language Models for Biology Research](https://arxiv.org/abs/2407.10362).
|
|
155
|
+
ArXiv:2407.10362, 2024.
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# aviary.labbench
|
|
2
|
+
|
|
3
|
+
LAB-Bench environments implemented with aviary,
|
|
4
|
+
allowing agents to perform question answering on scientific tasks.
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
|
|
8
|
+
To install the LAB-Bench environment, run:
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install 'fhaviary[labbench]'
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Usage
|
|
15
|
+
|
|
16
|
+
In [`labbench/env.py`](src/aviary/envs/labbench/env.py), you will find:
|
|
17
|
+
|
|
18
|
+
- `GradablePaperQAEnvironment`: an PaperQA-backed environment
|
|
19
|
+
that can grade answers given an evaluation function.
|
|
20
|
+
- `ImageQAEnvironment`: an `GradablePaperQAEnvironment`
|
|
21
|
+
subclass for QA where image(s) are pre-added.
|
|
22
|
+
|
|
23
|
+
And in [`labbench/task.py`](src/aviary/envs/labbench/task.py), you will find:
|
|
24
|
+
|
|
25
|
+
- `TextQATaskDataset`: a task dataset designed to
|
|
26
|
+
pull down FigQA, LitQA2, or TableQA from Hugging Face,
|
|
27
|
+
and create one `GradablePaperQAEnvironment` per question.
|
|
28
|
+
- `ImageQATaskDataset`: a task dataset that pairs with `ImageQAEnvironment`
|
|
29
|
+
for FigQA or TableQA.
|
|
30
|
+
|
|
31
|
+
Here is an example of how to use them:
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
import os
|
|
35
|
+
|
|
36
|
+
from ldp.agent import SimpleAgent
|
|
37
|
+
from ldp.alg import Evaluator, EvaluatorConfig, MeanMetricsCallback
|
|
38
|
+
from paperqa import Settings
|
|
39
|
+
|
|
40
|
+
from aviary.env import TaskDataset
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def evaluate(folder_of_litqa_v2_papers: str | os.PathLike) -> None:
|
|
44
|
+
settings = Settings(paper_directory=folder_of_litqa_v2_papers)
|
|
45
|
+
dataset = TaskDataset.from_name("litqa2", settings=settings)
|
|
46
|
+
metrics_callback = MeanMetricsCallback(eval_dataset=dataset)
|
|
47
|
+
|
|
48
|
+
evaluator = Evaluator(
|
|
49
|
+
config=EvaluatorConfig(batch_size=3),
|
|
50
|
+
agent=SimpleAgent(),
|
|
51
|
+
dataset=dataset,
|
|
52
|
+
callbacks=[metrics_callback],
|
|
53
|
+
)
|
|
54
|
+
await evaluator.evaluate()
|
|
55
|
+
print(metrics_callback.eval_means)
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
### Image Question-Answer
|
|
59
|
+
|
|
60
|
+
This is an environment/dataset for giving PaperQA a `Docs` object with
|
|
61
|
+
the image(s) for one LAB-Bench question.
|
|
62
|
+
It's designed to be a comparison with zero-shotting the question to a LLM,
|
|
63
|
+
but instead of a singular prompt the image is put through the PaperQA agent loop.
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from typing import cast
|
|
67
|
+
|
|
68
|
+
import litellm
|
|
69
|
+
import pytest
|
|
70
|
+
from ldp.agent import Agent
|
|
71
|
+
from ldp.alg import (
|
|
72
|
+
Evaluator,
|
|
73
|
+
EvaluatorConfig,
|
|
74
|
+
MeanMetricsCallback,
|
|
75
|
+
StoreTrajectoriesCallback,
|
|
76
|
+
)
|
|
77
|
+
from paperqa.settings import AgentSettings, IndexSettings
|
|
78
|
+
|
|
79
|
+
from aviary.envs.labbench import (
|
|
80
|
+
ImageQAEnvironment,
|
|
81
|
+
ImageQATaskDataset,
|
|
82
|
+
LABBenchDatasets,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.mark.asyncio
|
|
87
|
+
async def test_image_qa(tmp_path) -> None:
|
|
88
|
+
litellm.num_retries = 8 # Mitigate connection-related failures
|
|
89
|
+
settings = ImageQAEnvironment.make_base_settings()
|
|
90
|
+
settings.agent = AgentSettings(
|
|
91
|
+
agent_type="ldp.agent.SimpleAgent",
|
|
92
|
+
index=IndexSettings(paper_directory=tmp_path),
|
|
93
|
+
# TODO: add image support for paper_search
|
|
94
|
+
tool_names={"gather_evidence", "gen_answer", "complete", "reset"},
|
|
95
|
+
agent_evidence_n=3, # Bumped up to collect several perspectives
|
|
96
|
+
)
|
|
97
|
+
dataset = ImageQATaskDataset(dataset=LABBenchDatasets.TABLE_QA, settings=settings)
|
|
98
|
+
t_cb = StoreTrajectoriesCallback()
|
|
99
|
+
m_cb = MeanMetricsCallback(eval_dataset=dataset, track_tool_usage=True)
|
|
100
|
+
evaluator = Evaluator(
|
|
101
|
+
config=EvaluatorConfig(
|
|
102
|
+
batch_size=256, # Use batch size greater than FigQA size and TableQA size
|
|
103
|
+
max_rollout_steps=18, # Match aviary paper's PaperQA setting
|
|
104
|
+
),
|
|
105
|
+
agent=cast(Agent, await settings.make_ldp_agent(settings.agent.agent_type)),
|
|
106
|
+
dataset=dataset,
|
|
107
|
+
callbacks=[t_cb, m_cb],
|
|
108
|
+
)
|
|
109
|
+
await evaluator.evaluate()
|
|
110
|
+
print(m_cb.eval_means)
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## References
|
|
114
|
+
|
|
115
|
+
[1] Skarlinski et al.
|
|
116
|
+
[Language agents achieve superhuman synthesis of scientific knowledge](https://arxiv.org/abs/2409.13740).
|
|
117
|
+
ArXiv:2409.13740, 2024.
|
|
118
|
+
|
|
119
|
+
[2] Laurent et al.
|
|
120
|
+
[LAB-Bench: Measuring Capabilities of Language Models for Biology Research](https://arxiv.org/abs/2407.10362).
|
|
121
|
+
ArXiv:2407.10362, 2024.
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
build-backend = "setuptools.build_meta"
|
|
3
|
+
requires = ["setuptools>=64", "setuptools_scm>=8"]
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
authors = [
|
|
7
|
+
{email = "hello@futurehouse.org", name = "FutureHouse technical staff"},
|
|
8
|
+
]
|
|
9
|
+
classifiers = [
|
|
10
|
+
"Intended Audience :: Developers",
|
|
11
|
+
"License :: OSI Approved :: Apache Software License",
|
|
12
|
+
"Operating System :: OS Independent",
|
|
13
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
14
|
+
"Programming Language :: Python :: 3.11",
|
|
15
|
+
"Programming Language :: Python :: 3.12",
|
|
16
|
+
"Programming Language :: Python :: 3.13",
|
|
17
|
+
"Programming Language :: Python :: 3.14",
|
|
18
|
+
"Programming Language :: Python",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"fhaviary>=0.14", # For MultipleChoiceQuestion
|
|
23
|
+
"fhlmi",
|
|
24
|
+
"ldp>=0.25.2", # Pin for lmi migration
|
|
25
|
+
"paper-qa[pymupdf]>=2025", # Pin for multimodal
|
|
26
|
+
"pydantic~=2.0",
|
|
27
|
+
"tenacity",
|
|
28
|
+
"typing-extensions; python_version <= '3.12'", # For TypeVar default
|
|
29
|
+
]
|
|
30
|
+
description = "LAB-Bench environments implemented with aviary"
|
|
31
|
+
dynamic = ["version"]
|
|
32
|
+
name = "aviary.labbench"
|
|
33
|
+
readme = "README.md"
|
|
34
|
+
requires-python = ">=3.11"
|
|
35
|
+
|
|
36
|
+
[project.optional-dependencies]
|
|
37
|
+
datasets = [
|
|
38
|
+
"datasets>=2.15", # Lower pin for https://github.com/huggingface/datasets/pull/6404
|
|
39
|
+
]
|
|
40
|
+
dev = [
|
|
41
|
+
"aviary.labbench[datasets,typing]",
|
|
42
|
+
"pandas",
|
|
43
|
+
"paper-qa>=5.29.1", # Pin for gen_answer's EmptyDocsError, with fix
|
|
44
|
+
"tantivy>=0.25.0; python_version >= '3.14'", # For Python 3.14 support
|
|
45
|
+
]
|
|
46
|
+
typing = ["pillow"]
|
|
47
|
+
|
|
48
|
+
[tool.ruff]
|
|
49
|
+
extend = "../../pyproject.toml"
|
|
50
|
+
|
|
51
|
+
[tool.setuptools.packages.find]
|
|
52
|
+
where = ["src"]
|
|
53
|
+
|
|
54
|
+
[tool.setuptools_scm]
|
|
55
|
+
root = "../.."
|
|
56
|
+
version_file = "src/aviary/envs/labbench/version.py"
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from .env import (
|
|
2
|
+
DEFAULT_REWARD_MAPPING,
|
|
3
|
+
GradablePaperQAEnvironment,
|
|
4
|
+
ImageQAEnvironment,
|
|
5
|
+
make_discounted_returns,
|
|
6
|
+
)
|
|
7
|
+
from .task import (
|
|
8
|
+
DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
|
|
9
|
+
DEFAULT_LABBENCH_HF_HUB_NAME,
|
|
10
|
+
ImageQATaskDataset,
|
|
11
|
+
LABBenchDatasets,
|
|
12
|
+
PaperQATaskDataset,
|
|
13
|
+
TextQATaskDataset,
|
|
14
|
+
TextQATaskSplit,
|
|
15
|
+
read_ds_from_hub,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DEFAULT_AVIARY_PAPER_HF_HUB_NAME",
|
|
20
|
+
"DEFAULT_LABBENCH_HF_HUB_NAME",
|
|
21
|
+
"DEFAULT_REWARD_MAPPING",
|
|
22
|
+
"GradablePaperQAEnvironment",
|
|
23
|
+
"ImageQAEnvironment",
|
|
24
|
+
"ImageQATaskDataset",
|
|
25
|
+
"LABBenchDatasets",
|
|
26
|
+
"PaperQATaskDataset",
|
|
27
|
+
"TextQATaskDataset",
|
|
28
|
+
"TextQATaskSplit",
|
|
29
|
+
"make_discounted_returns",
|
|
30
|
+
"read_ds_from_hub",
|
|
31
|
+
]
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import tempfile
|
|
4
|
+
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, Self, cast
|
|
8
|
+
from uuid import UUID
|
|
9
|
+
|
|
10
|
+
from aviary.core import (
|
|
11
|
+
Messages,
|
|
12
|
+
MultipleChoiceEvaluation,
|
|
13
|
+
MultipleChoiceQuestion,
|
|
14
|
+
ToolRequestMessage,
|
|
15
|
+
)
|
|
16
|
+
from aviary.env import ENV_REGISTRY
|
|
17
|
+
from ldp.utils import discounted_returns
|
|
18
|
+
from lmi import EmbeddingModel, LiteLLMModel
|
|
19
|
+
from paperqa.agents.env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
|
|
20
|
+
from paperqa.agents.search import SearchIndex, maybe_get_manifest
|
|
21
|
+
from paperqa.docs import Docs
|
|
22
|
+
from paperqa.settings import AnswerSettings, ParsingSettings, Settings
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from PIL.Image import Image
|
|
26
|
+
|
|
27
|
+
if sys.version_info >= (3, 13):
|
|
28
|
+
from typing import TypeVar
|
|
29
|
+
else:
|
|
30
|
+
from typing_extensions import TypeVar # For TypeVar.default backport
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
TEvaluation = TypeVar("TEvaluation", default=MultipleChoiceEvaluation)
|
|
35
|
+
|
|
36
|
+
DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def make_discounted_returns(
|
|
40
|
+
evaluation: MultipleChoiceEvaluation,
|
|
41
|
+
num_steps: int,
|
|
42
|
+
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
|
|
43
|
+
discount: float = 1.0,
|
|
44
|
+
) -> list[float]:
|
|
45
|
+
return discounted_returns(
|
|
46
|
+
# paper-qa has no intermediary rewards
|
|
47
|
+
[0] * (num_steps - 1) + [rewards[evaluation.value]],
|
|
48
|
+
terminated=[False] * (num_steps - 1) + [True],
|
|
49
|
+
discount=discount,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GradablePaperQAEnvironment(PaperQAEnvironment, Generic[TEvaluation]):
|
|
54
|
+
"""Extended environment that can grade answers."""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
query: str | MultipleChoiceQuestion,
|
|
59
|
+
settings: Settings,
|
|
60
|
+
docs: Docs,
|
|
61
|
+
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
|
|
62
|
+
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
|
|
63
|
+
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
|
|
64
|
+
session_id: UUID | None = None,
|
|
65
|
+
sources: str | list[str] | None = None,
|
|
66
|
+
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
|
|
67
|
+
evaluation_callback: Callable[[TEvaluation], Awaitable] | None = None,
|
|
68
|
+
**env_kwargs,
|
|
69
|
+
):
|
|
70
|
+
super().__init__(
|
|
71
|
+
query,
|
|
72
|
+
settings,
|
|
73
|
+
docs,
|
|
74
|
+
llm_model,
|
|
75
|
+
summary_llm_model,
|
|
76
|
+
embedding_model,
|
|
77
|
+
session_id,
|
|
78
|
+
**env_kwargs,
|
|
79
|
+
)
|
|
80
|
+
# Enables checking an Index has the right DOI(s)
|
|
81
|
+
self.sources: list[str] | None = (
|
|
82
|
+
[sources] if isinstance(sources, str) else sources
|
|
83
|
+
)
|
|
84
|
+
self._evaluation_callback = evaluation_callback
|
|
85
|
+
self._rewards = rewards
|
|
86
|
+
|
|
87
|
+
async def validate_sources(
|
|
88
|
+
self, manifest_or_index: dict[str, dict[str, Any]] | SearchIndex | None = None
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Validate the sources can be found in the input manifest or index."""
|
|
91
|
+
if not self.sources:
|
|
92
|
+
return
|
|
93
|
+
if manifest_or_index is None: # Let's try to load in the manifest
|
|
94
|
+
manifest_or_index = await maybe_get_manifest(
|
|
95
|
+
filename=await self._settings.agent.index.finalize_manifest_file()
|
|
96
|
+
)
|
|
97
|
+
if isinstance(manifest_or_index, SearchIndex):
|
|
98
|
+
entity: str = "index"
|
|
99
|
+
file_names: set[str] = {k for k in await manifest_or_index.index_files if k}
|
|
100
|
+
lowercased_dois: set[str] = set()
|
|
101
|
+
else:
|
|
102
|
+
entity = "manifest"
|
|
103
|
+
file_names = {k for k in manifest_or_index if k}
|
|
104
|
+
lowercased_dois = {
|
|
105
|
+
v["doi"].lower() for v in manifest_or_index.values() if v["doi"]
|
|
106
|
+
}
|
|
107
|
+
if not file_names: # File names being empty means something's wrong
|
|
108
|
+
logger.warning(
|
|
109
|
+
f"Can't validate sources {self.sources} without a correctly specified"
|
|
110
|
+
f" {entity}."
|
|
111
|
+
)
|
|
112
|
+
return
|
|
113
|
+
not_found = [
|
|
114
|
+
s
|
|
115
|
+
for s in self.sources
|
|
116
|
+
if s not in file_names and s.lower() not in lowercased_dois
|
|
117
|
+
]
|
|
118
|
+
if not_found:
|
|
119
|
+
question = (
|
|
120
|
+
self._query
|
|
121
|
+
if isinstance(self._query, str)
|
|
122
|
+
else self._query.question_prompt
|
|
123
|
+
)
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Sources {not_found} of {self.sources} not found in the {entity},"
|
|
126
|
+
f" the corresponding query was {question!r}."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
async def _evaluate_answer(self) -> TEvaluation:
|
|
130
|
+
# If the ensuring evaluation fails (e.g. due to OpenAI being down), we can:
|
|
131
|
+
# - Suppress the exception and declare the evaluation as incorrect, which can
|
|
132
|
+
# negatively reward what otherwise was a good trajectory containing a correct
|
|
133
|
+
# answer. We don't want "bad" offline data, so it's not what we do.
|
|
134
|
+
# - Suppress the exception and just give super()'s reward, but again this could
|
|
135
|
+
# incorrectly reward what otherwise was a good trajectory.
|
|
136
|
+
# - Don't suppress the exception, which leads to the trajectory failing, and
|
|
137
|
+
# removes it from the learnable pool. This is the only safe default behavior.
|
|
138
|
+
evaluation, self.state.session.graded_answer = await cast(
|
|
139
|
+
"MultipleChoiceQuestion", self._query
|
|
140
|
+
).grade(self.state.session.answer)
|
|
141
|
+
return evaluation # type: ignore[return-value]
|
|
142
|
+
|
|
143
|
+
async def step(
|
|
144
|
+
self, action: ToolRequestMessage
|
|
145
|
+
) -> tuple[Messages, float, bool, bool]:
|
|
146
|
+
messages, reward, done, truncated = await super().step(action)
|
|
147
|
+
if not done or not isinstance(self._query, MultipleChoiceQuestion):
|
|
148
|
+
return messages, reward, done, truncated
|
|
149
|
+
evaluation = await self._evaluate_answer()
|
|
150
|
+
if evaluation_callback := self._evaluation_callback:
|
|
151
|
+
await evaluation_callback(evaluation)
|
|
152
|
+
|
|
153
|
+
return (
|
|
154
|
+
messages,
|
|
155
|
+
reward + self._rewards[cast("MultipleChoiceEvaluation", evaluation).value],
|
|
156
|
+
done,
|
|
157
|
+
truncated,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
async def get_id(self) -> str:
|
|
161
|
+
if (
|
|
162
|
+
isinstance(self._query, str)
|
|
163
|
+
or self._query.question_id
|
|
164
|
+
== MultipleChoiceQuestion.model_fields["question_id"].default
|
|
165
|
+
):
|
|
166
|
+
details = (
|
|
167
|
+
", as just a question was configured"
|
|
168
|
+
if isinstance(self._query, str)
|
|
169
|
+
else ", as the default ID remains present"
|
|
170
|
+
)
|
|
171
|
+
raise ValueError(f"No question ID was configured{details}.")
|
|
172
|
+
return str(self._query.question_id)
|
|
173
|
+
|
|
174
|
+
def __deepcopy__(self, memo) -> Self:
|
|
175
|
+
copy_state = deepcopy(self.state, memo)
|
|
176
|
+
# We don't know the side effects of deep copying a litellm.Router,
|
|
177
|
+
# so we force a shallow copy of these LiteLLMModels
|
|
178
|
+
env_model_kwargs: dict[str, Any] = {
|
|
179
|
+
name: model if model is None else type(model)(**model.model_dump())
|
|
180
|
+
for name, model in (
|
|
181
|
+
("llm_model", self._llm_model),
|
|
182
|
+
("summary_llm_model", self._summary_llm_model),
|
|
183
|
+
("embedding_model", self._embedding_model),
|
|
184
|
+
)
|
|
185
|
+
}
|
|
186
|
+
copy_self = type(self)(
|
|
187
|
+
query=self._query, # No need to copy since we read only
|
|
188
|
+
settings=deepcopy(self._settings, memo), # Deepcopy just to be safe
|
|
189
|
+
docs=copy_state.docs,
|
|
190
|
+
sources=self.sources,
|
|
191
|
+
rewards=self._rewards,
|
|
192
|
+
evaluation_callback=self._evaluation_callback,
|
|
193
|
+
**env_model_kwargs,
|
|
194
|
+
)
|
|
195
|
+
copy_self.state = copy_state
|
|
196
|
+
# Because we shallow copied the LiteLLMModels, we need to re-make the
|
|
197
|
+
# tool functions within the tools
|
|
198
|
+
copy_self.tools = copy_self.make_tools()
|
|
199
|
+
return copy_self
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
ENV_REGISTRY["paperqa-local"] = (
|
|
203
|
+
GradablePaperQAEnvironment.__module__,
|
|
204
|
+
GradablePaperQAEnvironment.__name__,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class ImageQAEnvironment(GradablePaperQAEnvironment):
|
|
209
|
+
"""Image question-answer environment useful for LAB-Bench's FigQA and TableQA."""
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def make_base_settings(cls, **kwargs) -> Settings:
|
|
213
|
+
"""Make a settings object that takes into account image-based QA restrictions."""
|
|
214
|
+
return Settings(
|
|
215
|
+
# PaperQA doesn't support image embeddings yet, so disable embedding
|
|
216
|
+
# Disable doc details since we just have images here (not a PDF with metadata)
|
|
217
|
+
parsing=ParsingSettings(defer_embedding=True, use_doc_details=False),
|
|
218
|
+
answer=AnswerSettings(evidence_retrieval=False),
|
|
219
|
+
**kwargs,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def __init__(
|
|
223
|
+
self,
|
|
224
|
+
*args,
|
|
225
|
+
images: "bytes | Image | Sequence[bytes | Image]",
|
|
226
|
+
image_paths: str | Sequence[str],
|
|
227
|
+
**kwargs,
|
|
228
|
+
):
|
|
229
|
+
super().__init__(*args, **kwargs)
|
|
230
|
+
if not isinstance(self._query, MultipleChoiceQuestion):
|
|
231
|
+
raise TypeError(
|
|
232
|
+
f"{type(self).__name__} requires a {MultipleChoiceQuestion.__name__}"
|
|
233
|
+
f" as the query, not {type(self._query)}."
|
|
234
|
+
)
|
|
235
|
+
# FigQA has 1 image with paths, TableQA has 1+ images with paths
|
|
236
|
+
if not isinstance(image_paths, str): # Assume TableQA
|
|
237
|
+
self._images_with_names: "list[tuple[bytes | Image, str]]" = [ # noqa: UP037
|
|
238
|
+
(image, Path(image_path).name)
|
|
239
|
+
for image, image_path in zip(
|
|
240
|
+
cast("Sequence[bytes | Image]", images), image_paths, strict=True
|
|
241
|
+
)
|
|
242
|
+
]
|
|
243
|
+
else: # Assume FigQA
|
|
244
|
+
self._images_with_names = [
|
|
245
|
+
(cast("bytes | Image", images), Path(image_paths).name)
|
|
246
|
+
]
|
|
247
|
+
|
|
248
|
+
def get_images(self) -> "list[bytes | Image]":
|
|
249
|
+
"""
|
|
250
|
+
Get the image(s) used in the environment, helpful for recall measurement.
|
|
251
|
+
|
|
252
|
+
NOTE: FigQA has 1 image with paths, TableQA has 1+ images with paths.
|
|
253
|
+
"""
|
|
254
|
+
return [image for image, _ in self._images_with_names]
|
|
255
|
+
|
|
256
|
+
async def _reset_docs(self) -> None:
|
|
257
|
+
"""Hook to reset the docs when creating the initial state."""
|
|
258
|
+
self._docs.clear_docs()
|
|
259
|
+
|
|
260
|
+
# Now add the image(s) to the docs
|
|
261
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
262
|
+
for image, image_name in self._images_with_names:
|
|
263
|
+
tmp_image_path = Path(tmpdir) / image_name
|
|
264
|
+
if isinstance(image, bytes):
|
|
265
|
+
tmp_image_path.write_bytes(image)
|
|
266
|
+
else:
|
|
267
|
+
image.save(tmp_image_path)
|
|
268
|
+
await self._docs.aadd(
|
|
269
|
+
tmp_image_path,
|
|
270
|
+
citation=(
|
|
271
|
+
f"Row ID {self._query.question_id} filename {tmp_image_path.name}"
|
|
272
|
+
if isinstance(self._query, MultipleChoiceQuestion)
|
|
273
|
+
else f"Filename {tmp_image_path.name}"
|
|
274
|
+
),
|
|
275
|
+
settings=self._settings,
|
|
276
|
+
)
|
|
File without changes
|