query-agent-benchmarking 0.2__tar.gz → 0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- query_agent_benchmarking-0.4/PKG-INFO +41 -0
- query_agent_benchmarking-0.4/README.md +21 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/pyproject.toml +5 -5
- query_agent_benchmarking-0.4/query_agent_benchmarking/__init__.py +113 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/agent/__init__.py +10 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/agent/ask_agent.py +189 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/agent/base.py +111 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/agent/search_agent.py +192 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/ask_benchmark_run.py +340 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/compare_embeddings.py +196 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/config.py +50 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/database/__init__.py +19 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/database/database_loader.py +216 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/database/database_registry.py +78 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/database/property_builder.py +311 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/database/spec.py +25 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/query_agent_benchmarking/dataset.py +181 -29
- query_agent_benchmarking-0.4/query_agent_benchmarking/experimental/add_hard_negatives.py +161 -0
- {query_agent_benchmarking-0.2/query_agent_benchmarking → query_agent_benchmarking-0.4/query_agent_benchmarking/experimental}/create_benchmark.py +10 -43
- query_agent_benchmarking-0.4/query_agent_benchmarking/experimental/hard_negative_prompts.py +5 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/experimental/query_gen_prompts.py +62 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/metrics/__init__.py +29 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/metrics/exact_match.py +33 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/query_agent_benchmarking/metrics/ir_metrics.py +22 -8
- query_agent_benchmarking-0.4/query_agent_benchmarking/metrics/lmjudge_alignment.py +183 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/models.py +102 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/qa_system_prompt_registry.py +19 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/query_agent_benchmark.py +681 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/result_serialization.py +192 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/search_benchmark_run.py +357 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking/utils.py +197 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking.egg-info/PKG-INFO +41 -0
- query_agent_benchmarking-0.4/query_agent_benchmarking.egg-info/SOURCES.txt +37 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/query_agent_benchmarking.egg-info/requires.txt +3 -3
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/setup.py +1 -1
- query_agent_benchmarking-0.2/PKG-INFO +0 -35
- query_agent_benchmarking-0.2/README.md +0 -15
- query_agent_benchmarking-0.2/query_agent_benchmarking/__init__.py +0 -22
- query_agent_benchmarking-0.2/query_agent_benchmarking/agent.py +0 -176
- query_agent_benchmarking-0.2/query_agent_benchmarking/benchmark_run.py +0 -306
- query_agent_benchmarking-0.2/query_agent_benchmarking/config.py +0 -21
- query_agent_benchmarking-0.2/query_agent_benchmarking/database.py +0 -211
- query_agent_benchmarking-0.2/query_agent_benchmarking/models.py +0 -24
- query_agent_benchmarking-0.2/query_agent_benchmarking/populate_db.py +0 -31
- query_agent_benchmarking-0.2/query_agent_benchmarking/query_agent_benchmark.py +0 -378
- query_agent_benchmarking-0.2/query_agent_benchmarking/utils.py +0 -49
- query_agent_benchmarking-0.2/query_agent_benchmarking.egg-info/PKG-INFO +0 -35
- query_agent_benchmarking-0.2/query_agent_benchmarking.egg-info/SOURCES.txt +0 -21
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/LICENSE +0 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/query_agent_benchmarking.egg-info/dependency_links.txt +0 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/query_agent_benchmarking.egg-info/top_level.txt +0 -0
- {query_agent_benchmarking-0.2 → query_agent_benchmarking-0.4}/setup.cfg +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: query-agent-benchmarking
|
|
3
|
+
Version: 0.4
|
|
4
|
+
Summary: A Python library for benchmarking Weaviate's Query Agent!
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Dist: dspy>=3.0.4
|
|
9
|
+
Requires-Dist: sentence-transformers>=5.0.0
|
|
10
|
+
Requires-Dist: weaviate-client>=4.19.2
|
|
11
|
+
Requires-Dist: weaviate-agents>=1.1.0
|
|
12
|
+
Requires-Dist: pandas>=2.3.1
|
|
13
|
+
Requires-Dist: datasets>=4.0.0
|
|
14
|
+
Requires-Dist: ir-datasets>=0.5.11
|
|
15
|
+
Requires-Dist: pip>=25.2
|
|
16
|
+
Requires-Dist: setuptools>=80.9.0
|
|
17
|
+
Requires-Dist: wheel>=0.45.1
|
|
18
|
+
Requires-Dist: twine>=6.2.0
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
|
|
21
|
+
# Query Agent Benchmarking
|
|
22
|
+
|
|
23
|
+
This repo contains a package for benchmarking the performance of Weaviate's Query Agent.
|
|
24
|
+
|
|
25
|
+
## News 📯
|
|
26
|
+
|
|
27
|
+
[9/25] 📊 Search Mode Benchmarking is live on the [Weaviate Blog](https://weaviate.io/blog/search-mode-benchmarking).
|
|
28
|
+
|
|
29
|
+
## How to Run 🧰
|
|
30
|
+
|
|
31
|
+
Populate Weaviate with benchmark data:
|
|
32
|
+
```
|
|
33
|
+
uv run python3 scripts/populate-db.py
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
Run eval:
|
|
37
|
+
```
|
|
38
|
+
uv run python3 scripts/run-search-benchmark.py
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
See `query_agent_benchmarking/benchmark-config.yml` to change the dataset populated in your Weaviate instance, as well as ablate `hybrid-search` or `query-agent-search-only`, as well as the number of samples and concurrency parameters.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Query Agent Benchmarking
|
|
2
|
+
|
|
3
|
+
This repo contains a package for benchmarking the performance of Weaviate's Query Agent.
|
|
4
|
+
|
|
5
|
+
## News 📯
|
|
6
|
+
|
|
7
|
+
[9/25] 📊 Search Mode Benchmarking is live on the [Weaviate Blog](https://weaviate.io/blog/search-mode-benchmarking).
|
|
8
|
+
|
|
9
|
+
## How to Run 🧰
|
|
10
|
+
|
|
11
|
+
Populate Weaviate with benchmark data:
|
|
12
|
+
```
|
|
13
|
+
uv run python3 scripts/populate-db.py
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Run eval:
|
|
17
|
+
```
|
|
18
|
+
uv run python3 scripts/run-search-benchmark.py
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
See `query_agent_benchmarking/benchmark-config.yml` to change the dataset populated in your Weaviate instance, as well as ablate `hybrid-search` or `query-agent-search-only`, as well as the number of samples and concurrency parameters.
|
|
@@ -4,15 +4,15 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "query-agent-benchmarking"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.4"
|
|
8
8
|
description="A Python library for benchmarking Weaviate's Query Agent!"
|
|
9
9
|
readme="README.md"
|
|
10
|
-
requires-python = ">=3.
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
11
|
dependencies = [
|
|
12
|
-
"dspy>=
|
|
12
|
+
"dspy>=3.0.4",
|
|
13
13
|
"sentence-transformers>=5.0.0",
|
|
14
|
-
"weaviate-client
|
|
15
|
-
"weaviate-agents>=1.
|
|
14
|
+
"weaviate-client>=4.19.2",
|
|
15
|
+
"weaviate-agents>=1.1.0",
|
|
16
16
|
"pandas>=2.3.1",
|
|
17
17
|
"datasets>=4.0.0",
|
|
18
18
|
"ir-datasets>=0.5.11",
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from .experimental.add_hard_negatives import add_hard_negatives
|
|
2
|
+
|
|
3
|
+
# Search benchmark exports
|
|
4
|
+
from .search_benchmark_run import run_search_eval, run_search_evals
|
|
5
|
+
|
|
6
|
+
# Ask benchmark exports
|
|
7
|
+
from .ask_benchmark_run import run_ask_eval
|
|
8
|
+
|
|
9
|
+
from .compare_embeddings import compare_embeddings
|
|
10
|
+
from .database import database_loader
|
|
11
|
+
from .dataset import (
|
|
12
|
+
in_memory_dataset_loader,
|
|
13
|
+
in_memory_ask_dataset_loader,
|
|
14
|
+
load_ask_queries_from_weaviate,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Models
|
|
18
|
+
from .models import (
|
|
19
|
+
DocsCollection,
|
|
20
|
+
QueriesCollection,
|
|
21
|
+
InMemoryQuery,
|
|
22
|
+
ObjectID,
|
|
23
|
+
QueryResult,
|
|
24
|
+
# Search-specific
|
|
25
|
+
InMemorySearchQuery,
|
|
26
|
+
SearchResult,
|
|
27
|
+
# Ask-specific
|
|
28
|
+
InMemoryAskQuery,
|
|
29
|
+
AskResult,
|
|
30
|
+
AskQueriesCollection,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Agent exports
|
|
34
|
+
from .agent import (
|
|
35
|
+
SearchAgentBuilder,
|
|
36
|
+
AskAgentBuilder,
|
|
37
|
+
BaseAgentBuilder,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Metrics
|
|
41
|
+
from .metrics import (
|
|
42
|
+
# IR Metrics
|
|
43
|
+
calculate_recall_at_k,
|
|
44
|
+
calculate_success_at_k,
|
|
45
|
+
calculate_nDCG_at_k,
|
|
46
|
+
calculate_coverage,
|
|
47
|
+
calculate_alpha_ndcg,
|
|
48
|
+
# LLM Judge
|
|
49
|
+
LMJudge,
|
|
50
|
+
calculate_alignment_score,
|
|
51
|
+
# Exact Match
|
|
52
|
+
calculate_exact_match,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
from .experimental.create_benchmark import create_benchmark
|
|
56
|
+
from .config import (
|
|
57
|
+
print_supported_datasets,
|
|
58
|
+
print_supported_ask_datasets,
|
|
59
|
+
supported_search_datasets,
|
|
60
|
+
supported_ask_datasets,
|
|
61
|
+
)
|
|
62
|
+
from .result_serialization import save_trial_results, save_ask_trial_results, save_trial_metrics, save_aggregated_results
|
|
63
|
+
|
|
64
|
+
__all__ = [
|
|
65
|
+
# Main entry points
|
|
66
|
+
"run_search_eval",
|
|
67
|
+
"run_search_evals",
|
|
68
|
+
"run_ask_eval",
|
|
69
|
+
# Utilities
|
|
70
|
+
"add_hard_negatives",
|
|
71
|
+
"database_loader",
|
|
72
|
+
"in_memory_dataset_loader",
|
|
73
|
+
"in_memory_ask_dataset_loader",
|
|
74
|
+
"load_ask_queries_from_weaviate",
|
|
75
|
+
"compare_embeddings",
|
|
76
|
+
"create_benchmark",
|
|
77
|
+
"print_supported_datasets",
|
|
78
|
+
"print_supported_ask_datasets",
|
|
79
|
+
"supported_search_datasets",
|
|
80
|
+
"supported_ask_datasets",
|
|
81
|
+
# Models
|
|
82
|
+
"DocsCollection",
|
|
83
|
+
"QueriesCollection",
|
|
84
|
+
"InMemoryQuery",
|
|
85
|
+
"ObjectID",
|
|
86
|
+
"QueryResult",
|
|
87
|
+
"InMemorySearchQuery",
|
|
88
|
+
"SearchResult",
|
|
89
|
+
"InMemoryAskQuery",
|
|
90
|
+
"AskResult",
|
|
91
|
+
"AskQueriesCollection",
|
|
92
|
+
# Agents
|
|
93
|
+
"SearchAgentBuilder",
|
|
94
|
+
"AskAgentBuilder",
|
|
95
|
+
"BaseAgentBuilder",
|
|
96
|
+
# Metrics - IR
|
|
97
|
+
"calculate_recall_at_k",
|
|
98
|
+
"calculate_success_at_k",
|
|
99
|
+
"calculate_nDCG_at_k",
|
|
100
|
+
"calculate_coverage",
|
|
101
|
+
"calculate_alpha_ndcg",
|
|
102
|
+
# Metrics - LLM Judge
|
|
103
|
+
"LMJudge",
|
|
104
|
+
"calculate_alignment_score",
|
|
105
|
+
# Metrics - Exact Match
|
|
106
|
+
"calculate_exact_match",
|
|
107
|
+
# Result serialization
|
|
108
|
+
"save_trial_results",
|
|
109
|
+
"save_ask_trial_results",
|
|
110
|
+
"save_trial_metrics",
|
|
111
|
+
"save_aggregated_results",
|
|
112
|
+
]
|
|
113
|
+
__version__ = "0.5"
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
from weaviate.agents.query import QueryAgent, AsyncQueryAgent
|
|
6
|
+
|
|
7
|
+
from query_agent_benchmarking.agent.base import BaseAgentBuilder
|
|
8
|
+
from query_agent_benchmarking.models import DocsCollection
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AskResponse:
|
|
13
|
+
"""Response from an ask query."""
|
|
14
|
+
final_answer: str
|
|
15
|
+
raw_response: Any # The full response object from the agent
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AskAgentBuilder(BaseAgentBuilder):
|
|
19
|
+
"""
|
|
20
|
+
Agent builder for ask mode operations.
|
|
21
|
+
|
|
22
|
+
Supports two agent types:
|
|
23
|
+
* `agent_name == "query-agent-ask"` → Wraps the Weaviate QueryAgent in Ask Mode.
|
|
24
|
+
* `agent_name == "external_service"` → Sends requests to an external host for RAG evaluation.
|
|
25
|
+
|
|
26
|
+
The "external_service" mode allows you to bring your own retrieval + generation system
|
|
27
|
+
and use the ask infrastructure for evaluation. It sends HTTP POST requests to
|
|
28
|
+
`external_service_host` with `question` and optionally `oracle_context_id`.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
agent_name: str,
|
|
34
|
+
dataset_name: Optional[str] = None,
|
|
35
|
+
docs_collection: Optional[DocsCollection] = None,
|
|
36
|
+
agents_host: Optional[str] = None,
|
|
37
|
+
use_async: bool = False,
|
|
38
|
+
embedding_model: Optional[str] = None,
|
|
39
|
+
external_service_host: Optional[str] = None,
|
|
40
|
+
system_prompt: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(
|
|
43
|
+
dataset_name=dataset_name,
|
|
44
|
+
docs_collection=docs_collection,
|
|
45
|
+
agents_host=agents_host,
|
|
46
|
+
use_async=use_async,
|
|
47
|
+
embedding_model=embedding_model,
|
|
48
|
+
system_prompt=system_prompt,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.agent_name = agent_name
|
|
52
|
+
self.external_service_host = external_service_host
|
|
53
|
+
self.weaviate_collection = None
|
|
54
|
+
|
|
55
|
+
if not use_async:
|
|
56
|
+
self.initialize_sync()
|
|
57
|
+
|
|
58
|
+
def initialize_sync(self):
|
|
59
|
+
if self.agent_name == "query-agent-ask":
|
|
60
|
+
self.weaviate_client = self._connect_sync()
|
|
61
|
+
agent_kwargs = dict(
|
|
62
|
+
client=self.weaviate_client,
|
|
63
|
+
collections=[self.collection],
|
|
64
|
+
agents_host=self.agents_host,
|
|
65
|
+
)
|
|
66
|
+
if self.system_prompt:
|
|
67
|
+
agent_kwargs["system_prompt"] = self.system_prompt
|
|
68
|
+
self.agent = QueryAgent(**agent_kwargs)
|
|
69
|
+
elif self.agent_name == "external_service":
|
|
70
|
+
# External service mode - no Weaviate connection needed
|
|
71
|
+
if not self.external_service_host:
|
|
72
|
+
raise ValueError("external_service_host is required for external_service mode")
|
|
73
|
+
print(f"External service mode initialized with host: {self.external_service_host}")
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Unknown agent_name: {self.agent_name}. "
|
|
77
|
+
"Must be 'query-agent-ask' or 'external_service'"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
async def initialize_async(self):
|
|
81
|
+
try:
|
|
82
|
+
if self.agent_name == "query-agent-ask":
|
|
83
|
+
self.weaviate_client = self._connect_async()
|
|
84
|
+
await self.weaviate_client.connect()
|
|
85
|
+
print("Async Weaviate client connected successfully")
|
|
86
|
+
agent_kwargs = dict(
|
|
87
|
+
client=self.weaviate_client,
|
|
88
|
+
collections=[self.collection],
|
|
89
|
+
agents_host=self.agents_host,
|
|
90
|
+
)
|
|
91
|
+
if self.system_prompt:
|
|
92
|
+
agent_kwargs["system_prompt"] = self.system_prompt
|
|
93
|
+
self.agent = AsyncQueryAgent(**agent_kwargs)
|
|
94
|
+
print(f"AsyncQueryAgent (ask mode) initialized for collection: {self.collection}")
|
|
95
|
+
print(f"Using agents host: {self.agents_host}")
|
|
96
|
+
elif self.agent_name == "external_service":
|
|
97
|
+
# External service mode - no Weaviate connection needed
|
|
98
|
+
if not self.external_service_host:
|
|
99
|
+
raise ValueError("external_service_host is required for external_service mode")
|
|
100
|
+
print(f"External service mode initialized with host: {self.external_service_host}")
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Unknown agent_name: {self.agent_name}. "
|
|
104
|
+
"Must be 'query-agent-ask' or 'external_service'"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(f"Failed to initialize async agent: {str(e)}")
|
|
109
|
+
import traceback
|
|
110
|
+
traceback.print_exc()
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
def run(
|
|
114
|
+
self,
|
|
115
|
+
query: str,
|
|
116
|
+
oracle_context_id: Optional[str] = None
|
|
117
|
+
) -> AskResponse:
|
|
118
|
+
"""
|
|
119
|
+
Run synchronous ask query.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
query: The question to ask.
|
|
123
|
+
oracle_context_id: Optional context ID to send to external host.
|
|
124
|
+
"""
|
|
125
|
+
if self.agent_name == "query-agent-ask":
|
|
126
|
+
response = self.agent.ask(query)
|
|
127
|
+
return AskResponse(
|
|
128
|
+
final_answer=response.final_answer,
|
|
129
|
+
raw_response=response
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
elif self.agent_name == "external_service":
|
|
133
|
+
# Build request payload
|
|
134
|
+
payload = {"question": query}
|
|
135
|
+
if oracle_context_id is not None:
|
|
136
|
+
payload["oracle_context_id"] = oracle_context_id
|
|
137
|
+
|
|
138
|
+
# Send request to external host
|
|
139
|
+
with httpx.Client(timeout=300.0) as client:
|
|
140
|
+
response = client.post(self.external_service_host, json=payload)
|
|
141
|
+
response.raise_for_status()
|
|
142
|
+
data = response.json()
|
|
143
|
+
|
|
144
|
+
return AskResponse(
|
|
145
|
+
final_answer=data.get("answer", ""),
|
|
146
|
+
raw_response=data
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
async def run_async(
|
|
150
|
+
self,
|
|
151
|
+
query: str,
|
|
152
|
+
oracle_context_id: Optional[str] = None
|
|
153
|
+
) -> AskResponse:
|
|
154
|
+
"""
|
|
155
|
+
Run asynchronous ask query.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
query: The question to ask.
|
|
159
|
+
oracle_context_id: Optional context ID to send to external host.
|
|
160
|
+
"""
|
|
161
|
+
try:
|
|
162
|
+
if self.agent_name == "query-agent-ask":
|
|
163
|
+
response = await self.agent.ask(query)
|
|
164
|
+
return AskResponse(
|
|
165
|
+
final_answer=response.final_answer,
|
|
166
|
+
raw_response=response
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
elif self.agent_name == "external_service":
|
|
170
|
+
# Build request payload
|
|
171
|
+
payload = {"question": query}
|
|
172
|
+
if oracle_context_id is not None:
|
|
173
|
+
payload["oracle_context_id"] = oracle_context_id
|
|
174
|
+
|
|
175
|
+
# Send async request to external host
|
|
176
|
+
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
177
|
+
response = await client.post(self.external_service_host, json=payload)
|
|
178
|
+
response.raise_for_status()
|
|
179
|
+
data = response.json()
|
|
180
|
+
|
|
181
|
+
return AskResponse(
|
|
182
|
+
final_answer=data.get("answer", ""),
|
|
183
|
+
raw_response=data
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"Ask query '{query[:50]}...' failed with error: {str(e)}")
|
|
188
|
+
raise
|
|
189
|
+
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import weaviate
|
|
6
|
+
from weaviate.auth import Auth
|
|
7
|
+
from weaviate.config import AdditionalConfig, Timeout
|
|
8
|
+
|
|
9
|
+
from query_agent_benchmarking.models import DocsCollection
|
|
10
|
+
from query_agent_benchmarking.database.database_registry import resolve_spec
|
|
11
|
+
from query_agent_benchmarking.utils import get_provider_headers, parse_embedding_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseAgentBuilder(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Base class for agent builders that handles common Weaviate connection logic
|
|
17
|
+
and dataset-to-collection mapping.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dataset_name: Optional[str] = None,
|
|
23
|
+
docs_collection: Optional[DocsCollection] = None,
|
|
24
|
+
agents_host: Optional[str] = None,
|
|
25
|
+
use_async: bool = False,
|
|
26
|
+
embedding_model: Optional[str] = None,
|
|
27
|
+
system_prompt: Optional[str] = None,
|
|
28
|
+
):
|
|
29
|
+
self.use_async = use_async
|
|
30
|
+
self.agent = None
|
|
31
|
+
self.weaviate_client = None
|
|
32
|
+
self.system_prompt = system_prompt
|
|
33
|
+
|
|
34
|
+
self.cluster_url = os.getenv("WEAVIATE_URL")
|
|
35
|
+
self.api_key = os.getenv("WEAVIATE_API_KEY")
|
|
36
|
+
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
37
|
+
|
|
38
|
+
# Get provider headers for third-party embedding providers
|
|
39
|
+
self.headers: dict[str, str] = {}
|
|
40
|
+
if embedding_model:
|
|
41
|
+
provider, _ = parse_embedding_model(embedding_model)
|
|
42
|
+
self.headers = get_provider_headers(provider)
|
|
43
|
+
|
|
44
|
+
# Require either dataset_name or docs_collection, but not both
|
|
45
|
+
if dataset_name and docs_collection:
|
|
46
|
+
raise ValueError("Cannot specify both dataset_name and docs_collection")
|
|
47
|
+
if not dataset_name and not docs_collection:
|
|
48
|
+
raise ValueError("Must specify either dataset_name or docs_collection")
|
|
49
|
+
|
|
50
|
+
self.dataset_name = dataset_name
|
|
51
|
+
|
|
52
|
+
# Handle custom DocsCollection
|
|
53
|
+
if docs_collection:
|
|
54
|
+
self.collection = docs_collection.collection_name
|
|
55
|
+
self.id_property = docs_collection.id_key
|
|
56
|
+
else:
|
|
57
|
+
spec = resolve_spec(dataset_name)
|
|
58
|
+
self.collection = f"{spec.name_fn(dataset_name)}_Default"
|
|
59
|
+
self.id_property = "dataset_id"
|
|
60
|
+
|
|
61
|
+
self.agents_host = agents_host or "https://api.agents.weaviate.io"
|
|
62
|
+
|
|
63
|
+
def _connect_sync(self) -> weaviate.WeaviateClient:
|
|
64
|
+
"""Create synchronous Weaviate connection."""
|
|
65
|
+
print(f"Initializing sync connection to {self.cluster_url}")
|
|
66
|
+
return weaviate.connect_to_weaviate_cloud(
|
|
67
|
+
cluster_url=self.cluster_url,
|
|
68
|
+
auth_credentials=weaviate.auth.AuthApiKey(self.api_key),
|
|
69
|
+
headers=self.headers,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _connect_async(self):
|
|
73
|
+
"""Create async Weaviate connection (returns client, must be awaited to connect)."""
|
|
74
|
+
print(f"Initializing async connection to {self.cluster_url}")
|
|
75
|
+
return weaviate.use_async_with_weaviate_cloud(
|
|
76
|
+
cluster_url=self.cluster_url,
|
|
77
|
+
auth_credentials=Auth.api_key(self.api_key),
|
|
78
|
+
headers=self.headers,
|
|
79
|
+
additional_config=AdditionalConfig(
|
|
80
|
+
timeout=Timeout(query=6000)
|
|
81
|
+
),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def initialize_sync(self):
|
|
86
|
+
"""Initialize synchronous agent. Must be implemented by subclasses."""
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
async def initialize_async(self):
|
|
91
|
+
"""Initialize asynchronous agent. Must be implemented by subclasses."""
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
async def close_async(self):
|
|
95
|
+
"""Close async connection."""
|
|
96
|
+
if self.use_async and self.weaviate_client:
|
|
97
|
+
try:
|
|
98
|
+
await self.weaviate_client.close()
|
|
99
|
+
print("Async connection closed successfully")
|
|
100
|
+
except Exception as e:
|
|
101
|
+
print(f"Warning: Error closing async connection: {str(e)}")
|
|
102
|
+
|
|
103
|
+
def close_sync(self):
|
|
104
|
+
"""Close sync connection."""
|
|
105
|
+
if not self.use_async and self.weaviate_client:
|
|
106
|
+
try:
|
|
107
|
+
self.weaviate_client.close()
|
|
108
|
+
print("Sync connection closed successfully")
|
|
109
|
+
except Exception as e:
|
|
110
|
+
print(f"Warning: Error closing sync connection: {str(e)}")
|
|
111
|
+
|