haiku.rag 0.5.5__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +4 -4
- haiku/rag/cli.py +38 -27
- haiku/rag/client.py +19 -23
- haiku/rag/config.py +6 -2
- haiku/rag/embeddings/__init__.py +3 -9
- haiku/rag/embeddings/openai.py +10 -13
- haiku/rag/logging.py +4 -0
- haiku/rag/mcp.py +12 -9
- haiku/rag/migration.py +316 -0
- haiku/rag/qa/__init__.py +10 -39
- haiku/rag/qa/agent.py +76 -0
- haiku/rag/qa/prompts.py +2 -0
- haiku/rag/reranking/__init__.py +0 -6
- haiku/rag/store/engine.py +173 -141
- haiku/rag/store/models/chunk.py +2 -2
- haiku/rag/store/models/document.py +1 -1
- haiku/rag/store/repositories/__init__.py +6 -2
- haiku/rag/store/repositories/chunk.py +279 -414
- haiku/rag/store/repositories/document.py +171 -205
- haiku/rag/store/repositories/settings.py +115 -49
- haiku/rag/store/upgrades/__init__.py +1 -3
- haiku/rag/utils.py +39 -31
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/METADATA +22 -22
- haiku_rag-0.7.0.dist-info/RECORD +39 -0
- haiku/rag/qa/anthropic.py +0 -108
- haiku/rag/qa/base.py +0 -89
- haiku/rag/qa/ollama.py +0 -60
- haiku/rag/qa/openai.py +0 -97
- haiku/rag/reranking/ollama.py +0 -84
- haiku/rag/store/repositories/base.py +0 -40
- haiku/rag/store/upgrades/v0_3_4.py +0 -26
- haiku_rag-0.5.5.dist-info/RECORD +0 -44
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/utils.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import sys
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import wraps
|
|
2
5
|
from importlib import metadata
|
|
3
6
|
from io import BytesIO
|
|
4
7
|
from pathlib import Path
|
|
@@ -10,6 +13,42 @@ from docling_core.types.io import DocumentStream
|
|
|
10
13
|
from packaging.version import Version, parse
|
|
11
14
|
|
|
12
15
|
|
|
16
|
+
def debounce(wait: float) -> Callable:
|
|
17
|
+
"""
|
|
18
|
+
A decorator to debounce a function, ensuring it is called only after a specified delay
|
|
19
|
+
and always executes after the last call.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
wait (float): The debounce delay in seconds.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Callable: The decorated function.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def decorator(func: Callable) -> Callable:
|
|
29
|
+
last_call = None
|
|
30
|
+
task = None
|
|
31
|
+
|
|
32
|
+
@wraps(func)
|
|
33
|
+
async def debounced(*args, **kwargs):
|
|
34
|
+
nonlocal last_call, task
|
|
35
|
+
last_call = asyncio.get_event_loop().time()
|
|
36
|
+
|
|
37
|
+
if task:
|
|
38
|
+
task.cancel()
|
|
39
|
+
|
|
40
|
+
async def call_func():
|
|
41
|
+
await asyncio.sleep(wait)
|
|
42
|
+
if asyncio.get_event_loop().time() - last_call >= wait: # type: ignore
|
|
43
|
+
await func(*args, **kwargs)
|
|
44
|
+
|
|
45
|
+
task = asyncio.create_task(call_func())
|
|
46
|
+
|
|
47
|
+
return debounced
|
|
48
|
+
|
|
49
|
+
return decorator
|
|
50
|
+
|
|
51
|
+
|
|
13
52
|
def get_default_data_dir() -> Path:
|
|
14
53
|
"""Get the user data directory for the current system platform.
|
|
15
54
|
|
|
@@ -32,37 +71,6 @@ def get_default_data_dir() -> Path:
|
|
|
32
71
|
return data_path
|
|
33
72
|
|
|
34
73
|
|
|
35
|
-
def semantic_version_to_int(version: str) -> int:
|
|
36
|
-
"""Convert a semantic version string to an integer.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
version: Semantic version string.
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
Integer representation of semantic version.
|
|
43
|
-
"""
|
|
44
|
-
major, minor, patch = version.split(".")
|
|
45
|
-
major = int(major) << 16
|
|
46
|
-
minor = int(minor) << 8
|
|
47
|
-
patch = int(patch)
|
|
48
|
-
return major + minor + patch
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def int_to_semantic_version(version: int) -> str:
|
|
52
|
-
"""Convert an integer to a semantic version string.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
version: Integer representation of semantic version.
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
Semantic version string.
|
|
59
|
-
"""
|
|
60
|
-
major = version >> 16
|
|
61
|
-
minor = (version >> 8) & 255
|
|
62
|
-
patch = version & 255
|
|
63
|
-
return f"{major}.{minor}.{patch}"
|
|
64
|
-
|
|
65
|
-
|
|
66
74
|
async def is_up_to_date() -> tuple[bool, Version, Version]:
|
|
67
75
|
"""Check whether haiku.rag is current.
|
|
68
76
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haiku.rag
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: Retrieval Augmented Generation (RAG) with
|
|
3
|
+
Version: 0.7.0
|
|
4
|
+
Summary: Retrieval Augmented Generation (RAG) with LanceDB
|
|
5
5
|
Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
|
|
6
6
|
License: MIT
|
|
7
7
|
License-File: LICENSE
|
|
8
|
-
Keywords: RAG,mcp,ml,
|
|
8
|
+
Keywords: RAG,lancedb,mcp,ml,vector-database
|
|
9
9
|
Classifier: Development Status :: 4 - Beta
|
|
10
10
|
Classifier: Environment :: Console
|
|
11
11
|
Classifier: Intended Audience :: Developers
|
|
@@ -17,42 +17,39 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.12
|
|
19
19
|
Classifier: Typing :: Typed
|
|
20
|
-
Requires-Python: >=3.
|
|
21
|
-
Requires-Dist: docling>=2.
|
|
20
|
+
Requires-Python: >=3.12
|
|
21
|
+
Requires-Dist: docling>=2.49.0
|
|
22
22
|
Requires-Dist: fastmcp>=2.8.1
|
|
23
23
|
Requires-Dist: httpx>=0.28.1
|
|
24
|
+
Requires-Dist: lancedb>=0.24.3
|
|
24
25
|
Requires-Dist: ollama>=0.5.3
|
|
26
|
+
Requires-Dist: pydantic-ai>=0.8.1
|
|
25
27
|
Requires-Dist: pydantic>=2.11.7
|
|
26
28
|
Requires-Dist: python-dotenv>=1.1.0
|
|
27
|
-
Requires-Dist: rich>=14.
|
|
28
|
-
Requires-Dist:
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist: typer>=0.16.0
|
|
29
|
+
Requires-Dist: rich>=14.1.0
|
|
30
|
+
Requires-Dist: tiktoken>=0.11.0
|
|
31
|
+
Requires-Dist: typer>=0.16.1
|
|
31
32
|
Requires-Dist: watchfiles>=1.1.0
|
|
32
|
-
Provides-Extra: anthropic
|
|
33
|
-
Requires-Dist: anthropic>=0.56.0; extra == 'anthropic'
|
|
34
|
-
Provides-Extra: cohere
|
|
35
|
-
Requires-Dist: cohere>=5.16.1; extra == 'cohere'
|
|
36
33
|
Provides-Extra: mxbai
|
|
37
34
|
Requires-Dist: mxbai-rerank>=0.1.6; extra == 'mxbai'
|
|
38
|
-
Provides-Extra: openai
|
|
39
|
-
Requires-Dist: openai>=1.0.0; extra == 'openai'
|
|
40
35
|
Provides-Extra: voyageai
|
|
41
36
|
Requires-Dist: voyageai>=0.3.2; extra == 'voyageai'
|
|
42
37
|
Description-Content-Type: text/markdown
|
|
43
38
|
|
|
44
|
-
# Haiku
|
|
39
|
+
# Haiku RAG
|
|
45
40
|
|
|
46
|
-
Retrieval-Augmented Generation (RAG) library on
|
|
41
|
+
Retrieval-Augmented Generation (RAG) library built on LanceDB.
|
|
47
42
|
|
|
48
|
-
`haiku.rag` is a Retrieval-Augmented Generation (RAG) library built to work
|
|
43
|
+
`haiku.rag` is a Retrieval-Augmented Generation (RAG) library built to work with LanceDB as a local vector database. It uses LanceDB for storing embeddings and performs semantic (vector) search as well as full-text search combined through native hybrid search with Reciprocal Rank Fusion. Both open-source (Ollama) as well as commercial (OpenAI, VoyageAI) embedding providers are supported.
|
|
44
|
+
|
|
45
|
+
> **Note**: Starting with version 0.7.0, haiku.rag uses LanceDB instead of SQLite. If you have an existing SQLite database, use `haiku-rag migrate old_database.sqlite` to migrate your data safely.
|
|
49
46
|
|
|
50
47
|
## Features
|
|
51
48
|
|
|
52
|
-
- **Local
|
|
49
|
+
- **Local LanceDB**: No external servers required, supports also LanceDB cloud storage, S3, Google Cloud & Azure
|
|
53
50
|
- **Multiple embedding providers**: Ollama, VoyageAI, OpenAI
|
|
54
|
-
- **Multiple QA providers**:
|
|
55
|
-
- **
|
|
51
|
+
- **Multiple QA providers**: Any provider/model supported by Pydantic AI
|
|
52
|
+
- **Native hybrid search**: Vector + full-text search with native LanceDB RRF reranking
|
|
56
53
|
- **Reranking**: Default search result reranking with MixedBread AI or Cohere
|
|
57
54
|
- **Question answering**: Built-in QA agents on your documents
|
|
58
55
|
- **File monitoring**: Auto-index files when run as server
|
|
@@ -82,6 +79,9 @@ haiku-rag ask "Who is the author of haiku.rag?" --cite
|
|
|
82
79
|
# Rebuild database (re-chunk and re-embed all documents)
|
|
83
80
|
haiku-rag rebuild
|
|
84
81
|
|
|
82
|
+
# Migrate from SQLite to LanceDB
|
|
83
|
+
haiku-rag migrate old_database.sqlite
|
|
84
|
+
|
|
85
85
|
# Start server with file monitoring
|
|
86
86
|
export MONITOR_DIRECTORIES="/path/to/docs"
|
|
87
87
|
haiku-rag serve
|
|
@@ -92,7 +92,7 @@ haiku-rag serve
|
|
|
92
92
|
```python
|
|
93
93
|
from haiku.rag.client import HaikuRAG
|
|
94
94
|
|
|
95
|
-
async with HaikuRAG("database.
|
|
95
|
+
async with HaikuRAG("database.lancedb") as client:
|
|
96
96
|
# Add document
|
|
97
97
|
doc = await client.create_document("Your content")
|
|
98
98
|
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
haiku/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
haiku/rag/app.py,sha256=GmuZxH7BMutWt8Mdu0RSateRBaKiqXh7Z9tV7cZX6n0,7655
|
|
3
|
+
haiku/rag/chunker.py,sha256=PVe6ysv8UlacUd4Zb3_8RFWIaWDXnzBAy2VDJ4TaUsE,1555
|
|
4
|
+
haiku/rag/cli.py,sha256=UY9Vh5RsIxSCV14eQbNOiwToKmbFAvqTOAnxjieaYBs,6399
|
|
5
|
+
haiku/rag/client.py,sha256=N4zkWjE9Rsw9YgPvNo83xptHUQR2ognfOnjkoV_w6hc,20999
|
|
6
|
+
haiku/rag/config.py,sha256=9Mv0QJ3c6VF1oVRSXJlSsG234dCd_sKnJO-ybMaTpDQ,1690
|
|
7
|
+
haiku/rag/logging.py,sha256=9RjS931Dkp0nSdJaCSCydi33jq4gXmLmIidONvD9qD4,731
|
|
8
|
+
haiku/rag/mcp.py,sha256=bR9Y-Nz-hvjiql20Y0KE0hwNGwyjmPGX8K9d-qmXptY,4683
|
|
9
|
+
haiku/rag/migration.py,sha256=gWxQwiKo0YulRhogYz4K8N98kHN9LQXIx9FeTmT24v4,10915
|
|
10
|
+
haiku/rag/monitor.py,sha256=r386nkhdlsU8UECwIuVwnrSlgMk3vNIuUZGNIzkZuec,2770
|
|
11
|
+
haiku/rag/reader.py,sha256=qkPTMJuQ_o4sK-8zpDl9WFYe_MJ7aL_gUw6rczIpW-g,3274
|
|
12
|
+
haiku/rag/utils.py,sha256=c8F0ECsFSqvQxzxINAOAnvShoOnJPLsOaNE3JEY2JSc,3230
|
|
13
|
+
haiku/rag/embeddings/__init__.py,sha256=n7aHW3BxHlpGxU4ze4YYDOsljzFpEep8dwVE2n45JoE,1218
|
|
14
|
+
haiku/rag/embeddings/base.py,sha256=NTQvuzbZPu0LBo5wAu3qGyJ4xXUaRAt1fjBO0ygWn_Y,465
|
|
15
|
+
haiku/rag/embeddings/ollama.py,sha256=y6-lp0XpbnyIjoOEdtSzMdEVkU5glOwnWQ1FkpUZnpI,370
|
|
16
|
+
haiku/rag/embeddings/openai.py,sha256=iA-DewCOSip8PLU_RhEJHFHBle4DtmCCIGNfGs58Wvk,357
|
|
17
|
+
haiku/rag/embeddings/voyageai.py,sha256=0hiRTIqu-bpl-4OaCtMHvWfPdgbrzhnfZJowSV8pLRA,415
|
|
18
|
+
haiku/rag/qa/__init__.py,sha256=Sl7Kzrg9CuBOcMF01wc1NtQhUNWjJI0MhIHfCWrb8V4,434
|
|
19
|
+
haiku/rag/qa/agent.py,sha256=r6tYKvOW4W1HxBRHH1kmzlzb1bIJcQSuHd6cG9ANqXY,2594
|
|
20
|
+
haiku/rag/qa/prompts.py,sha256=xdT4cyrOrAK9UDgVqyev1wHF49jD57Bh40gx2sH4NPI,3341
|
|
21
|
+
haiku/rag/reranking/__init__.py,sha256=IRXHs4qPu6VbGJQpzSwhgtVWWumURH_vEoVFE-extlo,894
|
|
22
|
+
haiku/rag/reranking/base.py,sha256=LM9yUSSJ414UgBZhFTgxGprlRqzfTe4I1vgjricz2JY,405
|
|
23
|
+
haiku/rag/reranking/cohere.py,sha256=1iTdiaa8vvb6oHVB2qpWzUOVkyfUcimVSZp6Qr4aq4c,1049
|
|
24
|
+
haiku/rag/reranking/mxbai.py,sha256=46sVTsTIkzIX9THgM3u8HaEmgY7evvEyB-N54JTHvK8,867
|
|
25
|
+
haiku/rag/store/__init__.py,sha256=hq0W0DAC7ysqhWSP2M2uHX8cbG6kbr-sWHxhq6qQcY0,103
|
|
26
|
+
haiku/rag/store/engine.py,sha256=XHGo5Xl-dCFdQHrOdMo64xVK5n0k8-LoUl5V-tlA0HI,7131
|
|
27
|
+
haiku/rag/store/models/__init__.py,sha256=s0E72zneGlowvZrFWaNxHYjOAUjgWdLxzdYsnvNRVlY,88
|
|
28
|
+
haiku/rag/store/models/chunk.py,sha256=ZNyTfO6lh3rXWLVYO3TZcitbL4LSUGr42fR6jQQ5iQc,364
|
|
29
|
+
haiku/rag/store/models/document.py,sha256=zSSpt6pyrMJAIXGQvIcqojcqUzwZnhp3WxVokaWxNRc,396
|
|
30
|
+
haiku/rag/store/repositories/__init__.py,sha256=Olv5dLfBQINRV3HrsfUpjzkZ7Qm7goEYyMNykgo_DaY,291
|
|
31
|
+
haiku/rag/store/repositories/chunk.py,sha256=5S77mGh6pWxPHjaXriJGmvbSOhoNM8tLwygE2GXPlbU,13586
|
|
32
|
+
haiku/rag/store/repositories/document.py,sha256=lP8Lo82KTP-qwXFRpYZ46WjeAdAsHwZ5pJcrXdz4g0U,6988
|
|
33
|
+
haiku/rag/store/repositories/settings.py,sha256=dqnAvm-98nQrWpLBbf9QghJw673QD80-iqQhRMP5t0c,5025
|
|
34
|
+
haiku/rag/store/upgrades/__init__.py,sha256=wUiEoSiHTahvuagx93E4FB07v123AhdbOjwUkPusiIg,14
|
|
35
|
+
haiku_rag-0.7.0.dist-info/METADATA,sha256=FGx5ufhc35VUXqXLdbhmUEa-3hi7aaBuDeyZYf8xaaQ,4597
|
|
36
|
+
haiku_rag-0.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
37
|
+
haiku_rag-0.7.0.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
|
|
38
|
+
haiku_rag-0.7.0.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
|
|
39
|
+
haiku_rag-0.7.0.dist-info/RECORD,,
|
haiku/rag/qa/anthropic.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from anthropic import AsyncAnthropic # type: ignore
|
|
5
|
-
from anthropic.types import ( # type: ignore
|
|
6
|
-
MessageParam,
|
|
7
|
-
TextBlock,
|
|
8
|
-
ToolParam,
|
|
9
|
-
ToolUseBlock,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
from haiku.rag.client import HaikuRAG
|
|
13
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
14
|
-
|
|
15
|
-
class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
client: HaikuRAG,
|
|
19
|
-
model: str = "claude-3-5-haiku-20241022",
|
|
20
|
-
use_citations: bool = False,
|
|
21
|
-
):
|
|
22
|
-
super().__init__(client, model or self._model, use_citations)
|
|
23
|
-
self.tools: Sequence[ToolParam] = [
|
|
24
|
-
ToolParam(
|
|
25
|
-
name="search_documents",
|
|
26
|
-
description="Search the knowledge base for relevant documents. Returns a JSON array with content, score, and document_uri for each result.",
|
|
27
|
-
input_schema={
|
|
28
|
-
"type": "object",
|
|
29
|
-
"properties": {
|
|
30
|
-
"query": {
|
|
31
|
-
"type": "string",
|
|
32
|
-
"description": "The search query to find relevant documents",
|
|
33
|
-
},
|
|
34
|
-
"limit": {
|
|
35
|
-
"type": "integer",
|
|
36
|
-
"description": "Maximum number of results to return",
|
|
37
|
-
"default": 3,
|
|
38
|
-
},
|
|
39
|
-
},
|
|
40
|
-
"required": ["query"],
|
|
41
|
-
},
|
|
42
|
-
)
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
async def answer(self, question: str) -> str:
|
|
46
|
-
anthropic_client = AsyncAnthropic()
|
|
47
|
-
|
|
48
|
-
messages: list[MessageParam] = [{"role": "user", "content": question}]
|
|
49
|
-
|
|
50
|
-
max_rounds = 5 # Prevent infinite loops
|
|
51
|
-
|
|
52
|
-
for _ in range(max_rounds):
|
|
53
|
-
response = await anthropic_client.messages.create(
|
|
54
|
-
model=self._model,
|
|
55
|
-
max_tokens=4096,
|
|
56
|
-
system=self._system_prompt,
|
|
57
|
-
messages=messages,
|
|
58
|
-
tools=self.tools,
|
|
59
|
-
temperature=0.0,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
if response.stop_reason == "tool_use":
|
|
63
|
-
messages.append({"role": "assistant", "content": response.content})
|
|
64
|
-
|
|
65
|
-
# Process tool calls
|
|
66
|
-
tool_results = []
|
|
67
|
-
for content_block in response.content:
|
|
68
|
-
if isinstance(content_block, ToolUseBlock):
|
|
69
|
-
if content_block.name == "search_documents":
|
|
70
|
-
args = content_block.input
|
|
71
|
-
query = (
|
|
72
|
-
args.get("query", question)
|
|
73
|
-
if isinstance(args, dict)
|
|
74
|
-
else question
|
|
75
|
-
)
|
|
76
|
-
limit = (
|
|
77
|
-
int(args.get("limit", 3))
|
|
78
|
-
if isinstance(args, dict)
|
|
79
|
-
else 3
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
context = await self._search_and_expand(
|
|
83
|
-
query, limit=limit
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
tool_results.append(
|
|
87
|
-
{
|
|
88
|
-
"type": "tool_result",
|
|
89
|
-
"tool_use_id": content_block.id,
|
|
90
|
-
"content": context,
|
|
91
|
-
}
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
if tool_results:
|
|
95
|
-
messages.append({"role": "user", "content": tool_results})
|
|
96
|
-
else:
|
|
97
|
-
# No tool use, return the response
|
|
98
|
-
if response.content:
|
|
99
|
-
first_content = response.content[0]
|
|
100
|
-
if isinstance(first_content, TextBlock):
|
|
101
|
-
return first_content.text
|
|
102
|
-
return ""
|
|
103
|
-
|
|
104
|
-
# If we've exhausted max rounds, return empty string
|
|
105
|
-
return ""
|
|
106
|
-
|
|
107
|
-
except ImportError:
|
|
108
|
-
pass
|
haiku/rag/qa/base.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
|
|
3
|
-
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class QuestionAnswerAgentBase:
|
|
8
|
-
_model: str = ""
|
|
9
|
-
_system_prompt: str = SYSTEM_PROMPT
|
|
10
|
-
|
|
11
|
-
def __init__(self, client: HaikuRAG, model: str = "", use_citations: bool = False):
|
|
12
|
-
self._model = model
|
|
13
|
-
self._client = client
|
|
14
|
-
self._system_prompt = (
|
|
15
|
-
SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
async def answer(self, question: str) -> str:
|
|
19
|
-
raise NotImplementedError(
|
|
20
|
-
"QABase is an abstract class. Please implement the answer method in a subclass."
|
|
21
|
-
)
|
|
22
|
-
|
|
23
|
-
async def _search_and_expand(self, query: str, limit: int = 3) -> str:
|
|
24
|
-
"""Search for documents and expand context, then format as JSON"""
|
|
25
|
-
search_results = await self._client.search(query, limit=limit)
|
|
26
|
-
expanded_results = await self._client.expand_context(search_results)
|
|
27
|
-
return self._format_search_results(expanded_results)
|
|
28
|
-
|
|
29
|
-
def _format_search_results(self, search_results) -> str:
|
|
30
|
-
"""Format search results as JSON list of {content, score, document_uri}"""
|
|
31
|
-
formatted_results = []
|
|
32
|
-
for chunk, score in search_results:
|
|
33
|
-
formatted_results.append(
|
|
34
|
-
{
|
|
35
|
-
"content": chunk.content,
|
|
36
|
-
"score": score,
|
|
37
|
-
"document_uri": chunk.document_uri,
|
|
38
|
-
}
|
|
39
|
-
)
|
|
40
|
-
return json.dumps(formatted_results, indent=2)
|
|
41
|
-
|
|
42
|
-
tools = [
|
|
43
|
-
{
|
|
44
|
-
"type": "function",
|
|
45
|
-
"function": {
|
|
46
|
-
"name": "search_documents",
|
|
47
|
-
"description": "Search the knowledge base for relevant documents. Returns a JSON array of search results.",
|
|
48
|
-
"parameters": {
|
|
49
|
-
"type": "object",
|
|
50
|
-
"properties": {
|
|
51
|
-
"query": {
|
|
52
|
-
"type": "string",
|
|
53
|
-
"description": "The search query to find relevant documents",
|
|
54
|
-
},
|
|
55
|
-
"limit": {
|
|
56
|
-
"type": "integer",
|
|
57
|
-
"description": "Maximum number of results to return",
|
|
58
|
-
"default": 3,
|
|
59
|
-
},
|
|
60
|
-
},
|
|
61
|
-
"required": ["query"],
|
|
62
|
-
},
|
|
63
|
-
"returns": {
|
|
64
|
-
"type": "string",
|
|
65
|
-
"description": "JSON array of search results",
|
|
66
|
-
"schema": {
|
|
67
|
-
"type": "array",
|
|
68
|
-
"items": {
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"content": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "The document text content",
|
|
74
|
-
},
|
|
75
|
-
"score": {
|
|
76
|
-
"type": "number",
|
|
77
|
-
"description": "Relevance score (higher is more relevant)",
|
|
78
|
-
},
|
|
79
|
-
"document_uri": {
|
|
80
|
-
"type": "string",
|
|
81
|
-
"description": "Source URI/path of the document",
|
|
82
|
-
},
|
|
83
|
-
},
|
|
84
|
-
},
|
|
85
|
-
},
|
|
86
|
-
},
|
|
87
|
-
},
|
|
88
|
-
}
|
|
89
|
-
]
|
haiku/rag/qa/ollama.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
from ollama import AsyncClient
|
|
2
|
-
|
|
3
|
-
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.config import Config
|
|
5
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
6
|
-
|
|
7
|
-
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
client: HaikuRAG,
|
|
14
|
-
model: str = Config.QA_MODEL,
|
|
15
|
-
use_citations: bool = False,
|
|
16
|
-
):
|
|
17
|
-
super().__init__(client, model or self._model, use_citations)
|
|
18
|
-
|
|
19
|
-
async def answer(self, question: str) -> str:
|
|
20
|
-
ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
21
|
-
|
|
22
|
-
messages = [
|
|
23
|
-
{"role": "system", "content": self._system_prompt},
|
|
24
|
-
{"role": "user", "content": question},
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
max_rounds = 5 # Prevent infinite loops
|
|
28
|
-
|
|
29
|
-
for _ in range(max_rounds):
|
|
30
|
-
response = await ollama_client.chat(
|
|
31
|
-
model=self._model,
|
|
32
|
-
messages=messages,
|
|
33
|
-
tools=self.tools,
|
|
34
|
-
options=OLLAMA_OPTIONS,
|
|
35
|
-
think=False,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
if response.get("message", {}).get("tool_calls"):
|
|
39
|
-
messages.append(response["message"])
|
|
40
|
-
|
|
41
|
-
for tool_call in response["message"]["tool_calls"]:
|
|
42
|
-
if tool_call["function"]["name"] == "search_documents":
|
|
43
|
-
args = tool_call["function"]["arguments"]
|
|
44
|
-
query = args.get("query", question)
|
|
45
|
-
limit = int(args.get("limit", 3))
|
|
46
|
-
|
|
47
|
-
context = await self._search_and_expand(query, limit=limit)
|
|
48
|
-
messages.append(
|
|
49
|
-
{
|
|
50
|
-
"role": "tool",
|
|
51
|
-
"content": context,
|
|
52
|
-
"tool_call_id": tool_call.get("id", "search_tool"),
|
|
53
|
-
}
|
|
54
|
-
)
|
|
55
|
-
else:
|
|
56
|
-
# No tool calls, return the response
|
|
57
|
-
return response["message"]["content"]
|
|
58
|
-
|
|
59
|
-
# If we've exhausted max rounds, return empty string
|
|
60
|
-
return ""
|
haiku/rag/qa/openai.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from openai import AsyncOpenAI # type: ignore
|
|
5
|
-
from openai.types.chat import ( # type: ignore
|
|
6
|
-
ChatCompletionAssistantMessageParam,
|
|
7
|
-
ChatCompletionMessageParam,
|
|
8
|
-
ChatCompletionSystemMessageParam,
|
|
9
|
-
ChatCompletionToolMessageParam,
|
|
10
|
-
ChatCompletionUserMessageParam,
|
|
11
|
-
)
|
|
12
|
-
from openai.types.chat.chat_completion_tool_param import ( # type: ignore
|
|
13
|
-
ChatCompletionToolParam,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
from haiku.rag.client import HaikuRAG
|
|
17
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
18
|
-
|
|
19
|
-
class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
client: HaikuRAG,
|
|
23
|
-
model: str = "gpt-4o-mini",
|
|
24
|
-
use_citations: bool = False,
|
|
25
|
-
):
|
|
26
|
-
super().__init__(client, model or self._model, use_citations)
|
|
27
|
-
self.tools: Sequence[ChatCompletionToolParam] = [
|
|
28
|
-
ChatCompletionToolParam(tool) for tool in self.tools
|
|
29
|
-
]
|
|
30
|
-
|
|
31
|
-
async def answer(self, question: str) -> str:
|
|
32
|
-
openai_client = AsyncOpenAI()
|
|
33
|
-
|
|
34
|
-
messages: list[ChatCompletionMessageParam] = [
|
|
35
|
-
ChatCompletionSystemMessageParam(
|
|
36
|
-
role="system", content=self._system_prompt
|
|
37
|
-
),
|
|
38
|
-
ChatCompletionUserMessageParam(role="user", content=question),
|
|
39
|
-
]
|
|
40
|
-
|
|
41
|
-
max_rounds = 5 # Prevent infinite loops
|
|
42
|
-
|
|
43
|
-
for _ in range(max_rounds):
|
|
44
|
-
response = await openai_client.chat.completions.create(
|
|
45
|
-
model=self._model,
|
|
46
|
-
messages=messages,
|
|
47
|
-
tools=self.tools,
|
|
48
|
-
temperature=0.0,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
response_message = response.choices[0].message
|
|
52
|
-
|
|
53
|
-
if response_message.tool_calls:
|
|
54
|
-
messages.append(
|
|
55
|
-
ChatCompletionAssistantMessageParam(
|
|
56
|
-
role="assistant",
|
|
57
|
-
content=response_message.content,
|
|
58
|
-
tool_calls=[
|
|
59
|
-
{
|
|
60
|
-
"id": tc.id,
|
|
61
|
-
"type": "function",
|
|
62
|
-
"function": {
|
|
63
|
-
"name": tc.function.name,
|
|
64
|
-
"arguments": tc.function.arguments,
|
|
65
|
-
},
|
|
66
|
-
}
|
|
67
|
-
for tc in response_message.tool_calls
|
|
68
|
-
],
|
|
69
|
-
)
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
for tool_call in response_message.tool_calls:
|
|
73
|
-
if tool_call.function.name == "search_documents":
|
|
74
|
-
import json
|
|
75
|
-
|
|
76
|
-
args = json.loads(tool_call.function.arguments)
|
|
77
|
-
query = args.get("query", question)
|
|
78
|
-
limit = int(args.get("limit", 3))
|
|
79
|
-
|
|
80
|
-
context = await self._search_and_expand(query, limit=limit)
|
|
81
|
-
|
|
82
|
-
messages.append(
|
|
83
|
-
ChatCompletionToolMessageParam(
|
|
84
|
-
role="tool",
|
|
85
|
-
content=context,
|
|
86
|
-
tool_call_id=tool_call.id,
|
|
87
|
-
)
|
|
88
|
-
)
|
|
89
|
-
else:
|
|
90
|
-
# No tool calls, return the response
|
|
91
|
-
return response_message.content or ""
|
|
92
|
-
|
|
93
|
-
# If we've exhausted max rounds, return empty string
|
|
94
|
-
return ""
|
|
95
|
-
|
|
96
|
-
except ImportError:
|
|
97
|
-
pass
|
haiku/rag/reranking/ollama.py
DELETED
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
|
|
3
|
-
from ollama import AsyncClient
|
|
4
|
-
from pydantic import BaseModel
|
|
5
|
-
|
|
6
|
-
from haiku.rag.config import Config
|
|
7
|
-
from haiku.rag.reranking.base import RerankerBase
|
|
8
|
-
from haiku.rag.store.models.chunk import Chunk
|
|
9
|
-
|
|
10
|
-
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class RerankResult(BaseModel):
|
|
14
|
-
"""Individual rerank result with index and relevance score."""
|
|
15
|
-
|
|
16
|
-
index: int
|
|
17
|
-
relevance_score: float
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class RerankResponse(BaseModel):
|
|
21
|
-
"""Response from the reranking model containing ranked results."""
|
|
22
|
-
|
|
23
|
-
results: list[RerankResult]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class OllamaReranker(RerankerBase):
|
|
27
|
-
def __init__(self, model: str = Config.RERANK_MODEL):
|
|
28
|
-
self._model = model
|
|
29
|
-
self._client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
30
|
-
|
|
31
|
-
async def rerank(
|
|
32
|
-
self, query: str, chunks: list[Chunk], top_n: int = 10
|
|
33
|
-
) -> list[tuple[Chunk, float]]:
|
|
34
|
-
if not chunks:
|
|
35
|
-
return []
|
|
36
|
-
|
|
37
|
-
documents = []
|
|
38
|
-
for i, chunk in enumerate(chunks):
|
|
39
|
-
documents.append({"index": i, "content": chunk.content})
|
|
40
|
-
|
|
41
|
-
# Create the prompt for reranking
|
|
42
|
-
system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
|
|
43
|
-
|
|
44
|
-
Return your response as a JSON object with a "results" array. Each result should have:
|
|
45
|
-
- "index": the original index of the document (integer)
|
|
46
|
-
- "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
|
|
47
|
-
|
|
48
|
-
Only return the top documents up to the requested limit, ordered by decreasing relevance score."""
|
|
49
|
-
|
|
50
|
-
documents_text = ""
|
|
51
|
-
for doc in documents:
|
|
52
|
-
documents_text += f"Index {doc['index']}: {doc['content']}\n\n"
|
|
53
|
-
|
|
54
|
-
user_prompt = f"""Query: {query}
|
|
55
|
-
|
|
56
|
-
Documents to rerank:
|
|
57
|
-
{documents_text.strip()}
|
|
58
|
-
|
|
59
|
-
Please rank these documents by relevance to the query and return the top {top_n} results as JSON."""
|
|
60
|
-
|
|
61
|
-
messages = [
|
|
62
|
-
{"role": "system", "content": system_prompt},
|
|
63
|
-
{"role": "user", "content": user_prompt},
|
|
64
|
-
]
|
|
65
|
-
|
|
66
|
-
try:
|
|
67
|
-
response = await self._client.chat(
|
|
68
|
-
model=self._model,
|
|
69
|
-
messages=messages,
|
|
70
|
-
format=RerankResponse.model_json_schema(),
|
|
71
|
-
options=OLLAMA_OPTIONS,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
content = response["message"]["content"]
|
|
75
|
-
|
|
76
|
-
parsed_response = RerankResponse.model_validate(json.loads(content))
|
|
77
|
-
return [
|
|
78
|
-
(chunks[result.index], result.relevance_score)
|
|
79
|
-
for result in parsed_response.results[:top_n]
|
|
80
|
-
]
|
|
81
|
-
|
|
82
|
-
except Exception:
|
|
83
|
-
# Fallback: return chunks in original order with same score
|
|
84
|
-
return [(chunks[i], 1.0) for i in range(min(top_n, len(chunks)))]
|