haiku.rag 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

Files changed (72) hide show
  1. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/PKG-INFO +2 -1
  2. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/README.md +1 -0
  3. haiku_rag-0.3.3/docs/benchmarks.md +27 -0
  4. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/mkdocs.yml +1 -0
  5. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/pyproject.toml +1 -1
  6. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/qa/__init__.py +2 -5
  7. haiku_rag-0.3.3/src/haiku/rag/qa/anthropic.py +106 -0
  8. haiku_rag-0.3.3/src/haiku/rag/qa/ollama.py +64 -0
  9. haiku_rag-0.3.3/src/haiku/rag/qa/openai.py +100 -0
  10. haiku_rag-0.3.3/src/haiku/rag/qa/prompts.py +20 -0
  11. haiku_rag-0.3.3/tests/generate_benchmark_db.py +151 -0
  12. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/llm_judge.py +1 -0
  13. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/uv.lock +1 -1
  14. haiku_rag-0.3.1/BENCHMARKS.md +0 -13
  15. haiku_rag-0.3.1/src/haiku/rag/qa/anthropic.py +0 -112
  16. haiku_rag-0.3.1/src/haiku/rag/qa/ollama.py +0 -67
  17. haiku_rag-0.3.1/src/haiku/rag/qa/openai.py +0 -101
  18. haiku_rag-0.3.1/src/haiku/rag/qa/prompts.py +0 -7
  19. haiku_rag-0.3.1/tests/generate_benchmark_db.py +0 -129
  20. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/FUNDING.yml +0 -0
  21. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/workflows/build-docs.yml +0 -0
  22. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.github/workflows/build-publish.yml +0 -0
  23. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.gitignore +0 -0
  24. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.pre-commit-config.yaml +0 -0
  25. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/.python-version +0 -0
  26. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/LICENSE +0 -0
  27. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/cli.md +0 -0
  28. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/configuration.md +0 -0
  29. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/index.md +0 -0
  30. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/installation.md +0 -0
  31. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/mcp.md +0 -0
  32. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/python.md +0 -0
  33. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/docs/server.md +0 -0
  34. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/__init__.py +0 -0
  35. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/app.py +0 -0
  36. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/chunker.py +0 -0
  37. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/cli.py +0 -0
  38. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/client.py +0 -0
  39. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/config.py +0 -0
  40. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/__init__.py +0 -0
  41. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/base.py +0 -0
  42. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/ollama.py +0 -0
  43. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/openai.py +0 -0
  44. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/embeddings/voyageai.py +0 -0
  45. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/logging.py +0 -0
  46. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/mcp.py +0 -0
  47. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/monitor.py +0 -0
  48. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/qa/base.py +0 -0
  49. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/reader.py +0 -0
  50. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/__init__.py +0 -0
  51. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/engine.py +0 -0
  52. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/__init__.py +0 -0
  53. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/chunk.py +0 -0
  54. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/models/document.py +0 -0
  55. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/__init__.py +0 -0
  56. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/base.py +0 -0
  57. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/chunk.py +0 -0
  58. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/store/repositories/document.py +0 -0
  59. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/src/haiku/rag/utils.py +0 -0
  60. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/__init__.py +0 -0
  61. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/conftest.py +0 -0
  62. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_app.py +0 -0
  63. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_chunk.py +0 -0
  64. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_chunker.py +0 -0
  65. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_cli.py +0 -0
  66. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_client.py +0 -0
  67. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_document.py +0 -0
  68. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_embedder.py +0 -0
  69. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_monitor.py +0 -0
  70. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_qa.py +0 -0
  71. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_rebuild.py +0 -0
  72. {haiku_rag-0.3.1 → haiku_rag-0.3.3}/tests/test_search.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haiku.rag
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Retrieval Augmented Generation (RAG) with SQLite
5
5
  Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
6
6
  License: MIT
