haiku.rag 0.3.1__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/PKG-INFO +2 -1
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/README.md +1 -0
- haiku_rag-0.3.3/docs/benchmarks.md +27 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/mkdocs.yml +1 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/pyproject.toml +1 -1
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/qa/__init__.py +2 -5
- haiku_rag-0.3.3/src/haiku/rag/qa/anthropic.py +106 -0
- haiku_rag-0.3.3/src/haiku/rag/qa/ollama.py +64 -0
- haiku_rag-0.3.3/src/haiku/rag/qa/openai.py +100 -0
- haiku_rag-0.3.3/src/haiku/rag/qa/prompts.py +20 -0
- haiku_rag-0.3.3/tests/generate_benchmark_db.py +151 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/llm_judge.py +1 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/uv.lock +1 -1
- haiku_rag-0.3.1/BENCHMARKS.md +0 -13
- haiku_rag-0.3.1/src/haiku/rag/qa/anthropic.py +0 -112
- haiku_rag-0.3.1/src/haiku/rag/qa/ollama.py +0 -67
- haiku_rag-0.3.1/src/haiku/rag/qa/openai.py +0 -101
- haiku_rag-0.3.1/src/haiku/rag/qa/prompts.py +0 -7
- haiku_rag-0.3.1/tests/generate_benchmark_db.py +0 -129
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/FUNDING.yml +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/workflows/build-docs.yml +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/workflows/build-publish.yml +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.gitignore +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.pre-commit-config.yaml +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.python-version +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/LICENSE +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/cli.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/configuration.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/index.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/installation.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/mcp.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/python.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/server.md +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/app.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/chunker.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/cli.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/client.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/config.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/base.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/ollama.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/openai.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/voyageai.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/logging.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/mcp.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/monitor.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/qa/base.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/reader.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/engine.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/chunk.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/document.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/base.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/chunk.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/document.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/utils.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/__init__.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/conftest.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_app.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_chunk.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_chunker.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_cli.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_client.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_document.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_embedder.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_monitor.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_qa.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_rebuild.py +0 -0
- {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_search.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haiku.rag
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: Retrieval Augmented Generation (RAG) with SQLite
|
|
5
5
|
Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -116,3 +116,4 @@ Full documentation at: https://ggozad.github.io/haiku.rag/
|
|
|
116
116
|
- [Configuration](https://ggozad.github.io/haiku.rag/configuration/) - Environment variables
|
|
117
117
|
- [CLI](https://ggozad.github.io/haiku.rag/cli/) - Command reference
|
|
118
118
|
- [Python API](https://ggozad.github.io/haiku.rag/python/) - Complete API docs
|
|
119
|
+
- [Benchmarks](https://ggozad.github.io/haiku.rag/benchmarks/) - Performance Benchmarks
|
|
@@ -77,3 +77,4 @@ Full documentation at: https://ggozad.github.io/haiku.rag/
|
|
|
77
77
|
- [Configuration](https://ggozad.github.io/haiku.rag/configuration/) - Environment variables
|
|
78
78
|
- [CLI](https://ggozad.github.io/haiku.rag/cli/) - Command reference
|
|
79
79
|
- [Python API](https://ggozad.github.io/haiku.rag/python/) - Complete API docs
|
|
80
|
+
- [Benchmarks](https://ggozad.github.io/haiku.rag/benchmarks/) - Performance Benchmarks
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Benchmarks
|
|
2
|
+
|
|
3
|
+
We use the [repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) dataset for the evaluation of `haiku.rag`.
|
|
4
|
+
|
|
5
|
+
You can perform your own evaluations using as example the script found at
|
|
6
|
+
`tests/generate_benchmark_db.py`.
|
|
7
|
+
|
|
8
|
+
## Recall
|
|
9
|
+
|
|
10
|
+
In order to calculate recall, we load the `News Stories` from `repliqa_3` which is 1035 documents and index them in a sqlite db. Subsequently, we run a search over the `question` field for each row of the dataset and check whether we match the document that answers the question.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
The recall obtained is ~0.73 for matching in the top result, raising to ~0.75 for the top 3 results.
|
|
14
|
+
|
|
15
|
+
| Model | Document in top 1 | Document in top 3 |
|
|
16
|
+
|---------------------------------------|-------------------|-------------------|
|
|
17
|
+
| Ollama / `mxbai-embed-large` | 0.73 | 0.75 |
|
|
18
|
+
| OpenAI / `text-embeddings-3-small` | 0.75 | 0.88 |
|
|
19
|
+
|
|
20
|
+
## Question/Answer evaluation
|
|
21
|
+
|
|
22
|
+
Again using the same dataset, we use a QA agent to answer the question. In addition we use an LLM judge (using the Ollama `qwen3`) to evaluate whether the answer is correct or not. The obtained accuracy is as follows:
|
|
23
|
+
|
|
24
|
+
| Embedding Model | QA Model | Accuracy |
|
|
25
|
+
|------------------------------|-----------------------------------|-----------|
|
|
26
|
+
| Ollama / `mxbai-embed-large` | Ollama / `qwen3` | 0.64 |
|
|
27
|
+
| Ollama / `mxbai-embed-large` | Anthropic / `Claude Sonnet 3.7` | 0.79 |
|
|
@@ -8,7 +8,6 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
|
|
|
8
8
|
"""
|
|
9
9
|
Factory function to get the appropriate QA agent based on the configuration.
|
|
10
10
|
"""
|
|
11
|
-
|
|
12
11
|
if Config.QA_PROVIDER == "ollama":
|
|
13
12
|
return QuestionAnswerOllamaAgent(client, model or Config.QA_MODEL)
|
|
14
13
|
|
|
@@ -21,7 +20,7 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
|
|
|
21
20
|
"Please install haiku.rag with the 'openai' extra:"
|
|
22
21
|
"uv pip install haiku.rag --extra openai"
|
|
23
22
|
)
|
|
24
|
-
return QuestionAnswerOpenAIAgent(client, model or
|
|
23
|
+
return QuestionAnswerOpenAIAgent(client, model or Config.QA_MODEL)
|
|
25
24
|
|
|
26
25
|
if Config.QA_PROVIDER == "anthropic":
|
|
27
26
|
try:
|
|
@@ -32,8 +31,6 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
|
|
|
32
31
|
"Please install haiku.rag with the 'anthropic' extra:"
|
|
33
32
|
"uv pip install haiku.rag --extra anthropic"
|
|
34
33
|
)
|
|
35
|
-
return QuestionAnswerAnthropicAgent(
|
|
36
|
-
client, model or "claude-3-5-haiku-20241022"
|
|
37
|
-
)
|
|
34
|
+
return QuestionAnswerAnthropicAgent(client, model or Config.QA_MODEL)
|
|
38
35
|
|
|
39
36
|
raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from anthropic import AsyncAnthropic
|
|
5
|
+
from anthropic.types import MessageParam, TextBlock, ToolParam, ToolUseBlock
|
|
6
|
+
|
|
7
|
+
from haiku.rag.client import HaikuRAG
|
|
8
|
+
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
9
|
+
|
|
10
|
+
class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
|
|
11
|
+
def __init__(self, client: HaikuRAG, model: str = "claude-3-5-haiku-20241022"):
|
|
12
|
+
super().__init__(client, model or self._model)
|
|
13
|
+
self.tools: Sequence[ToolParam] = [
|
|
14
|
+
ToolParam(
|
|
15
|
+
name="search_documents",
|
|
16
|
+
description="Search the knowledge base for relevant documents",
|
|
17
|
+
input_schema={
|
|
18
|
+
"type": "object",
|
|
19
|
+
"properties": {
|
|
20
|
+
"query": {
|
|
21
|
+
"type": "string",
|
|
22
|
+
"description": "The search query to find relevant documents",
|
|
23
|
+
},
|
|
24
|
+
"limit": {
|
|
25
|
+
"type": "integer",
|
|
26
|
+
"description": "Maximum number of results to return",
|
|
27
|
+
"default": 3,
|
|
28
|
+
},
|
|
29
|
+
},
|
|
30
|
+
"required": ["query"],
|
|
31
|
+
},
|
|
32
|
+
)
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
async def answer(self, question: str) -> str:
|
|
36
|
+
anthropic_client = AsyncAnthropic()
|
|
37
|
+
|
|
38
|
+
messages: list[MessageParam] = [{"role": "user", "content": question}]
|
|
39
|
+
|
|
40
|
+
max_rounds = 5 # Prevent infinite loops
|
|
41
|
+
|
|
42
|
+
for _ in range(max_rounds):
|
|
43
|
+
response = await anthropic_client.messages.create(
|
|
44
|
+
model=self._model,
|
|
45
|
+
max_tokens=4096,
|
|
46
|
+
system=self._system_prompt,
|
|
47
|
+
messages=messages,
|
|
48
|
+
tools=self.tools,
|
|
49
|
+
temperature=0.0,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if response.stop_reason == "tool_use":
|
|
53
|
+
messages.append({"role": "assistant", "content": response.content})
|
|
54
|
+
|
|
55
|
+
# Process tool calls
|
|
56
|
+
tool_results = []
|
|
57
|
+
for content_block in response.content:
|
|
58
|
+
if isinstance(content_block, ToolUseBlock):
|
|
59
|
+
if content_block.name == "search_documents":
|
|
60
|
+
args = content_block.input
|
|
61
|
+
query = (
|
|
62
|
+
args.get("query", question)
|
|
63
|
+
if isinstance(args, dict)
|
|
64
|
+
else question
|
|
65
|
+
)
|
|
66
|
+
limit = (
|
|
67
|
+
int(args.get("limit", 3))
|
|
68
|
+
if isinstance(args, dict)
|
|
69
|
+
else 3
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
search_results = await self._client.search(
|
|
73
|
+
query, limit=limit
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
context_chunks = []
|
|
77
|
+
for chunk, score in search_results:
|
|
78
|
+
context_chunks.append(
|
|
79
|
+
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
context = "\n\n".join(context_chunks)
|
|
83
|
+
|
|
84
|
+
tool_results.append(
|
|
85
|
+
{
|
|
86
|
+
"type": "tool_result",
|
|
87
|
+
"tool_use_id": content_block.id,
|
|
88
|
+
"content": context,
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if tool_results:
|
|
93
|
+
messages.append({"role": "user", "content": tool_results})
|
|
94
|
+
else:
|
|
95
|
+
# No tool use, return the response
|
|
96
|
+
if response.content:
|
|
97
|
+
first_content = response.content[0]
|
|
98
|
+
if isinstance(first_content, TextBlock):
|
|
99
|
+
return first_content.text
|
|
100
|
+
return ""
|
|
101
|
+
|
|
102
|
+
# If we've exhausted max rounds, return empty string
|
|
103
|
+
return ""
|
|
104
|
+
|
|
105
|
+
except ImportError:
|
|
106
|
+
pass
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from ollama import AsyncClient
|
|
2
|
+
|
|
3
|
+
from haiku.rag.client import HaikuRAG
|
|
4
|
+
from haiku.rag.config import Config
|
|
5
|
+
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
6
|
+
|
|
7
|
+
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 64000}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
|
|
11
|
+
def __init__(self, client: HaikuRAG, model: str = Config.QA_MODEL):
|
|
12
|
+
super().__init__(client, model or self._model)
|
|
13
|
+
|
|
14
|
+
async def answer(self, question: str) -> str:
|
|
15
|
+
ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
16
|
+
|
|
17
|
+
messages = [
|
|
18
|
+
{"role": "system", "content": self._system_prompt},
|
|
19
|
+
{"role": "user", "content": question},
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
max_rounds = 5 # Prevent infinite loops
|
|
23
|
+
|
|
24
|
+
for _ in range(max_rounds):
|
|
25
|
+
response = await ollama_client.chat(
|
|
26
|
+
model=self._model,
|
|
27
|
+
messages=messages,
|
|
28
|
+
tools=self.tools,
|
|
29
|
+
options=OLLAMA_OPTIONS,
|
|
30
|
+
think=False,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if response.get("message", {}).get("tool_calls"):
|
|
34
|
+
messages.append(response["message"])
|
|
35
|
+
|
|
36
|
+
for tool_call in response["message"]["tool_calls"]:
|
|
37
|
+
if tool_call["function"]["name"] == "search_documents":
|
|
38
|
+
args = tool_call["function"]["arguments"]
|
|
39
|
+
query = args.get("query", question)
|
|
40
|
+
limit = int(args.get("limit", 3))
|
|
41
|
+
|
|
42
|
+
search_results = await self._client.search(query, limit=limit)
|
|
43
|
+
|
|
44
|
+
context_chunks = []
|
|
45
|
+
for chunk, score in search_results:
|
|
46
|
+
context_chunks.append(
|
|
47
|
+
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
context = "\n\n".join(context_chunks)
|
|
51
|
+
|
|
52
|
+
messages.append(
|
|
53
|
+
{
|
|
54
|
+
"role": "tool",
|
|
55
|
+
"content": context,
|
|
56
|
+
"tool_call_id": tool_call.get("id", "search_tool"),
|
|
57
|
+
}
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
# No tool calls, return the response
|
|
61
|
+
return response["message"]["content"]
|
|
62
|
+
|
|
63
|
+
# If we've exhausted max rounds, return empty string
|
|
64
|
+
return ""
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
from openai.types.chat import (
|
|
6
|
+
ChatCompletionAssistantMessageParam,
|
|
7
|
+
ChatCompletionMessageParam,
|
|
8
|
+
ChatCompletionSystemMessageParam,
|
|
9
|
+
ChatCompletionToolMessageParam,
|
|
10
|
+
ChatCompletionUserMessageParam,
|
|
11
|
+
)
|
|
12
|
+
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
13
|
+
|
|
14
|
+
from haiku.rag.client import HaikuRAG
|
|
15
|
+
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
16
|
+
|
|
17
|
+
class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
|
|
18
|
+
def __init__(self, client: HaikuRAG, model: str = "gpt-4o-mini"):
|
|
19
|
+
super().__init__(client, model or self._model)
|
|
20
|
+
self.tools: Sequence[ChatCompletionToolParam] = [
|
|
21
|
+
ChatCompletionToolParam(tool) for tool in self.tools
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
async def answer(self, question: str) -> str:
|
|
25
|
+
openai_client = AsyncOpenAI()
|
|
26
|
+
|
|
27
|
+
messages: list[ChatCompletionMessageParam] = [
|
|
28
|
+
ChatCompletionSystemMessageParam(
|
|
29
|
+
role="system", content=self._system_prompt
|
|
30
|
+
),
|
|
31
|
+
ChatCompletionUserMessageParam(role="user", content=question),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
max_rounds = 5 # Prevent infinite loops
|
|
35
|
+
|
|
36
|
+
for _ in range(max_rounds):
|
|
37
|
+
response = await openai_client.chat.completions.create(
|
|
38
|
+
model=self._model,
|
|
39
|
+
messages=messages,
|
|
40
|
+
tools=self.tools,
|
|
41
|
+
temperature=0.0,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
response_message = response.choices[0].message
|
|
45
|
+
|
|
46
|
+
if response_message.tool_calls:
|
|
47
|
+
messages.append(
|
|
48
|
+
ChatCompletionAssistantMessageParam(
|
|
49
|
+
role="assistant",
|
|
50
|
+
content=response_message.content,
|
|
51
|
+
tool_calls=[
|
|
52
|
+
{
|
|
53
|
+
"id": tc.id,
|
|
54
|
+
"type": "function",
|
|
55
|
+
"function": {
|
|
56
|
+
"name": tc.function.name,
|
|
57
|
+
"arguments": tc.function.arguments,
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
for tc in response_message.tool_calls
|
|
61
|
+
],
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
for tool_call in response_message.tool_calls:
|
|
66
|
+
if tool_call.function.name == "search_documents":
|
|
67
|
+
import json
|
|
68
|
+
|
|
69
|
+
args = json.loads(tool_call.function.arguments)
|
|
70
|
+
query = args.get("query", question)
|
|
71
|
+
limit = int(args.get("limit", 3))
|
|
72
|
+
|
|
73
|
+
search_results = await self._client.search(
|
|
74
|
+
query, limit=limit
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
context_chunks = []
|
|
78
|
+
for chunk, score in search_results:
|
|
79
|
+
context_chunks.append(
|
|
80
|
+
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
context = "\n\n".join(context_chunks)
|
|
84
|
+
|
|
85
|
+
messages.append(
|
|
86
|
+
ChatCompletionToolMessageParam(
|
|
87
|
+
role="tool",
|
|
88
|
+
content=context,
|
|
89
|
+
tool_call_id=tool_call.id,
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
# No tool calls, return the response
|
|
94
|
+
return response_message.content or ""
|
|
95
|
+
|
|
96
|
+
# If we've exhausted max rounds, return empty string
|
|
97
|
+
return ""
|
|
98
|
+
|
|
99
|
+
except ImportError:
|
|
100
|
+
pass
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
SYSTEM_PROMPT = """
|
|
2
|
+
You are a knowledgeable assistant that helps users find information from a document knowledge base.
|
|
3
|
+
|
|
4
|
+
Your process:
|
|
5
|
+
1. When a user asks a question, use the search_documents tool to find relevant information
|
|
6
|
+
2. Search with specific keywords and phrases from the user's question
|
|
7
|
+
3. Review the search results and their relevance scores
|
|
8
|
+
4. If you need additional context, perform follow-up searches with different keywords
|
|
9
|
+
5. Provide a comprehensive answer based only on the retrieved documents
|
|
10
|
+
|
|
11
|
+
Guidelines:
|
|
12
|
+
- Base your answers strictly on the provided document content
|
|
13
|
+
- Quote or reference specific information when possible
|
|
14
|
+
- If multiple documents contain relevant information, synthesize them coherently
|
|
15
|
+
- Indicate when information is incomplete or when you need to search for additional context
|
|
16
|
+
- If the retrieved documents don't contain sufficient information, clearly state: "I cannot find enough information in the knowledge base to answer this question."
|
|
17
|
+
- For complex questions, consider breaking them down and performing multiple searches
|
|
18
|
+
|
|
19
|
+
Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
|
|
20
|
+
"""
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset, load_dataset
|
|
5
|
+
from llm_judge import LLMJudge
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.progress import Progress
|
|
8
|
+
|
|
9
|
+
from haiku.rag.client import HaikuRAG
|
|
10
|
+
from haiku.rag.qa import get_qa_agent
|
|
11
|
+
|
|
12
|
+
console = Console()
|
|
13
|
+
|
|
14
|
+
db_path = Path(__file__).parent / "data" / "benchmark.sqlite"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def populate_db():
|
|
18
|
+
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
19
|
+
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
20
|
+
|
|
21
|
+
with Progress() as progress:
|
|
22
|
+
task = progress.add_task("[green]Populating database...", total=len(corpus))
|
|
23
|
+
|
|
24
|
+
async with HaikuRAG(db_path) as rag:
|
|
25
|
+
for doc in corpus:
|
|
26
|
+
uri = doc["document_id"] # type: ignore
|
|
27
|
+
existing_doc = await rag.get_document_by_uri(uri)
|
|
28
|
+
if existing_doc is not None:
|
|
29
|
+
progress.advance(task)
|
|
30
|
+
continue
|
|
31
|
+
|
|
32
|
+
await rag.create_document(
|
|
33
|
+
content=doc["document_extracted"], # type: ignore
|
|
34
|
+
uri=uri,
|
|
35
|
+
)
|
|
36
|
+
progress.advance(task)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def run_match_benchmark():
|
|
40
|
+
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
41
|
+
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
42
|
+
|
|
43
|
+
correct_at_1 = 0
|
|
44
|
+
correct_at_2 = 0
|
|
45
|
+
correct_at_3 = 0
|
|
46
|
+
total_queries = 0
|
|
47
|
+
|
|
48
|
+
with Progress() as progress:
|
|
49
|
+
task = progress.add_task(
|
|
50
|
+
"[blue]Running retrieval benchmark...", total=len(corpus)
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
async with HaikuRAG(db_path) as rag:
|
|
54
|
+
for doc in corpus:
|
|
55
|
+
doc_id = doc["document_id"] # type: ignore
|
|
56
|
+
matches = await rag.search(
|
|
57
|
+
query=doc["question"], # type: ignore
|
|
58
|
+
limit=3,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
total_queries += 1
|
|
62
|
+
|
|
63
|
+
# Check position of correct document in results
|
|
64
|
+
for position, (chunk, _) in enumerate(matches):
|
|
65
|
+
retrieved = await rag.get_document_by_id(chunk.document_id)
|
|
66
|
+
if retrieved and retrieved.uri == doc_id:
|
|
67
|
+
if position == 0: # First position
|
|
68
|
+
correct_at_1 += 1
|
|
69
|
+
correct_at_2 += 1
|
|
70
|
+
correct_at_3 += 1
|
|
71
|
+
elif position == 1: # Second position
|
|
72
|
+
correct_at_2 += 1
|
|
73
|
+
correct_at_3 += 1
|
|
74
|
+
elif position == 2: # Third position
|
|
75
|
+
correct_at_3 += 1
|
|
76
|
+
break
|
|
77
|
+
|
|
78
|
+
progress.advance(task)
|
|
79
|
+
|
|
80
|
+
# Calculate recall metrics
|
|
81
|
+
recall_at_1 = correct_at_1 / total_queries
|
|
82
|
+
recall_at_2 = correct_at_2 / total_queries
|
|
83
|
+
recall_at_3 = correct_at_3 / total_queries
|
|
84
|
+
|
|
85
|
+
console.print("\n=== Retrieval Benchmark Results ===", style="bold cyan")
|
|
86
|
+
console.print(f"Total queries: {total_queries}")
|
|
87
|
+
console.print(f"Recall@1: {recall_at_1:.4f}")
|
|
88
|
+
console.print(f"Recall@2: {recall_at_2:.4f}")
|
|
89
|
+
console.print(f"Recall@3: {recall_at_3:.4f}")
|
|
90
|
+
|
|
91
|
+
return {"recall@1": recall_at_1, "recall@2": recall_at_2, "recall@3": recall_at_3}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
async def run_qa_benchmark(k: int | None = None):
|
|
95
|
+
"""Run QA benchmarking on the corpus."""
|
|
96
|
+
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
97
|
+
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
98
|
+
|
|
99
|
+
if k is not None:
|
|
100
|
+
corpus = corpus.select(range(min(k, len(corpus))))
|
|
101
|
+
|
|
102
|
+
judge = LLMJudge()
|
|
103
|
+
correct_answers = 0
|
|
104
|
+
total_questions = 0
|
|
105
|
+
|
|
106
|
+
with Progress() as progress:
|
|
107
|
+
task = progress.add_task("[yellow]Running QA benchmark...", total=len(corpus))
|
|
108
|
+
|
|
109
|
+
async with HaikuRAG(db_path) as rag:
|
|
110
|
+
qa = get_qa_agent(rag)
|
|
111
|
+
|
|
112
|
+
for doc in corpus:
|
|
113
|
+
question = doc["question"] # type: ignore
|
|
114
|
+
expected_answer = doc["answer"] # type: ignore
|
|
115
|
+
|
|
116
|
+
generated_answer = await qa.answer(question)
|
|
117
|
+
is_equivalent = await judge.judge_answers(
|
|
118
|
+
question, generated_answer, expected_answer
|
|
119
|
+
)
|
|
120
|
+
console.print(f"Question: {question}")
|
|
121
|
+
console.print(f"Expected: {expected_answer}")
|
|
122
|
+
console.print(f"Generated: {generated_answer}")
|
|
123
|
+
console.print(f"Equivalent: {is_equivalent}\n")
|
|
124
|
+
|
|
125
|
+
if is_equivalent:
|
|
126
|
+
correct_answers += 1
|
|
127
|
+
total_questions += 1
|
|
128
|
+
console.print("Current score:", correct_answers, "/", total_questions)
|
|
129
|
+
|
|
130
|
+
progress.advance(task)
|
|
131
|
+
|
|
132
|
+
accuracy = correct_answers / total_questions if total_questions > 0 else 0
|
|
133
|
+
|
|
134
|
+
console.print("\n=== QA Benchmark Results ===", style="bold cyan")
|
|
135
|
+
console.print(f"Total questions: {total_questions}")
|
|
136
|
+
console.print(f"Correct answers: {correct_answers}")
|
|
137
|
+
console.print(f"QA Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def main():
|
|
141
|
+
await populate_db()
|
|
142
|
+
|
|
143
|
+
console.print("Running retrieval benchmarks...", style="bold blue")
|
|
144
|
+
await run_match_benchmark()
|
|
145
|
+
|
|
146
|
+
console.print("\nRunning QA benchmarks...", style="bold yellow")
|
|
147
|
+
await run_qa_benchmark()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
if __name__ == "__main__":
|
|
151
|
+
asyncio.run(main())
|
|
@@ -49,6 +49,7 @@ class LLMJudge:
|
|
|
49
49
|
1. Do both answers provide the same answer?
|
|
50
50
|
2. Do both answers directly address the question asked?
|
|
51
51
|
3. Minor differences in wording or style are acceptable if the meaning of the answer is the same.
|
|
52
|
+
4. If one answer is more detailed but the other is correct, they can still be considered equivalent.
|
|
52
53
|
|
|
53
54
|
Be strict but fair in your evaluation. Focus on factual correctness and whether both answers would satisfy someone asking the question."""
|
|
54
55
|
|
haiku_rag-0.3.1/BENCHMARKS.md
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
# `haiku.rag` benchmarks
|
|
2
|
-
|
|
3
|
-
We use [repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) for the evaluation of `haiku.rag`
|
|
4
|
-
|
|
5
|
-
* Recall
|
|
6
|
-
|
|
7
|
-
We load the `News Stories` from `repliqa_3` which is 1035 documents, using `tests/generate_benchmark_db.py`, using the `mxbai-embed-large` Ollama embeddings.
|
|
8
|
-
|
|
9
|
-
Subsequently, we run a search over the `question` for each row of the dataset and check whether we match the document that answers the question. The recall obtained is ~0.75 for matching in the top result, raising to ~0.75 for the top 3 results.
|
|
10
|
-
|
|
11
|
-
* Question/Answer evaluation
|
|
12
|
-
|
|
13
|
-
We use the `News Stories` from `repliqa_3` using the `mxbai-embed-large` Ollama embeddings, with a QA agent also using Ollama with the `qwen3` model (8b). For each story we ask the `question` and use an LLM judge (also `qwen3`) to evaluate whether the answer is correct or not. Thus we obtain accuracy of ~0.54.
|
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from anthropic import AsyncAnthropic
|
|
5
|
-
from anthropic.types import MessageParam, TextBlock, ToolParam, ToolUseBlock
|
|
6
|
-
|
|
7
|
-
from haiku.rag.client import HaikuRAG
|
|
8
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
9
|
-
|
|
10
|
-
class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
|
|
11
|
-
def __init__(self, client: HaikuRAG, model: str = "claude-3-5-haiku-20241022"):
|
|
12
|
-
super().__init__(client, model or self._model)
|
|
13
|
-
self.tools: Sequence[ToolParam] = [
|
|
14
|
-
ToolParam(
|
|
15
|
-
name="search_documents",
|
|
16
|
-
description="Search the knowledge base for relevant documents",
|
|
17
|
-
input_schema={
|
|
18
|
-
"type": "object",
|
|
19
|
-
"properties": {
|
|
20
|
-
"query": {
|
|
21
|
-
"type": "string",
|
|
22
|
-
"description": "The search query to find relevant documents",
|
|
23
|
-
},
|
|
24
|
-
"limit": {
|
|
25
|
-
"type": "integer",
|
|
26
|
-
"description": "Maximum number of results to return",
|
|
27
|
-
"default": 3,
|
|
28
|
-
},
|
|
29
|
-
},
|
|
30
|
-
"required": ["query"],
|
|
31
|
-
},
|
|
32
|
-
)
|
|
33
|
-
]
|
|
34
|
-
|
|
35
|
-
async def answer(self, question: str) -> str:
|
|
36
|
-
anthropic_client = AsyncAnthropic()
|
|
37
|
-
|
|
38
|
-
messages: list[MessageParam] = [{"role": "user", "content": question}]
|
|
39
|
-
|
|
40
|
-
response = await anthropic_client.messages.create(
|
|
41
|
-
model=self._model,
|
|
42
|
-
max_tokens=4096,
|
|
43
|
-
system=self._system_prompt,
|
|
44
|
-
messages=messages,
|
|
45
|
-
tools=self.tools,
|
|
46
|
-
temperature=0.0,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
if response.stop_reason == "tool_use":
|
|
50
|
-
messages.append({"role": "assistant", "content": response.content})
|
|
51
|
-
|
|
52
|
-
# Process tool calls
|
|
53
|
-
tool_results = []
|
|
54
|
-
for content_block in response.content:
|
|
55
|
-
if isinstance(content_block, ToolUseBlock):
|
|
56
|
-
if content_block.name == "search_documents":
|
|
57
|
-
args = content_block.input
|
|
58
|
-
query = (
|
|
59
|
-
args.get("query", question)
|
|
60
|
-
if isinstance(args, dict)
|
|
61
|
-
else question
|
|
62
|
-
)
|
|
63
|
-
limit = (
|
|
64
|
-
int(args.get("limit", 3))
|
|
65
|
-
if isinstance(args, dict)
|
|
66
|
-
else 3
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
search_results = await self._client.search(
|
|
70
|
-
query, limit=limit
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
context_chunks = []
|
|
74
|
-
for chunk, score in search_results:
|
|
75
|
-
context_chunks.append(
|
|
76
|
-
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
context = "\n\n".join(context_chunks)
|
|
80
|
-
|
|
81
|
-
tool_results.append(
|
|
82
|
-
{
|
|
83
|
-
"type": "tool_result",
|
|
84
|
-
"tool_use_id": content_block.id,
|
|
85
|
-
"content": context,
|
|
86
|
-
}
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if tool_results:
|
|
90
|
-
messages.append({"role": "user", "content": tool_results})
|
|
91
|
-
|
|
92
|
-
final_response = await anthropic_client.messages.create(
|
|
93
|
-
model=self._model,
|
|
94
|
-
max_tokens=4096,
|
|
95
|
-
system=self._system_prompt,
|
|
96
|
-
messages=messages,
|
|
97
|
-
temperature=0.0,
|
|
98
|
-
)
|
|
99
|
-
if final_response.content:
|
|
100
|
-
first_content = final_response.content[0]
|
|
101
|
-
if isinstance(first_content, TextBlock):
|
|
102
|
-
return first_content.text
|
|
103
|
-
return ""
|
|
104
|
-
|
|
105
|
-
if response.content:
|
|
106
|
-
first_content = response.content[0]
|
|
107
|
-
if isinstance(first_content, TextBlock):
|
|
108
|
-
return first_content.text
|
|
109
|
-
return ""
|
|
110
|
-
|
|
111
|
-
except ImportError:
|
|
112
|
-
pass
|
|
@@ -1,67 +0,0 @@
|
|
|
1
|
-
from ollama import AsyncClient
|
|
2
|
-
|
|
3
|
-
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.config import Config
|
|
5
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
6
|
-
|
|
7
|
-
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 64000}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
|
|
11
|
-
def __init__(self, client: HaikuRAG, model: str = Config.QA_MODEL):
|
|
12
|
-
super().__init__(client, model or self._model)
|
|
13
|
-
|
|
14
|
-
async def answer(self, question: str) -> str:
|
|
15
|
-
ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
16
|
-
|
|
17
|
-
# Define the search tool
|
|
18
|
-
|
|
19
|
-
messages = [
|
|
20
|
-
{"role": "system", "content": self._system_prompt},
|
|
21
|
-
{"role": "user", "content": question},
|
|
22
|
-
]
|
|
23
|
-
|
|
24
|
-
# Initial response with tool calling
|
|
25
|
-
response = await ollama_client.chat(
|
|
26
|
-
model=self._model,
|
|
27
|
-
messages=messages,
|
|
28
|
-
tools=self.tools,
|
|
29
|
-
options=OLLAMA_OPTIONS,
|
|
30
|
-
think=False,
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
if response.get("message", {}).get("tool_calls"):
|
|
34
|
-
for tool_call in response["message"]["tool_calls"]:
|
|
35
|
-
if tool_call["function"]["name"] == "search_documents":
|
|
36
|
-
args = tool_call["function"]["arguments"]
|
|
37
|
-
query = args.get("query", question)
|
|
38
|
-
limit = int(args.get("limit", 3))
|
|
39
|
-
|
|
40
|
-
search_results = await self._client.search(query, limit=limit)
|
|
41
|
-
|
|
42
|
-
context_chunks = []
|
|
43
|
-
for chunk, score in search_results:
|
|
44
|
-
context_chunks.append(
|
|
45
|
-
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
context = "\n\n".join(context_chunks)
|
|
49
|
-
|
|
50
|
-
messages.append(response["message"])
|
|
51
|
-
messages.append(
|
|
52
|
-
{
|
|
53
|
-
"role": "tool",
|
|
54
|
-
"content": context,
|
|
55
|
-
"tool_call_id": tool_call.get("id", "search_tool"),
|
|
56
|
-
}
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
final_response = await ollama_client.chat(
|
|
60
|
-
model=self._model,
|
|
61
|
-
messages=messages,
|
|
62
|
-
think=False,
|
|
63
|
-
options=OLLAMA_OPTIONS,
|
|
64
|
-
)
|
|
65
|
-
return final_response["message"]["content"]
|
|
66
|
-
else:
|
|
67
|
-
return response["message"]["content"]
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from openai import AsyncOpenAI
|
|
5
|
-
from openai.types.chat import (
|
|
6
|
-
ChatCompletionAssistantMessageParam,
|
|
7
|
-
ChatCompletionMessageParam,
|
|
8
|
-
ChatCompletionSystemMessageParam,
|
|
9
|
-
ChatCompletionToolMessageParam,
|
|
10
|
-
ChatCompletionUserMessageParam,
|
|
11
|
-
)
|
|
12
|
-
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
13
|
-
|
|
14
|
-
from haiku.rag.client import HaikuRAG
|
|
15
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
16
|
-
|
|
17
|
-
class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
|
|
18
|
-
def __init__(self, client: HaikuRAG, model: str = "gpt-4o-mini"):
|
|
19
|
-
super().__init__(client, model or self._model)
|
|
20
|
-
self.tools: Sequence[ChatCompletionToolParam] = [
|
|
21
|
-
ChatCompletionToolParam(tool) for tool in self.tools
|
|
22
|
-
]
|
|
23
|
-
|
|
24
|
-
async def answer(self, question: str) -> str:
|
|
25
|
-
openai_client = AsyncOpenAI()
|
|
26
|
-
|
|
27
|
-
# Define the search tool
|
|
28
|
-
|
|
29
|
-
messages: list[ChatCompletionMessageParam] = [
|
|
30
|
-
ChatCompletionSystemMessageParam(
|
|
31
|
-
role="system", content=self._system_prompt
|
|
32
|
-
),
|
|
33
|
-
ChatCompletionUserMessageParam(role="user", content=question),
|
|
34
|
-
]
|
|
35
|
-
|
|
36
|
-
# Initial response with tool calling
|
|
37
|
-
response = await openai_client.chat.completions.create(
|
|
38
|
-
model=self._model,
|
|
39
|
-
messages=messages,
|
|
40
|
-
tools=self.tools,
|
|
41
|
-
temperature=0.0,
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
response_message = response.choices[0].message
|
|
45
|
-
|
|
46
|
-
if response_message.tool_calls:
|
|
47
|
-
messages.append(
|
|
48
|
-
ChatCompletionAssistantMessageParam(
|
|
49
|
-
role="assistant",
|
|
50
|
-
content=response_message.content,
|
|
51
|
-
tool_calls=[
|
|
52
|
-
{
|
|
53
|
-
"id": tc.id,
|
|
54
|
-
"type": "function",
|
|
55
|
-
"function": {
|
|
56
|
-
"name": tc.function.name,
|
|
57
|
-
"arguments": tc.function.arguments,
|
|
58
|
-
},
|
|
59
|
-
}
|
|
60
|
-
for tc in response_message.tool_calls
|
|
61
|
-
],
|
|
62
|
-
)
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
for tool_call in response_message.tool_calls:
|
|
66
|
-
if tool_call.function.name == "search_documents":
|
|
67
|
-
import json
|
|
68
|
-
|
|
69
|
-
args = json.loads(tool_call.function.arguments)
|
|
70
|
-
query = args.get("query", question)
|
|
71
|
-
limit = int(args.get("limit", 3))
|
|
72
|
-
|
|
73
|
-
search_results = await self._client.search(query, limit=limit)
|
|
74
|
-
|
|
75
|
-
context_chunks = []
|
|
76
|
-
for chunk, score in search_results:
|
|
77
|
-
context_chunks.append(
|
|
78
|
-
f"Content: {chunk.content}\nScore: {score:.4f}"
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
context = "\n\n".join(context_chunks)
|
|
82
|
-
|
|
83
|
-
messages.append(
|
|
84
|
-
ChatCompletionToolMessageParam(
|
|
85
|
-
role="tool",
|
|
86
|
-
content=context,
|
|
87
|
-
tool_call_id=tool_call.id,
|
|
88
|
-
)
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
final_response = await openai_client.chat.completions.create(
|
|
92
|
-
model=self._model,
|
|
93
|
-
messages=messages,
|
|
94
|
-
temperature=0.0,
|
|
95
|
-
)
|
|
96
|
-
return final_response.choices[0].message.content or ""
|
|
97
|
-
else:
|
|
98
|
-
return response_message.content or ""
|
|
99
|
-
|
|
100
|
-
except ImportError:
|
|
101
|
-
pass
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
SYSTEM_PROMPT = """
|
|
2
|
-
You are a helpful assistant that uses a RAG library to answer the user's prompt.
|
|
3
|
-
Your task is to provide a concise and accurate answer based on the provided context.
|
|
4
|
-
You should ask the provided tools to find relevant documents and then use the content of those documents to answer the question.
|
|
5
|
-
Never make up information, always use the context to answer the question.
|
|
6
|
-
If the context does not contain enough information to answer the question, respond with "I cannot answer that based on the provided context."
|
|
7
|
-
"""
|
|
@@ -1,129 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
|
|
4
|
-
from datasets import Dataset, load_dataset
|
|
5
|
-
from llm_judge import LLMJudge
|
|
6
|
-
from tqdm import tqdm
|
|
7
|
-
|
|
8
|
-
from haiku.rag.client import HaikuRAG
|
|
9
|
-
from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
|
|
10
|
-
|
|
11
|
-
db_path = Path(__file__).parent / "data" / "benchmark.sqlite"
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
async def populate_db():
|
|
15
|
-
if (db_path).exists():
|
|
16
|
-
print("Benchmark database already exists. Skipping creation.")
|
|
17
|
-
return
|
|
18
|
-
|
|
19
|
-
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
20
|
-
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
21
|
-
|
|
22
|
-
async with HaikuRAG(db_path) as rag:
|
|
23
|
-
for i, doc in enumerate(tqdm(corpus)):
|
|
24
|
-
await rag.create_document(
|
|
25
|
-
content=doc["document_extracted"], # type: ignore
|
|
26
|
-
uri=doc["document_id"], # type: ignore
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
async def run_match_benchmark():
|
|
31
|
-
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
32
|
-
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
33
|
-
|
|
34
|
-
correct_at_1 = 0
|
|
35
|
-
correct_at_2 = 0
|
|
36
|
-
correct_at_3 = 0
|
|
37
|
-
total_queries = 0
|
|
38
|
-
|
|
39
|
-
async with HaikuRAG(db_path) as rag:
|
|
40
|
-
for i, doc in enumerate(tqdm(corpus)):
|
|
41
|
-
doc_id = doc["document_id"] # type: ignore
|
|
42
|
-
matches = await rag.search(
|
|
43
|
-
query=doc["question"], # type: ignore
|
|
44
|
-
limit=3,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
total_queries += 1
|
|
48
|
-
|
|
49
|
-
# Check position of correct document in results
|
|
50
|
-
for position, (chunk, _) in enumerate(matches):
|
|
51
|
-
retrieved = await rag.get_document_by_id(chunk.document_id)
|
|
52
|
-
if retrieved and retrieved.uri == doc_id:
|
|
53
|
-
if position == 0: # First position
|
|
54
|
-
correct_at_1 += 1
|
|
55
|
-
correct_at_2 += 1
|
|
56
|
-
correct_at_3 += 1
|
|
57
|
-
elif position == 1: # Second position
|
|
58
|
-
correct_at_2 += 1
|
|
59
|
-
correct_at_3 += 1
|
|
60
|
-
elif position == 2: # Third position
|
|
61
|
-
correct_at_3 += 1
|
|
62
|
-
break
|
|
63
|
-
|
|
64
|
-
# Calculate recall metrics
|
|
65
|
-
recall_at_1 = correct_at_1 / total_queries
|
|
66
|
-
recall_at_2 = correct_at_2 / total_queries
|
|
67
|
-
recall_at_3 = correct_at_3 / total_queries
|
|
68
|
-
|
|
69
|
-
print("\n=== Retrieval Benchmark Results ===")
|
|
70
|
-
print(f"Total queries: {total_queries}")
|
|
71
|
-
print(f"Recall@1: {recall_at_1:.4f}")
|
|
72
|
-
print(f"Recall@2: {recall_at_2:.4f}")
|
|
73
|
-
print(f"Recall@3: {recall_at_3:.4f}")
|
|
74
|
-
|
|
75
|
-
return {"recall@1": recall_at_1, "recall@2": recall_at_2, "recall@3": recall_at_3}
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
async def run_qa_benchmark(k: int | None = None):
|
|
79
|
-
"""Run QA benchmarking on the corpus."""
|
|
80
|
-
ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
|
|
81
|
-
corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
82
|
-
|
|
83
|
-
if k is not None:
|
|
84
|
-
corpus = corpus.select(range(min(k, len(corpus))))
|
|
85
|
-
|
|
86
|
-
judge = LLMJudge()
|
|
87
|
-
correct_answers = 0
|
|
88
|
-
total_questions = 0
|
|
89
|
-
|
|
90
|
-
async with HaikuRAG(db_path) as rag:
|
|
91
|
-
qa = QuestionAnswerOllamaAgent(rag)
|
|
92
|
-
|
|
93
|
-
for i, doc in enumerate(tqdm(corpus, desc="QA Benchmarking")):
|
|
94
|
-
question = doc["question"] # type: ignore
|
|
95
|
-
expected_answer = doc["answer"] # type: ignore
|
|
96
|
-
|
|
97
|
-
generated_answer = await qa.answer(question)
|
|
98
|
-
is_equivalent = await judge.judge_answers(
|
|
99
|
-
question, generated_answer, expected_answer
|
|
100
|
-
)
|
|
101
|
-
print(f"Question: {question}")
|
|
102
|
-
print(f"Expected: {expected_answer}")
|
|
103
|
-
print(f"Generated: {generated_answer}")
|
|
104
|
-
print(f"Equivalent: {is_equivalent}\n")
|
|
105
|
-
|
|
106
|
-
if is_equivalent:
|
|
107
|
-
correct_answers += 1
|
|
108
|
-
total_questions += 1
|
|
109
|
-
|
|
110
|
-
accuracy = correct_answers / total_questions if total_questions > 0 else 0
|
|
111
|
-
|
|
112
|
-
print("\n=== QA Benchmark Results ===")
|
|
113
|
-
print(f"Total questions: {total_questions}")
|
|
114
|
-
print(f"Correct answers: {correct_answers}")
|
|
115
|
-
print(f"QA Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
async def main():
|
|
119
|
-
await populate_db()
|
|
120
|
-
|
|
121
|
-
print("Running retrieval benchmarks...")
|
|
122
|
-
await run_match_benchmark()
|
|
123
|
-
|
|
124
|
-
print("\nRunning QA benchmarks...")
|
|
125
|
-
await run_qa_benchmark()
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
if __name__ == "__main__":
|
|
129
|
-
asyncio.run(main())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|