query-agent-benchmarking 0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- query_agent_benchmarking-0.1/LICENSE +27 -0
- query_agent_benchmarking-0.1/PKG-INFO +17 -0
- query_agent_benchmarking-0.1/README.md +15 -0
- query_agent_benchmarking-0.1/pyproject.toml +26 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/__init__.py +22 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/agent.py +176 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/benchmark_run.py +306 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/config.py +21 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/create_benchmark.py +147 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/database.py +211 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/dataset.py +232 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/metrics/ir_metrics.py +252 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/models.py +24 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/populate_db.py +31 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/query_agent_benchmark.py +378 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking/utils.py +49 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking.egg-info/PKG-INFO +17 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking.egg-info/SOURCES.txt +21 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking.egg-info/dependency_links.txt +1 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking.egg-info/requires.txt +11 -0
- query_agent_benchmarking-0.1/query_agent_benchmarking.egg-info/top_level.txt +1 -0
- query_agent_benchmarking-0.1/setup.cfg +4 -0
- query_agent_benchmarking-0.1/setup.py +18 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
Copyright (c) 2020-2024, Weaviate B.V.
|
|
2
|
+
All rights reserved.
|
|
3
|
+
|
|
4
|
+
Redistribution and use in source and binary forms, with or without
|
|
5
|
+
modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
contributors may be used to endorse or promote products derived from
|
|
16
|
+
this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: query-agent-benchmarking
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Requires-Python: >=3.9
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Dist: dspy>=2.6.27
|
|
7
|
+
Requires-Dist: sentence-transformers>=5.0.0
|
|
8
|
+
Requires-Dist: weaviate-client[agents]>=4.15.4
|
|
9
|
+
Requires-Dist: weaviate-agents>=1.0.0
|
|
10
|
+
Requires-Dist: pandas>=2.3.1
|
|
11
|
+
Requires-Dist: datasets>=4.0.0
|
|
12
|
+
Requires-Dist: ir-datasets>=0.5.11
|
|
13
|
+
Requires-Dist: pip>=25.2
|
|
14
|
+
Requires-Dist: setuptools>=80.9.0
|
|
15
|
+
Requires-Dist: wheel>=0.45.1
|
|
16
|
+
Requires-Dist: twine>=6.2.0
|
|
17
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Query Agent Benchmarking
|
|
2
|
+
|
|
3
|
+
This repo contains a tool for benchmarking the performance of the Weaviate Query Agent.
|
|
4
|
+
|
|
5
|
+
Populate Weaviate with benchmark data:
|
|
6
|
+
```
|
|
7
|
+
uv run python benchmarker/populate-db.py
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
Run eval:
|
|
11
|
+
```
|
|
12
|
+
uv run python benchmarker/benchmark-run.py
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
See `benchmarker/config.yml` to change the dataset populated in your Weaviate instance, as well as ablate `hybrid-search` or `query-agent-search-only`, as well as the number of samples and concurrency parameters.
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=70", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "query-agent-benchmarking"
|
|
7
|
+
version = "0.1"
|
|
8
|
+
requires-python = ">=3.9"
|
|
9
|
+
dependencies = [
|
|
10
|
+
"dspy>=2.6.27",
|
|
11
|
+
"sentence-transformers>=5.0.0",
|
|
12
|
+
"weaviate-client[agents]>=4.15.4",
|
|
13
|
+
"weaviate-agents>=1.0.0",
|
|
14
|
+
"pandas>=2.3.1",
|
|
15
|
+
"datasets>=4.0.0",
|
|
16
|
+
"ir-datasets>=0.5.11",
|
|
17
|
+
"pip>=25.2",
|
|
18
|
+
"setuptools>=80.9.0",
|
|
19
|
+
"wheel>=0.45.1",
|
|
20
|
+
"twine>=6.2.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[tool.setuptools.packages.find]
|
|
24
|
+
where = ["."]
|
|
25
|
+
include = ["query_agent_benchmarking*"]
|
|
26
|
+
exclude = ["notebooks*", "results*", "locust*", "ci*"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .benchmark_run import run_eval
|
|
2
|
+
from .models import (
|
|
3
|
+
DocsCollection,
|
|
4
|
+
QueriesCollection,
|
|
5
|
+
InMemoryQuery,
|
|
6
|
+
ObjectID,
|
|
7
|
+
QueryResult,
|
|
8
|
+
)
|
|
9
|
+
from .create_benchmark import create_benchmark
|
|
10
|
+
from .config import print_supported_datasets
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"run_eval",
|
|
14
|
+
"DocsCollection",
|
|
15
|
+
"QueriesCollection",
|
|
16
|
+
"InMemoryQuery",
|
|
17
|
+
"ObjectID",
|
|
18
|
+
"QueryResult",
|
|
19
|
+
"create_benchmark",
|
|
20
|
+
"print_supported_datasets",
|
|
21
|
+
]
|
|
22
|
+
__version__ = "0.1"
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import weaviate
|
|
5
|
+
from weaviate.agents.query import QueryAgent, AsyncQueryAgent
|
|
6
|
+
from weaviate.auth import Auth
|
|
7
|
+
from query_agent_benchmarking.models import ObjectID, DocsCollection
|
|
8
|
+
|
|
9
|
+
class AgentBuilder:
|
|
10
|
+
"""
|
|
11
|
+
* `agent_name == "query-agent-search-only"` ➜ Wraps the Weaviate QueryAgent in Search Only Mode.
|
|
12
|
+
* `agent_name == "hybrid-search"` ➜ Wraps Weaviate Hybrid Search.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
agent_name: str,
|
|
17
|
+
dataset_name: Optional[str] = None,
|
|
18
|
+
docs_collection: Optional[DocsCollection] = None,
|
|
19
|
+
agents_host: Optional[str] = None,
|
|
20
|
+
use_async: bool = False,
|
|
21
|
+
):
|
|
22
|
+
self.use_async = use_async
|
|
23
|
+
self.agent = None
|
|
24
|
+
self.weaviate_client = None
|
|
25
|
+
|
|
26
|
+
self.cluster_url = os.getenv("WEAVIATE_URL")
|
|
27
|
+
self.api_key = os.getenv("WEAVIATE_API_KEY")
|
|
28
|
+
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
29
|
+
|
|
30
|
+
# Require either dataset_name or docs_collection, but not both
|
|
31
|
+
if dataset_name and docs_collection:
|
|
32
|
+
raise ValueError("Cannot specify both dataset_name and docs_collection")
|
|
33
|
+
if not dataset_name and not docs_collection:
|
|
34
|
+
raise ValueError("Must specify either dataset_name or docs_collection")
|
|
35
|
+
|
|
36
|
+
# Handle custom DocsCollection
|
|
37
|
+
if docs_collection:
|
|
38
|
+
self.collection = docs_collection.collection_name
|
|
39
|
+
self.target_property_name = docs_collection.content_key
|
|
40
|
+
self.id_property = docs_collection.id_key
|
|
41
|
+
|
|
42
|
+
# Handle built-in datasets
|
|
43
|
+
elif dataset_name == "enron":
|
|
44
|
+
self.collection = "EnronEmails"
|
|
45
|
+
self.target_property_name = ""
|
|
46
|
+
self.id_property = "dataset_id"
|
|
47
|
+
elif dataset_name == "wixqa":
|
|
48
|
+
self.collection = "WixKB"
|
|
49
|
+
self.target_property_name = "contents"
|
|
50
|
+
self.id_property = "dataset_id"
|
|
51
|
+
elif dataset_name.startswith("freshstack-"):
|
|
52
|
+
subset = dataset_name.split("-")[1].capitalize()
|
|
53
|
+
self.collection = f"Freshstack{subset}"
|
|
54
|
+
self.target_property_name = "docs_text"
|
|
55
|
+
self.id_property = "dataset_id"
|
|
56
|
+
elif dataset_name.startswith("beir/"):
|
|
57
|
+
self.collection = f"Beir{dataset_name.split('beir/')[1].replace('-', '_').replace('/', '_').capitalize()}"
|
|
58
|
+
self.target_property_name = "content"
|
|
59
|
+
self.id_property = "dataset_id"
|
|
60
|
+
elif dataset_name.startswith("lotte/"):
|
|
61
|
+
lotte_subset = dataset_name.split("/")[1]
|
|
62
|
+
self.collection = f"Lotte{lotte_subset.capitalize()}"
|
|
63
|
+
self.target_property_name = "content"
|
|
64
|
+
self.id_property = "dataset_id"
|
|
65
|
+
elif dataset_name.startswith("bright/"):
|
|
66
|
+
self.collection = f"Bright{dataset_name.split('/')[1].capitalize()}"
|
|
67
|
+
self.target_property_name = "content"
|
|
68
|
+
self.id_property = "dataset_id"
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
|
71
|
+
|
|
72
|
+
self.agent_name = agent_name
|
|
73
|
+
self.agents_host = agents_host or "https://api.agents.weaviate.io"
|
|
74
|
+
|
|
75
|
+
if not use_async:
|
|
76
|
+
self.initialize_sync()
|
|
77
|
+
else:
|
|
78
|
+
self.initialize_async()
|
|
79
|
+
|
|
80
|
+
def initialize_sync(self):
|
|
81
|
+
print(f"Initializing sync connection to {self.cluster_url}")
|
|
82
|
+
|
|
83
|
+
self.weaviate_client = weaviate.connect_to_weaviate_cloud(
|
|
84
|
+
cluster_url=self.cluster_url,
|
|
85
|
+
auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
|
|
86
|
+
)
|
|
87
|
+
if self.agent_name == "query-agent-search-only":
|
|
88
|
+
self.agent = QueryAgent(
|
|
89
|
+
client=self.weaviate_client,
|
|
90
|
+
collections=[self.collection],
|
|
91
|
+
agents_host=self.agents_host,
|
|
92
|
+
)
|
|
93
|
+
elif self.agent_name == "hybrid-search":
|
|
94
|
+
self.weaviate_collection = self.weaviate_client.collections.get(self.collection)
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unknown agent_name: {self.agent_name}. Must be 'query-agent-search-only' or 'hybrid-search'")
|
|
97
|
+
|
|
98
|
+
async def initialize_async(self):
|
|
99
|
+
print(f"Initializing async connection to {self.cluster_url}")
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
self.weaviate_client = weaviate.use_async_with_weaviate_cloud(
|
|
103
|
+
cluster_url=self.cluster_url,
|
|
104
|
+
auth_credentials=Auth.api_key(self.api_key),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
await self.weaviate_client.connect()
|
|
108
|
+
print("Async Weaviate client connected successfully")
|
|
109
|
+
|
|
110
|
+
if self.agent_name == "query-agent-search-only":
|
|
111
|
+
self.agent = AsyncQueryAgent(
|
|
112
|
+
client=self.weaviate_client,
|
|
113
|
+
collections=[self.collection],
|
|
114
|
+
agents_host=self.agents_host
|
|
115
|
+
)
|
|
116
|
+
print(f"AsyncQueryAgent initialized for collection: {self.collection}")
|
|
117
|
+
print(f"Using agents host: {self.agents_host}")
|
|
118
|
+
elif self.agent_name == "hybrid-search":
|
|
119
|
+
self.weaviate_collection = self.weaviate_client.collections.get(self.collection)
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unknown agent_name: {self.agent_name}. Must be 'query-agent-search-only' or 'hybrid-search'")
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f"Failed to initialize async agent: {str(e)}")
|
|
125
|
+
import traceback
|
|
126
|
+
traceback.print_exc()
|
|
127
|
+
raise
|
|
128
|
+
|
|
129
|
+
async def close_async(self):
|
|
130
|
+
if self.use_async and self.weaviate_client:
|
|
131
|
+
try:
|
|
132
|
+
await self.weaviate_client.close()
|
|
133
|
+
print("Async connection closed successfully")
|
|
134
|
+
except Exception as e:
|
|
135
|
+
print(f"Warning: Error closing async connection: {str(e)}")
|
|
136
|
+
|
|
137
|
+
def run(self, query: str) -> list[ObjectID]:
|
|
138
|
+
if self.agent_name == "query-agent-search-only":
|
|
139
|
+
# TODO: Interface `retrieved_k` instead of hardcoding `20`
|
|
140
|
+
response = self.agent.search(query, limit=20)
|
|
141
|
+
results = []
|
|
142
|
+
for obj in response.search_results.objects:
|
|
143
|
+
results.append(ObjectID(object_id=obj.properties[self.id_property]))
|
|
144
|
+
return results
|
|
145
|
+
|
|
146
|
+
if self.agent_name == "hybrid-search":
|
|
147
|
+
response = self.weaviate_collection.query.hybrid(
|
|
148
|
+
query=query,
|
|
149
|
+
limit=20
|
|
150
|
+
)
|
|
151
|
+
results = []
|
|
152
|
+
for obj in response.objects:
|
|
153
|
+
results.append(ObjectID(object_id=str(obj.properties[self.id_property])))
|
|
154
|
+
return results
|
|
155
|
+
|
|
156
|
+
async def run_async(self, query: str):
|
|
157
|
+
try:
|
|
158
|
+
if self.agent_name == "query-agent-search-only":
|
|
159
|
+
# TODO: Interface `retrieved_k` instead of hardcoding `20`
|
|
160
|
+
response = await self.agent.search(query, limit=20)
|
|
161
|
+
results = []
|
|
162
|
+
for obj in response.search_results.objects:
|
|
163
|
+
results.append(ObjectID(object_id=obj.properties[self.id_property]))
|
|
164
|
+
return results
|
|
165
|
+
elif self.agent_name == "hybrid-search":
|
|
166
|
+
response = await self.weaviate_collection.query.hybrid(
|
|
167
|
+
query=query,
|
|
168
|
+
limit=20
|
|
169
|
+
)
|
|
170
|
+
results = []
|
|
171
|
+
for obj in response.objects:
|
|
172
|
+
results.append(ObjectID(object_id=str(obj.properties[self.id_property])))
|
|
173
|
+
return results
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"Query '{query[:50]}...' failed with error: {str(e)}")
|
|
176
|
+
raise
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Dict, Any, Union, List
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
import weaviate
|
|
10
|
+
|
|
11
|
+
from query_agent_benchmarking.agent import AgentBuilder
|
|
12
|
+
from query_agent_benchmarking.dataset import (
|
|
13
|
+
in_memory_dataset_loader,
|
|
14
|
+
load_queries_from_weaviate_collection,
|
|
15
|
+
)
|
|
16
|
+
from query_agent_benchmarking.models import (
|
|
17
|
+
DocsCollection,
|
|
18
|
+
QueriesCollection,
|
|
19
|
+
InMemoryQuery,
|
|
20
|
+
)
|
|
21
|
+
from query_agent_benchmarking.query_agent_benchmark import (
|
|
22
|
+
run_queries,
|
|
23
|
+
run_queries_async,
|
|
24
|
+
analyze_results,
|
|
25
|
+
aggregate_metrics
|
|
26
|
+
)
|
|
27
|
+
from query_agent_benchmarking.utils import pretty_print_in_memory_query
|
|
28
|
+
from query_agent_benchmarking.config import supported_datasets
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
DEFAULT_CONFIG_PATH = Path(__file__).parent / "benchmark-config.yml"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_config(config_path: str) -> Dict[str, Any]:
|
|
35
|
+
"""Load configuration from YAML file."""
|
|
36
|
+
with open(config_path) as f:
|
|
37
|
+
config = yaml.safe_load(f)
|
|
38
|
+
return config
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def merge_configs(file_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]:
|
|
42
|
+
"""Merge file-based config with programmatic overrides."""
|
|
43
|
+
merged = file_config.copy()
|
|
44
|
+
|
|
45
|
+
# Filter out None values from override_config
|
|
46
|
+
filtered_overrides = {k: v for k, v in override_config.items() if v is not None}
|
|
47
|
+
|
|
48
|
+
# Special handling: if docs_collection is provided, remove dataset from merged config
|
|
49
|
+
if 'docs_collection' in filtered_overrides and 'dataset' in merged:
|
|
50
|
+
del merged['dataset']
|
|
51
|
+
|
|
52
|
+
# Apply overrides
|
|
53
|
+
merged.update(filtered_overrides)
|
|
54
|
+
|
|
55
|
+
return merged
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def _run_eval_async(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
59
|
+
agents_host = config.get("agents_host", "https://api.agents.weaviate.io")
|
|
60
|
+
use_async = config.get("use_async", True)
|
|
61
|
+
|
|
62
|
+
# Determine if using built-in dataset or custom collection
|
|
63
|
+
dataset_name = config.get("dataset")
|
|
64
|
+
docs_collection = config.get("docs_collection")
|
|
65
|
+
queries_input = config.get("queries")
|
|
66
|
+
|
|
67
|
+
if dataset_name:
|
|
68
|
+
# Built-in dataset paths
|
|
69
|
+
if dataset_name not in supported_datasets:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Dataset {dataset_name} is not supported. "
|
|
72
|
+
f"Supported datasets are: {supported_datasets}"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
_, queries = in_memory_dataset_loader(dataset_name)
|
|
76
|
+
dataset_identifier = dataset_name
|
|
77
|
+
|
|
78
|
+
elif docs_collection and queries_input:
|
|
79
|
+
# Custom collection path
|
|
80
|
+
if not isinstance(docs_collection, DocsCollection):
|
|
81
|
+
raise ValueError("docs_collection must be a DocsCollection object")
|
|
82
|
+
|
|
83
|
+
# Load queries based on the type of queries_input
|
|
84
|
+
if isinstance(queries_input, QueriesCollection):
|
|
85
|
+
# Load from Weaviate
|
|
86
|
+
queries = load_queries_from_weaviate_collection(
|
|
87
|
+
collection_name=queries_input.collection_name,
|
|
88
|
+
query_content_key=queries_input.query_content_key,
|
|
89
|
+
gold_ids_key=queries_input.gold_ids_key,
|
|
90
|
+
)
|
|
91
|
+
elif isinstance(queries_input, list):
|
|
92
|
+
# Verify all items are InMemoryQuery
|
|
93
|
+
if not queries_input:
|
|
94
|
+
raise ValueError("Queries list cannot be empty")
|
|
95
|
+
if not all(isinstance(q, InMemoryQuery) for q in queries_input):
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"All queries must be InMemoryQuery objects. "
|
|
98
|
+
f"Found: {set(type(q).__name__ for q in queries_input)}"
|
|
99
|
+
)
|
|
100
|
+
queries = queries_input
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Queries must be either QueriesCollection or List[InMemoryQuery]. "
|
|
104
|
+
f"Got: {type(queries_input)}"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
dataset_identifier = docs_collection.collection_name
|
|
108
|
+
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
"Must provide either 'dataset' (for built-in) or "
|
|
112
|
+
"'docs_collection' + 'queries' (for custom)"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
print(f"There are \033[92m{len(queries)}\033[0m total queries in this dataset.\n")
|
|
116
|
+
print("\033[92mFirst Query\033[0m")
|
|
117
|
+
pretty_print_in_memory_query(queries[0])
|
|
118
|
+
|
|
119
|
+
# Handle subset if requested
|
|
120
|
+
if config.get("use_subset", False):
|
|
121
|
+
import random
|
|
122
|
+
random.seed(config.get("random_seed", 24))
|
|
123
|
+
random.shuffle(queries)
|
|
124
|
+
queries = queries[:config["num_samples"]]
|
|
125
|
+
print(f"Using a subset of {config['num_samples']} queries.")
|
|
126
|
+
|
|
127
|
+
# Build agent
|
|
128
|
+
if dataset_name:
|
|
129
|
+
query_agent = AgentBuilder(
|
|
130
|
+
agent_name=config["agent_name"],
|
|
131
|
+
dataset_name=dataset_name,
|
|
132
|
+
agents_host=agents_host,
|
|
133
|
+
use_async=use_async,
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
query_agent = AgentBuilder(
|
|
137
|
+
agent_name=config["agent_name"],
|
|
138
|
+
docs_collection=docs_collection,
|
|
139
|
+
agents_host=agents_host,
|
|
140
|
+
use_async=use_async,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
num_trials = config.get("num_trials", 1)
|
|
144
|
+
metrics_across_trials = []
|
|
145
|
+
|
|
146
|
+
# Run trials
|
|
147
|
+
for trial in range(num_trials):
|
|
148
|
+
print(f"\033[92mRunning trial {trial+1}/{num_trials}\033[0m")
|
|
149
|
+
|
|
150
|
+
if use_async:
|
|
151
|
+
print("\033[92mRunning queries async!\033[0m")
|
|
152
|
+
await query_agent.initialize_async()
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
results = await run_queries_async(
|
|
156
|
+
queries=queries,
|
|
157
|
+
query_agent=query_agent,
|
|
158
|
+
batch_size=config.get("batch_size", 10),
|
|
159
|
+
max_concurrent=config.get("max_concurrent", 5)
|
|
160
|
+
)
|
|
161
|
+
finally:
|
|
162
|
+
await query_agent.close_async()
|
|
163
|
+
else:
|
|
164
|
+
print("\n\033[94mRunning synchronous benchmark\033[0m")
|
|
165
|
+
results = run_queries(
|
|
166
|
+
queries=queries,
|
|
167
|
+
query_agent=query_agent,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Analyze results
|
|
171
|
+
weaviate_client = weaviate.connect_to_weaviate_cloud(
|
|
172
|
+
cluster_url=os.getenv("WEAVIATE_URL"),
|
|
173
|
+
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_API_KEY")),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
metrics = await analyze_results(
|
|
177
|
+
results=results,
|
|
178
|
+
ground_truths=queries,
|
|
179
|
+
dataset_name=dataset_name,
|
|
180
|
+
)
|
|
181
|
+
print(metrics)
|
|
182
|
+
|
|
183
|
+
weaviate_client.close()
|
|
184
|
+
metrics_across_trials.append(metrics)
|
|
185
|
+
|
|
186
|
+
# Aggregate and save results
|
|
187
|
+
aggregated_metrics = aggregate_metrics(metrics_across_trials)
|
|
188
|
+
aggregated_metrics["timestamp"] = datetime.now().isoformat()
|
|
189
|
+
|
|
190
|
+
# Save results
|
|
191
|
+
output_path = config.get("output_path")
|
|
192
|
+
if output_path is None:
|
|
193
|
+
dataset_name_for_file = dataset_identifier.replace("/", "-")
|
|
194
|
+
output_path = f"{dataset_name_for_file}-{config['agent_name']}-{num_trials}-results.json"
|
|
195
|
+
|
|
196
|
+
with open(output_path, "w") as f:
|
|
197
|
+
json.dump(aggregated_metrics, f, indent=2)
|
|
198
|
+
|
|
199
|
+
print(f"\n\033[92mResults saved to {output_path}\033[0m")
|
|
200
|
+
|
|
201
|
+
return aggregated_metrics
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def run_eval(
|
|
205
|
+
config_path: Optional[str] = None,
|
|
206
|
+
dataset: Optional[str] = None,
|
|
207
|
+
docs_collection: Optional[DocsCollection] = None,
|
|
208
|
+
queries: Optional[Union[QueriesCollection, List[InMemoryQuery]]] = None,
|
|
209
|
+
agent_name: Optional[str] = None,
|
|
210
|
+
num_trials: Optional[int] = None,
|
|
211
|
+
use_subset: Optional[bool] = None,
|
|
212
|
+
num_samples: Optional[int] = None,
|
|
213
|
+
batch_size: Optional[int] = None,
|
|
214
|
+
max_concurrent: Optional[int] = None,
|
|
215
|
+
use_async: Optional[bool] = None,
|
|
216
|
+
agents_host: Optional[str] = None,
|
|
217
|
+
output_path: Optional[str] = None,
|
|
218
|
+
random_seed: Optional[int] = None,
|
|
219
|
+
**kwargs
|
|
220
|
+
) -> Dict[str, Any]:
|
|
221
|
+
"""
|
|
222
|
+
Run evaluation benchmark for query agents.
|
|
223
|
+
|
|
224
|
+
Works with both built-in datasets and custom collections.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
config_path: Path to YAML config file (default: benchmark-config.yml in package dir)
|
|
228
|
+
|
|
229
|
+
# Built-in dataset mode:
|
|
230
|
+
dataset: Dataset name (e.g., "bright/biology")
|
|
231
|
+
|
|
232
|
+
# Custom collection mode:
|
|
233
|
+
docs_collection: DocsCollection object specifying the document collection
|
|
234
|
+
queries: Queries in one of two formats:
|
|
235
|
+
- QueriesCollection: Loads from Weaviate
|
|
236
|
+
- List[InMemoryQuery]: Pre-built query objects
|
|
237
|
+
|
|
238
|
+
# Common parameters:
|
|
239
|
+
agent_name: Name of the agent to benchmark
|
|
240
|
+
num_trials: Number of trials to run
|
|
241
|
+
use_subset: Whether to use a subset of queries
|
|
242
|
+
num_samples: Number of samples to use if use_subset=True
|
|
243
|
+
batch_size: Batch size for async queries
|
|
244
|
+
max_concurrent: Max concurrent requests for async queries
|
|
245
|
+
use_async: Whether to use async mode
|
|
246
|
+
agents_host: URL of the agents host
|
|
247
|
+
output_path: Custom path for output JSON file
|
|
248
|
+
random_seed: Random seed for subset selection
|
|
249
|
+
**kwargs: Additional config parameters
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dict containing aggregated metrics
|
|
253
|
+
|
|
254
|
+
Examples:
|
|
255
|
+
# Built-in dataset with default config
|
|
256
|
+
>>> run_eval()
|
|
257
|
+
|
|
258
|
+
# Built-in dataset with overrides
|
|
259
|
+
>>> run_eval(dataset="bright/biology", num_trials=3)
|
|
260
|
+
|
|
261
|
+
# Custom collection with Weaviate queries
|
|
262
|
+
>>> docs = DocsCollection(name="MyDocs", id_key="id")
|
|
263
|
+
>>> queries = QueriesCollection(
|
|
264
|
+
... name="MyQueries",
|
|
265
|
+
... id_key="id"
|
|
266
|
+
... )
|
|
267
|
+
>>> run_eval(docs_collection=docs, queries=queries, agent_name="my-agent")
|
|
268
|
+
|
|
269
|
+
# Custom collection with in-memory queries
|
|
270
|
+
>>> docs = DocsCollection(name="MyDocs")
|
|
271
|
+
>>> queries = [
|
|
272
|
+
... InMemoryQuery(question="What is X?", query_id="q1", dataset_ids=["doc1"]),
|
|
273
|
+
... InMemoryQuery(question="Explain Y", query_id="q2", dataset_ids=["doc2"])
|
|
274
|
+
... ]
|
|
275
|
+
>>> run_eval(docs_collection=docs, queries=queries, agent_name="my-agent")
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
if config_path is None:
|
|
279
|
+
config_path = DEFAULT_CONFIG_PATH
|
|
280
|
+
|
|
281
|
+
# Load base config from file
|
|
282
|
+
file_config = load_config(config_path)
|
|
283
|
+
|
|
284
|
+
# Build override config from parameters
|
|
285
|
+
override_config = {
|
|
286
|
+
"dataset": dataset,
|
|
287
|
+
"docs_collection": docs_collection,
|
|
288
|
+
"queries": queries,
|
|
289
|
+
"agent_name": agent_name,
|
|
290
|
+
"num_trials": num_trials,
|
|
291
|
+
"use_subset": use_subset,
|
|
292
|
+
"num_samples": num_samples,
|
|
293
|
+
"batch_size": batch_size,
|
|
294
|
+
"max_concurrent": max_concurrent,
|
|
295
|
+
"use_async": use_async,
|
|
296
|
+
"agents_host": agents_host,
|
|
297
|
+
"output_path": output_path,
|
|
298
|
+
"random_seed": random_seed,
|
|
299
|
+
**kwargs
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
# Merge configs
|
|
303
|
+
final_config = merge_configs(file_config, override_config)
|
|
304
|
+
|
|
305
|
+
# Run evaluation
|
|
306
|
+
return asyncio.run(_run_eval_async(final_config))
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
supported_datasets = (
|
|
2
|
+
"beir/fiqa/test",
|
|
3
|
+
"beir/nq",
|
|
4
|
+
"beir/scifact/test",
|
|
5
|
+
"bright/biology",
|
|
6
|
+
"bright/earth_science",
|
|
7
|
+
"bright/economics",
|
|
8
|
+
"bright/psychology",
|
|
9
|
+
"bright/robotics",
|
|
10
|
+
"enron",
|
|
11
|
+
"lotte/lifestyle/test/forum",
|
|
12
|
+
"lotte/lifestyle/test/search",
|
|
13
|
+
"lotte/recreation/test/forum",
|
|
14
|
+
"lotte/recreation/test/search",
|
|
15
|
+
"wixqa"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def print_supported_datasets():
|
|
19
|
+
print("Supported datasets:")
|
|
20
|
+
for dataset in supported_datasets:
|
|
21
|
+
print(f"- {dataset}")
|