@@ -116,3 +116,4 @@ Full documentation at: https://ggozad.github.io/haiku.rag/
116
116
  - [Configuration](https://ggozad.github.io/haiku.rag/configuration/) - Environment variables
117
117
  - [CLI](https://ggozad.github.io/haiku.rag/cli/) - Command reference
118
118
  - [Python API](https://ggozad.github.io/haiku.rag/python/) - Complete API docs
119
+ - [Benchmarks](https://ggozad.github.io/haiku.rag/benchmarks/) - Performance Benchmarks
@@ -77,3 +77,4 @@ Full documentation at: https://ggozad.github.io/haiku.rag/
77
77
  - [Configuration](https://ggozad.github.io/haiku.rag/configuration/) - Environment variables
78
78
  - [CLI](https://ggozad.github.io/haiku.rag/cli/) - Command reference
79
79
  - [Python API](https://ggozad.github.io/haiku.rag/python/) - Complete API docs
80
+ - [Benchmarks](https://ggozad.github.io/haiku.rag/benchmarks/) - Performance Benchmarks
@@ -0,0 +1,27 @@
1
+ # Benchmarks
2
+
3
+ We use the [repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) dataset for the evaluation of `haiku.rag`.
4
+
5
+ You can perform your own evaluations using as example the script found at
6
+ `tests/generate_benchmark_db.py`.
7
+
8
+ ## Recall
9
+
10
+ In order to calculate recall, we load the `News Stories` from `repliqa_3` which is 1035 documents and index them in a sqlite db. Subsequently, we run a search over the `question` field for each row of the dataset and check whether we match the document that answers the question.
11
+
12
+
13
+ The recall obtained is ~0.73 for matching in the top result, raising to ~0.75 for the top 3 results.
14
+
15
+ | Model | Document in top 1 | Document in top 3 |
16
+ |---------------------------------------|-------------------|-------------------|
17
+ | Ollama / `mxbai-embed-large` | 0.73 | 0.75 |
18
+ | OpenAI / `text-embeddings-3-small` | 0.75 | 0.88 |
19
+
20
+ ## Question/Answer evaluation
21
+
22
+ Again using the same dataset, we use a QA agent to answer the question. In addition we use an LLM judge (using the Ollama `qwen3`) to evaluate whether the answer is correct or not. The obtained accuracy is as follows:
23
+
24
+ | Embedding Model | QA Model | Accuracy |
25
+ |------------------------------|-----------------------------------|-----------|
26
+ | Ollama / `mxbai-embed-large` | Ollama / `qwen3` | 0.64 |
27
+ | Ollama / `mxbai-embed-large` | Anthropic / `Claude Sonnet 3.7` | 0.79 |
@@ -63,6 +63,7 @@ nav:
63
63
  - Server: server.md
64
64
  - MCP: mcp.md
65
65
  - Python: python.md
66
+ - Benchmarks: benchmarks.md
66
67
  markdown_extensions:
67
68
  - admonition
68
69
  - attr_list
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "haiku.rag"
3
- version = "0.3.1"
3
+ version = "0.3.3"
4
4
  description = "Retrieval Augmented Generation (RAG) with SQLite"
5
5
  authors = [{ name = "Yiorgis Gozadinos", email = "ggozadinos@gmail.com" }]
6
6
  license = { text = "MIT" }
@@ -8,7 +8,6 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
8
8
  """
9
9
  Factory function to get the appropriate QA agent based on the configuration.
10
10
  """
11
-
12
11
  if Config.QA_PROVIDER == "ollama":
13
12
  return QuestionAnswerOllamaAgent(client, model or Config.QA_MODEL)
14
13
 
@@ -21,7 +20,7 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
21
20
  "Please install haiku.rag with the 'openai' extra:"
22
21
  "uv pip install haiku.rag --extra openai"
23
22
  )
24
- return QuestionAnswerOpenAIAgent(client, model or "gpt-4o-mini")
23
+ return QuestionAnswerOpenAIAgent(client, model or Config.QA_MODEL)
25
24
 
26
25
  if Config.QA_PROVIDER == "anthropic":
27
26
  try:
@@ -32,8 +31,6 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
32
31
  "Please install haiku.rag with the 'anthropic' extra:"
33
32
  "uv pip install haiku.rag --extra anthropic"
34
33
  )
35
- return QuestionAnswerAnthropicAgent(
36
- client, model or "claude-3-5-haiku-20241022"
37
- )
34
+ return QuestionAnswerAnthropicAgent(client, model or Config.QA_MODEL)
38
35
 
39
36
  raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
@@ -0,0 +1,106 @@
1
+ from collections.abc import Sequence
2
+
3
+ try:
4
+ from anthropic import AsyncAnthropic
5
+ from anthropic.types import MessageParam, TextBlock, ToolParam, ToolUseBlock
6
+
7
+ from haiku.rag.client import HaikuRAG
8
+ from haiku.rag.qa.base import QuestionAnswerAgentBase
9
+
10
+ class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
11
+ def __init__(self, client: HaikuRAG, model: str = "claude-3-5-haiku-20241022"):
12
+ super().__init__(client, model or self._model)
13
+ self.tools: Sequence[ToolParam] = [
14
+ ToolParam(
15
+ name="search_documents",
16
+ description="Search the knowledge base for relevant documents",
17
+ input_schema={
18
+ "type": "object",
19
+ "properties": {
20
+ "query": {
21
+ "type": "string",
22
+ "description": "The search query to find relevant documents",
23
+ },
24
+ "limit": {
25
+ "type": "integer",
26
+ "description": "Maximum number of results to return",
27
+ "default": 3,
28
+ },
29
+ },
30
+ "required": ["query"],
31
+ },
32
+ )
33
+ ]
34
+
35
+ async def answer(self, question: str) -> str:
36
+ anthropic_client = AsyncAnthropic()
37
+
38
+ messages: list[MessageParam] = [{"role": "user", "content": question}]
39
+
40
+ max_rounds = 5 # Prevent infinite loops
41
+
42
+ for _ in range(max_rounds):
43
+ response = await anthropic_client.messages.create(
44
+ model=self._model,
45
+ max_tokens=4096,
46
+ system=self._system_prompt,
47
+ messages=messages,
48
+ tools=self.tools,
49
+ temperature=0.0,
50
+ )
51
+
52
+ if response.stop_reason == "tool_use":
53
+ messages.append({"role": "assistant", "content": response.content})
54
+
55
+ # Process tool calls
56
+ tool_results = []
57
+ for content_block in response.content:
58
+ if isinstance(content_block, ToolUseBlock):
59
+ if content_block.name == "search_documents":
60
+ args = content_block.input
61
+ query = (
62
+ args.get("query", question)
63
+ if isinstance(args, dict)
64
+ else question
65
+ )
66
+ limit = (
67
+ int(args.get("limit", 3))
68
+ if isinstance(args, dict)
69
+ else 3
70
+ )
71
+
72
+ search_results = await self._client.search(
73
+ query, limit=limit
74
+ )
75
+
76
+ context_chunks = []
77
+ for chunk, score in search_results:
78
+ context_chunks.append(
79
+ f"Content: {chunk.content}\nScore: {score:.4f}"
80
+ )
81
+
82
+ context = "\n\n".join(context_chunks)
83
+
84
+ tool_results.append(
85
+ {
86
+ "type": "tool_result",
87
+ "tool_use_id": content_block.id,
88
+ "content": context,
89
+ }
90
+ )
91
+
92
+ if tool_results:
93
+ messages.append({"role": "user", "content": tool_results})
94
+ else:
95
+ # No tool use, return the response
96
+ if response.content:
97
+ first_content = response.content[0]
98
+ if isinstance(first_content, TextBlock):
99
+ return first_content.text
100
+ return ""
101
+
102
+ # If we've exhausted max rounds, return empty string
103
+ return ""
104
+
105
+ except ImportError:
106
+ pass
@@ -0,0 +1,64 @@
1
+ from ollama import AsyncClient
2
+
3
+ from haiku.rag.client import HaikuRAG
4
+ from haiku.rag.config import Config
5
+ from haiku.rag.qa.base import QuestionAnswerAgentBase
6
+
7
+ OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 64000}
8
+
9
+
10
+ class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
11
+ def __init__(self, client: HaikuRAG, model: str = Config.QA_MODEL):
12
+ super().__init__(client, model or self._model)
13
+
14
+ async def answer(self, question: str) -> str:
15
+ ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
16
+
17
+ messages = [
18
+ {"role": "system", "content": self._system_prompt},
19
+ {"role": "user", "content": question},
20
+ ]
21
+
22
+ max_rounds = 5 # Prevent infinite loops
23
+
24
+ for _ in range(max_rounds):
25
+ response = await ollama_client.chat(
26
+ model=self._model,
27
+ messages=messages,
28
+ tools=self.tools,
29
+ options=OLLAMA_OPTIONS,
30
+ think=False,
31
+ )
32
+
33
+ if response.get("message", {}).get("tool_calls"):
34
+ messages.append(response["message"])
35
+
36
+ for tool_call in response["message"]["tool_calls"]:
37
+ if tool_call["function"]["name"] == "search_documents":
38
+ args = tool_call["function"]["arguments"]
39
+ query = args.get("query", question)
40
+ limit = int(args.get("limit", 3))
41
+
42
+ search_results = await self._client.search(query, limit=limit)
43
+
44
+ context_chunks = []
45
+ for chunk, score in search_results:
46
+ context_chunks.append(
47
+ f"Content: {chunk.content}\nScore: {score:.4f}"
48
+ )
49
+
50
+ context = "\n\n".join(context_chunks)
51
+
52
+ messages.append(
53
+ {
54
+ "role": "tool",
55
+ "content": context,
56
+ "tool_call_id": tool_call.get("id", "search_tool"),
57
+ }
58
+ )
59
+ else:
60
+ # No tool calls, return the response
61
+ return response["message"]["content"]
62
+
63
+ # If we've exhausted max rounds, return empty string
64
+ return ""
@@ -0,0 +1,100 @@
1
+ from collections.abc import Sequence
2
+
3
+ try:
4
+ from openai import AsyncOpenAI
5
+ from openai.types.chat import (
6
+ ChatCompletionAssistantMessageParam,
7
+ ChatCompletionMessageParam,
8
+ ChatCompletionSystemMessageParam,
9
+ ChatCompletionToolMessageParam,
10
+ ChatCompletionUserMessageParam,
11
+ )
12
+ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
13
+
14
+ from haiku.rag.client import HaikuRAG
15
+ from haiku.rag.qa.base import QuestionAnswerAgentBase
16
+
17
+ class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
18
+ def __init__(self, client: HaikuRAG, model: str = "gpt-4o-mini"):
19
+ super().__init__(client, model or self._model)
20
+ self.tools: Sequence[ChatCompletionToolParam] = [
21
+ ChatCompletionToolParam(tool) for tool in self.tools
22
+ ]
23
+
24
+ async def answer(self, question: str) -> str:
25
+ openai_client = AsyncOpenAI()
26
+
27
+ messages: list[ChatCompletionMessageParam] = [
28
+ ChatCompletionSystemMessageParam(
29
+ role="system", content=self._system_prompt
30
+ ),
31
+ ChatCompletionUserMessageParam(role="user", content=question),
32
+ ]
33
+
34
+ max_rounds = 5 # Prevent infinite loops
35
+
36
+ for _ in range(max_rounds):
37
+ response = await openai_client.chat.completions.create(
38
+ model=self._model,
39
+ messages=messages,
40
+ tools=self.tools,
41
+ temperature=0.0,
42
+ )
43
+
44
+ response_message = response.choices[0].message
45
+
46
+ if response_message.tool_calls:
47
+ messages.append(
48
+ ChatCompletionAssistantMessageParam(
49
+ role="assistant",
50
+ content=response_message.content,
51
+ tool_calls=[
52
+ {
53
+ "id": tc.id,
54
+ "type": "function",
55
+ "function": {
56
+ "name": tc.function.name,
57
+ "arguments": tc.function.arguments,
58
+ },
59
+ }
60
+ for tc in response_message.tool_calls
61
+ ],
62
+ )
63
+ )
64
+
65
+ for tool_call in response_message.tool_calls:
66
+ if tool_call.function.name == "search_documents":
67
+ import json
68
+
69
+ args = json.loads(tool_call.function.arguments)
70
+ query = args.get("query", question)
71
+ limit = int(args.get("limit", 3))
72
+
73
+ search_results = await self._client.search(
74
+ query, limit=limit
75
+ )
76
+
77
+ context_chunks = []
78
+ for chunk, score in search_results:
79
+ context_chunks.append(
80
+ f"Content: {chunk.content}\nScore: {score:.4f}"
81
+ )
82
+
83
+ context = "\n\n".join(context_chunks)
84
+
85
+ messages.append(
86
+ ChatCompletionToolMessageParam(
87
+ role="tool",
88
+ content=context,
89
+ tool_call_id=tool_call.id,
90
+ )
91
+ )
92
+ else:
93
+ # No tool calls, return the response
94
+ return response_message.content or ""
95
+
96
+ # If we've exhausted max rounds, return empty string
97
+ return ""
98
+
99
+ except ImportError:
100
+ pass
@@ -0,0 +1,20 @@
1
+ SYSTEM_PROMPT = """
2
+ You are a knowledgeable assistant that helps users find information from a document knowledge base.
3
+
4
+ Your process:
5
+ 1. When a user asks a question, use the search_documents tool to find relevant information
6
+ 2. Search with specific keywords and phrases from the user's question
7
+ 3. Review the search results and their relevance scores
8
+ 4. If you need additional context, perform follow-up searches with different keywords
9
+ 5. Provide a comprehensive answer based only on the retrieved documents
10
+
11
+ Guidelines:
12
+ - Base your answers strictly on the provided document content
13
+ - Quote or reference specific information when possible
14
+ - If multiple documents contain relevant information, synthesize them coherently
15
+ - Indicate when information is incomplete or when you need to search for additional context
16
+ - If the retrieved documents don't contain sufficient information, clearly state: "I cannot find enough information in the knowledge base to answer this question."
17
+ - For complex questions, consider breaking them down and performing multiple searches
18
+
19
+ Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
20
+ """
@@ -0,0 +1,151 @@
1
+ import asyncio
2
+ from pathlib import Path
3
+
4
+ from datasets import Dataset, load_dataset
5
+ from llm_judge import LLMJudge
6
+ from rich.console import Console
7
+ from rich.progress import Progress
8
+
9
+ from haiku.rag.client import HaikuRAG
10
+ from haiku.rag.qa import get_qa_agent
11
+
12
+ console = Console()
13
+
14
+ db_path = Path(__file__).parent / "data" / "benchmark.sqlite"
15
+
16
+
17
+ async def populate_db():
18
+ ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
19
+ corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
20
+
21
+ with Progress() as progress:
22
+ task = progress.add_task("[green]Populating database...", total=len(corpus))
23
+
24
+ async with HaikuRAG(db_path) as rag:
25
+ for doc in corpus:
26
+ uri = doc["document_id"] # type: ignore
27
+ existing_doc = await rag.get_document_by_uri(uri)
28
+ if existing_doc is not None:
29
+ progress.advance(task)
30
+ continue
31
+
32
+ await rag.create_document(
33
+ content=doc["document_extracted"], # type: ignore
34
+ uri=uri,
35
+ )
36
+ progress.advance(task)
37
+
38
+
39
+ async def run_match_benchmark():
40
+ ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
41
+ corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
42
+
43
+ correct_at_1 = 0
44
+ correct_at_2 = 0
45
+ correct_at_3 = 0
46
+ total_queries = 0
47
+
48
+ with Progress() as progress:
49
+ task = progress.add_task(
50
+ "[blue]Running retrieval benchmark...", total=len(corpus)
51
+ )
52
+
53
+ async with HaikuRAG(db_path) as rag:
54
+ for doc in corpus:
55
+ doc_id = doc["document_id"] # type: ignore
56
+ matches = await rag.search(
57
+ query=doc["question"], # type: ignore
58
+ limit=3,
59
+ )
60
+
61
+ total_queries += 1
62
+
63
+ # Check position of correct document in results
64
+ for position, (chunk, _) in enumerate(matches):
65
+ retrieved = await rag.get_document_by_id(chunk.document_id)
66
+ if retrieved and retrieved.uri == doc_id:
67
+ if position == 0: # First position
68
+ correct_at_1 += 1
69
+ correct_at_2 += 1
70
+ correct_at_3 += 1
71
+ elif position == 1: # Second position
72
+ correct_at_2 += 1
73
+ correct_at_3 += 1
74
+ elif position == 2: # Third position
75
+ correct_at_3 += 1
76
+ break
77
+
78
+ progress.advance(task)
79
+
80
+ # Calculate recall metrics
81
+ recall_at_1 = correct_at_1 / total_queries
82
+ recall_at_2 = correct_at_2 / total_queries
83
+ recall_at_3 = correct_at_3 / total_queries
84
+
85
+ console.print("\n=== Retrieval Benchmark Results ===", style="bold cyan")
86
+ console.print(f"Total queries: {total_queries}")
87
+ console.print(f"Recall@1: {recall_at_1:.4f}")
88
+ console.print(f"Recall@2: {recall_at_2:.4f}")
89
+ console.print(f"Recall@3: {recall_at_3:.4f}")
90
+
91
+ return {"recall@1": recall_at_1, "recall@2": recall_at_2, "recall@3": recall_at_3}
92
+
93
+
94
+ async def run_qa_benchmark(k: int | None = None):
95
+ """Run QA benchmarking on the corpus."""
96
+ ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
97
+ corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
98
+
99
+ if k is not None:
100
+ corpus = corpus.select(range(min(k, len(corpus))))
101
+
102
+ judge = LLMJudge()
103
+ correct_answers = 0
104
+ total_questions = 0
105
+
106
+ with Progress() as progress:
107
+ task = progress.add_task("[yellow]Running QA benchmark...", total=len(corpus))
108
+
109
+ async with HaikuRAG(db_path) as rag:
110
+ qa = get_qa_agent(rag)
111
+
112
+ for doc in corpus:
113
+ question = doc["question"] # type: ignore
114
+ expected_answer = doc["answer"] # type: ignore
115
+
116
+ generated_answer = await qa.answer(question)
117
+ is_equivalent = await judge.judge_answers(
118
+ question, generated_answer, expected_answer
119
+ )
120
+ console.print(f"Question: {question}")
121
+ console.print(f"Expected: {expected_answer}")
122
+ console.print(f"Generated: {generated_answer}")
123
+ console.print(f"Equivalent: {is_equivalent}\n")
124
+
125
+ if is_equivalent:
126
+ correct_answers += 1
127
+ total_questions += 1
128
+ console.print("Current score:", correct_answers, "/", total_questions)
129
+
130
+ progress.advance(task)
131
+
132
+ accuracy = correct_answers / total_questions if total_questions > 0 else 0
133
+
134
+ console.print("\n=== QA Benchmark Results ===", style="bold cyan")
135
+ console.print(f"Total questions: {total_questions}")
136
+ console.print(f"Correct answers: {correct_answers}")
137
+ console.print(f"QA Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
138
+
139
+
140
+ async def main():
141
+ await populate_db()
142
+
143
+ console.print("Running retrieval benchmarks...", style="bold blue")
144
+ await run_match_benchmark()
145
+
146
+ console.print("\nRunning QA benchmarks...", style="bold yellow")
147
+ await run_qa_benchmark()
148
+
149
+
150
+ if __name__ == "__main__":
151
+ asyncio.run(main())
@@ -49,6 +49,7 @@ class LLMJudge:
49
49
  1. Do both answers provide the same answer?
50
50
  2. Do both answers directly address the question asked?
51
51
  3. Minor differences in wording or style are acceptable if the meaning of the answer is the same.
52
+ 4. If one answer is more detailed but the other is correct, they can still be considered equivalent.
52
53
 
53
54
  Be strict but fair in your evaluation. Focus on factual correctness and whether both answers would satisfy someone asking the question."""
54
55
 
@@ -816,7 +816,7 @@ wheels = [
816
816
 
817
817
  [[package]]
818
818
  name = "haiku-rag"
819
- version = "0.3.1"
819
+ version = "0.3.3"
820
820
  source = { editable = "." }
821
821
  dependencies = [
822
822
  { name = "fastmcp" },
@@ -1,13 +0,0 @@
1
- # `haiku.rag` benchmarks
2
-
3
- We use [repliqa](https://huggingface.co/datasets/ServiceNow/repliqa) for the evaluation of `haiku.rag`
4
-
5
- * Recall
6
-
7
- We load the `News Stories` from `repliqa_3` which is 1035 documents, using `tests/generate_benchmark_db.py`, using the `mxbai-embed-large` Ollama embeddings.
8
-
9
- Subsequently, we run a search over the `question` for each row of the dataset and check whether we match the document that answers the question. The recall obtained is ~0.75 for matching in the top result, raising to ~0.75 for the top 3 results.
10
-
11
- * Question/Answer evaluation
12
-
13
- We use the `News Stories` from `repliqa_3` using the `mxbai-embed-large` Ollama embeddings, with a QA agent also using Ollama with the `qwen3` model (8b). For each story we ask the `question` and use an LLM judge (also `qwen3`) to evaluate whether the answer is correct or not. Thus we obtain accuracy of ~0.54.
@@ -1,112 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from anthropic import AsyncAnthropic
5
- from anthropic.types import MessageParam, TextBlock, ToolParam, ToolUseBlock
6
-
7
- from haiku.rag.client import HaikuRAG
8
- from haiku.rag.qa.base import QuestionAnswerAgentBase
9
-
10
- class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
11
- def __init__(self, client: HaikuRAG, model: str = "claude-3-5-haiku-20241022"):
12
- super().__init__(client, model or self._model)
13
- self.tools: Sequence[ToolParam] = [
14
- ToolParam(
15
- name="search_documents",
16
- description="Search the knowledge base for relevant documents",
17
- input_schema={
18
- "type": "object",
19
- "properties": {
20
- "query": {
21
- "type": "string",
22
- "description": "The search query to find relevant documents",
23
- },
24
- "limit": {
25
- "type": "integer",
26
- "description": "Maximum number of results to return",
27
- "default": 3,
28
- },
29
- },
30
- "required": ["query"],
31
- },
32
- )
33
- ]
34
-
35
- async def answer(self, question: str) -> str:
36
- anthropic_client = AsyncAnthropic()
37
-
38
- messages: list[MessageParam] = [{"role": "user", "content": question}]
39
-
40
- response = await anthropic_client.messages.create(
41
- model=self._model,
42
- max_tokens=4096,
43
- system=self._system_prompt,
44
- messages=messages,
45
- tools=self.tools,
46
- temperature=0.0,
47
- )
48
-
49
- if response.stop_reason == "tool_use":
50
- messages.append({"role": "assistant", "content": response.content})
51
-
52
- # Process tool calls
53
- tool_results = []
54
- for content_block in response.content:
55
- if isinstance(content_block, ToolUseBlock):
56
- if content_block.name == "search_documents":
57
- args = content_block.input
58
- query = (
59
- args.get("query", question)
60
- if isinstance(args, dict)
61
- else question
62
- )
63
- limit = (
64
- int(args.get("limit", 3))
65
- if isinstance(args, dict)
66
- else 3
67
- )
68
-
69
- search_results = await self._client.search(
70
- query, limit=limit
71
- )
72
-
73
- context_chunks = []
74
- for chunk, score in search_results:
75
- context_chunks.append(
76
- f"Content: {chunk.content}\nScore: {score:.4f}"
77
- )
78
-
79
- context = "\n\n".join(context_chunks)
80
-
81
- tool_results.append(
82
- {
83
- "type": "tool_result",
84
- "tool_use_id": content_block.id,
85
- "content": context,
86
- }
87
- )
88
-
89
- if tool_results:
90
- messages.append({"role": "user", "content": tool_results})
91
-
92
- final_response = await anthropic_client.messages.create(
93
- model=self._model,
94
- max_tokens=4096,
95
- system=self._system_prompt,
96
- messages=messages,
97
- temperature=0.0,
98
- )
99
- if final_response.content:
100
- first_content = final_response.content[0]
101
- if isinstance(first_content, TextBlock):
102
- return first_content.text
103
- return ""
104
-
105
- if response.content:
106
- first_content = response.content[0]
107
- if isinstance(first_content, TextBlock):
108
- return first_content.text
109
- return ""
110
-
111
- except ImportError:
112
- pass
@@ -1,67 +0,0 @@
1
- from ollama import AsyncClient
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.config import Config
5
- from haiku.rag.qa.base import QuestionAnswerAgentBase
6
-
7
- OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 64000}
8
-
9
-
10
- class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
11
- def __init__(self, client: HaikuRAG, model: str = Config.QA_MODEL):
12
- super().__init__(client, model or self._model)
13
-
14
- async def answer(self, question: str) -> str:
15
- ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
16
-
17
- # Define the search tool
18
-
19
- messages = [
20
- {"role": "system", "content": self._system_prompt},
21
- {"role": "user", "content": question},
22
- ]
23
-
24
- # Initial response with tool calling
25
- response = await ollama_client.chat(
26
- model=self._model,
27
- messages=messages,
28
- tools=self.tools,
29
- options=OLLAMA_OPTIONS,
30
- think=False,
31
- )
32
-
33
- if response.get("message", {}).get("tool_calls"):
34
- for tool_call in response["message"]["tool_calls"]:
35
- if tool_call["function"]["name"] == "search_documents":
36
- args = tool_call["function"]["arguments"]
37
- query = args.get("query", question)
38
- limit = int(args.get("limit", 3))
39
-
40
- search_results = await self._client.search(query, limit=limit)
41
-
42
- context_chunks = []
43
- for chunk, score in search_results:
44
- context_chunks.append(
45
- f"Content: {chunk.content}\nScore: {score:.4f}"
46
- )
47
-
48
- context = "\n\n".join(context_chunks)
49
-
50
- messages.append(response["message"])
51
- messages.append(
52
- {
53
- "role": "tool",
54
- "content": context,
55
- "tool_call_id": tool_call.get("id", "search_tool"),
56
- }
57
- )
58
-
59
- final_response = await ollama_client.chat(
60
- model=self._model,
61
- messages=messages,
62
- think=False,
63
- options=OLLAMA_OPTIONS,
64
- )
65
- return final_response["message"]["content"]
66
- else:
67
- return response["message"]["content"]
@@ -1,101 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from openai import AsyncOpenAI
5
- from openai.types.chat import (
6
- ChatCompletionAssistantMessageParam,
7
- ChatCompletionMessageParam,
8
- ChatCompletionSystemMessageParam,
9
- ChatCompletionToolMessageParam,
10
- ChatCompletionUserMessageParam,
11
- )
12
- from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
13
-
14
- from haiku.rag.client import HaikuRAG
15
- from haiku.rag.qa.base import QuestionAnswerAgentBase
16
-
17
- class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
18
- def __init__(self, client: HaikuRAG, model: str = "gpt-4o-mini"):
19
- super().__init__(client, model or self._model)
20
- self.tools: Sequence[ChatCompletionToolParam] = [
21
- ChatCompletionToolParam(tool) for tool in self.tools
22
- ]
23
-
24
- async def answer(self, question: str) -> str:
25
- openai_client = AsyncOpenAI()
26
-
27
- # Define the search tool
28
-
29
- messages: list[ChatCompletionMessageParam] = [
30
- ChatCompletionSystemMessageParam(
31
- role="system", content=self._system_prompt
32
- ),
33
- ChatCompletionUserMessageParam(role="user", content=question),
34
- ]
35
-
36
- # Initial response with tool calling
37
- response = await openai_client.chat.completions.create(
38
- model=self._model,
39
- messages=messages,
40
- tools=self.tools,
41
- temperature=0.0,
42
- )
43
-
44
- response_message = response.choices[0].message
45
-
46
- if response_message.tool_calls:
47
- messages.append(
48
- ChatCompletionAssistantMessageParam(
49
- role="assistant",
50
- content=response_message.content,
51
- tool_calls=[
52
- {
53
- "id": tc.id,
54
- "type": "function",
55
- "function": {
56
- "name": tc.function.name,
57
- "arguments": tc.function.arguments,
58
- },
59
- }
60
- for tc in response_message.tool_calls
61
- ],
62
- )
63
- )
64
-
65
- for tool_call in response_message.tool_calls:
66
- if tool_call.function.name == "search_documents":
67
- import json
68
-
69
- args = json.loads(tool_call.function.arguments)
70
- query = args.get("query", question)
71
- limit = int(args.get("limit", 3))
72
-
73
- search_results = await self._client.search(query, limit=limit)
74
-
75
- context_chunks = []
76
- for chunk, score in search_results:
77
- context_chunks.append(
78
- f"Content: {chunk.content}\nScore: {score:.4f}"
79
- )
80
-
81
- context = "\n\n".join(context_chunks)
82
-
83
- messages.append(
84
- ChatCompletionToolMessageParam(
85
- role="tool",
86
- content=context,
87
- tool_call_id=tool_call.id,
88
- )
89
- )
90
-
91
- final_response = await openai_client.chat.completions.create(
92
- model=self._model,
93
- messages=messages,
94
- temperature=0.0,
95
- )
96
- return final_response.choices[0].message.content or ""
97
- else:
98
- return response_message.content or ""
99
-
100
- except ImportError:
101
- pass
@@ -1,7 +0,0 @@
1
- SYSTEM_PROMPT = """
2
- You are a helpful assistant that uses a RAG library to answer the user's prompt.
3
- Your task is to provide a concise and accurate answer based on the provided context.
4
- You should ask the provided tools to find relevant documents and then use the content of those documents to answer the question.
5
- Never make up information, always use the context to answer the question.
6
- If the context does not contain enough information to answer the question, respond with "I cannot answer that based on the provided context."
7
- """
@@ -1,129 +0,0 @@
1
- import asyncio
2
- from pathlib import Path
3
-
4
- from datasets import Dataset, load_dataset
5
- from llm_judge import LLMJudge
6
- from tqdm import tqdm
7
-
8
- from haiku.rag.client import HaikuRAG
9
- from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
10
-
11
- db_path = Path(__file__).parent / "data" / "benchmark.sqlite"
12
-
13
-
14
- async def populate_db():
15
- if (db_path).exists():
16
- print("Benchmark database already exists. Skipping creation.")
17
- return
18
-
19
- ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
20
- corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
21
-
22
- async with HaikuRAG(db_path) as rag:
23
- for i, doc in enumerate(tqdm(corpus)):
24
- await rag.create_document(
25
- content=doc["document_extracted"], # type: ignore
26
- uri=doc["document_id"], # type: ignore
27
- )
28
-
29
-
30
- async def run_match_benchmark():
31
- ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
32
- corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
33
-
34
- correct_at_1 = 0
35
- correct_at_2 = 0
36
- correct_at_3 = 0
37
- total_queries = 0
38
-
39
- async with HaikuRAG(db_path) as rag:
40
- for i, doc in enumerate(tqdm(corpus)):
41
- doc_id = doc["document_id"] # type: ignore
42
- matches = await rag.search(
43
- query=doc["question"], # type: ignore
44
- limit=3,
45
- )
46
-
47
- total_queries += 1
48
-
49
- # Check position of correct document in results
50
- for position, (chunk, _) in enumerate(matches):
51
- retrieved = await rag.get_document_by_id(chunk.document_id)
52
- if retrieved and retrieved.uri == doc_id:
53
- if position == 0: # First position
54
- correct_at_1 += 1
55
- correct_at_2 += 1
56
- correct_at_3 += 1
57
- elif position == 1: # Second position
58
- correct_at_2 += 1
59
- correct_at_3 += 1
60
- elif position == 2: # Third position
61
- correct_at_3 += 1
62
- break
63
-
64
- # Calculate recall metrics
65
- recall_at_1 = correct_at_1 / total_queries
66
- recall_at_2 = correct_at_2 / total_queries
67
- recall_at_3 = correct_at_3 / total_queries
68
-
69
- print("\n=== Retrieval Benchmark Results ===")
70
- print(f"Total queries: {total_queries}")
71
- print(f"Recall@1: {recall_at_1:.4f}")
72
- print(f"Recall@2: {recall_at_2:.4f}")
73
- print(f"Recall@3: {recall_at_3:.4f}")
74
-
75
- return {"recall@1": recall_at_1, "recall@2": recall_at_2, "recall@3": recall_at_3}
76
-
77
-
78
- async def run_qa_benchmark(k: int | None = None):
79
- """Run QA benchmarking on the corpus."""
80
- ds: Dataset = load_dataset("ServiceNow/repliqa")["repliqa_3"] # type: ignore
81
- corpus = ds.filter(lambda doc: doc["document_topic"] == "News Stories")
82
-
83
- if k is not None:
84
- corpus = corpus.select(range(min(k, len(corpus))))
85
-
86
- judge = LLMJudge()
87
- correct_answers = 0
88
- total_questions = 0
89
-
90
- async with HaikuRAG(db_path) as rag:
91
- qa = QuestionAnswerOllamaAgent(rag)
92
-
93
- for i, doc in enumerate(tqdm(corpus, desc="QA Benchmarking")):
94
- question = doc["question"] # type: ignore
95
- expected_answer = doc["answer"] # type: ignore
96
-
97
- generated_answer = await qa.answer(question)
98
- is_equivalent = await judge.judge_answers(
99
- question, generated_answer, expected_answer
100
- )
101
- print(f"Question: {question}")
102
- print(f"Expected: {expected_answer}")
103
- print(f"Generated: {generated_answer}")
104
- print(f"Equivalent: {is_equivalent}\n")
105
-
106
- if is_equivalent:
107
- correct_answers += 1
108
- total_questions += 1
109
-
110
- accuracy = correct_answers / total_questions if total_questions > 0 else 0
111
-
112
- print("\n=== QA Benchmark Results ===")
113
- print(f"Total questions: {total_questions}")
114
- print(f"Correct answers: {correct_answers}")
115
- print(f"QA Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
116
-
117
-
118
- async def main():
119
- await populate_db()
120
-
121
- print("Running retrieval benchmarks...")
122
- await run_match_benchmark()
123
-
124
- print("\nRunning QA benchmarks...")
125
- await run_qa_benchmark()
126
-
127
-
128
- if __name__ == "__main__":
129
- asyncio.run(main())
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes