haiku.rag 0.5.5__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/embeddings/__init__.py +3 -9
- haiku/rag/embeddings/openai.py +10 -13
- haiku/rag/qa/__init__.py +10 -39
- haiku/rag/qa/agent.py +76 -0
- haiku/rag/qa/prompts.py +2 -0
- haiku/rag/reranking/ollama.py +29 -32
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.6.0.dist-info}/METADATA +3 -8
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.6.0.dist-info}/RECORD +11 -14
- haiku/rag/qa/anthropic.py +0 -108
- haiku/rag/qa/base.py +0 -89
- haiku/rag/qa/ollama.py +0 -60
- haiku/rag/qa/openai.py +0 -97
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.6.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.6.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.6.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/embeddings/__init__.py
CHANGED
|
@@ -17,20 +17,14 @@ def get_embedder() -> EmbedderBase:
|
|
|
17
17
|
except ImportError:
|
|
18
18
|
raise ImportError(
|
|
19
19
|
"VoyageAI embedder requires the 'voyageai' package. "
|
|
20
|
-
"Please install haiku.rag with the 'voyageai' extra:"
|
|
20
|
+
"Please install haiku.rag with the 'voyageai' extra: "
|
|
21
21
|
"uv pip install haiku.rag[voyageai]"
|
|
22
22
|
)
|
|
23
23
|
return VoyageAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
|
|
24
24
|
|
|
25
25
|
if Config.EMBEDDINGS_PROVIDER == "openai":
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
except ImportError:
|
|
29
|
-
raise ImportError(
|
|
30
|
-
"OpenAI embedder requires the 'openai' package. "
|
|
31
|
-
"Please install haiku.rag with the 'openai' extra:"
|
|
32
|
-
"uv pip install haiku.rag[openai]"
|
|
33
|
-
)
|
|
26
|
+
from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
|
|
27
|
+
|
|
34
28
|
return OpenAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
|
|
35
29
|
|
|
36
30
|
raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDINGS_PROVIDER}")
|
haiku/rag/embeddings/openai.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
|
1
|
-
|
|
2
|
-
from openai import AsyncOpenAI
|
|
1
|
+
from openai import AsyncOpenAI
|
|
3
2
|
|
|
4
|
-
|
|
3
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
5
4
|
|
|
6
|
-
class Embedder(EmbedderBase):
|
|
7
|
-
async def embed(self, text: str) -> list[float]:
|
|
8
|
-
client = AsyncOpenAI()
|
|
9
|
-
response = await client.embeddings.create(
|
|
10
|
-
model=self._model,
|
|
11
|
-
input=text,
|
|
12
|
-
)
|
|
13
|
-
return response.data[0].embedding
|
|
14
5
|
|
|
15
|
-
|
|
16
|
-
|
|
6
|
+
class Embedder(EmbedderBase):
|
|
7
|
+
async def embed(self, text: str) -> list[float]:
|
|
8
|
+
client = AsyncOpenAI()
|
|
9
|
+
response = await client.embeddings.create(
|
|
10
|
+
model=self._model,
|
|
11
|
+
input=text,
|
|
12
|
+
)
|
|
13
|
+
return response.data[0].embedding
|
haiku/rag/qa/__init__.py
CHANGED
|
@@ -1,44 +1,15 @@
|
|
|
1
1
|
from haiku.rag.client import HaikuRAG
|
|
2
2
|
from haiku.rag.config import Config
|
|
3
|
-
from haiku.rag.qa.
|
|
4
|
-
from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
|
|
3
|
+
from haiku.rag.qa.agent import QuestionAnswerAgent
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def get_qa_agent(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
"""
|
|
11
|
-
Factory function to get the appropriate QA agent based on the configuration.
|
|
12
|
-
"""
|
|
13
|
-
if Config.QA_PROVIDER == "ollama":
|
|
14
|
-
return QuestionAnswerOllamaAgent(
|
|
15
|
-
client, model or Config.QA_MODEL, use_citations
|
|
16
|
-
)
|
|
6
|
+
def get_qa_agent(client: HaikuRAG, use_citations: bool = False) -> QuestionAnswerAgent:
|
|
7
|
+
provider = Config.QA_PROVIDER
|
|
8
|
+
model_name = Config.QA_MODEL
|
|
17
9
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
"Please install haiku.rag with the 'openai' extra:"
|
|
25
|
-
"uv pip install haiku.rag[openai]"
|
|
26
|
-
)
|
|
27
|
-
return QuestionAnswerOpenAIAgent(
|
|
28
|
-
client, model or Config.QA_MODEL, use_citations
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
if Config.QA_PROVIDER == "anthropic":
|
|
32
|
-
try:
|
|
33
|
-
from haiku.rag.qa.anthropic import QuestionAnswerAnthropicAgent
|
|
34
|
-
except ImportError:
|
|
35
|
-
raise ImportError(
|
|
36
|
-
"Anthropic QA agent requires the 'anthropic' package. "
|
|
37
|
-
"Please install haiku.rag with the 'anthropic' extra:"
|
|
38
|
-
"uv pip install haiku.rag[anthropic]"
|
|
39
|
-
)
|
|
40
|
-
return QuestionAnswerAnthropicAgent(
|
|
41
|
-
client, model or Config.QA_MODEL, use_citations
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
|
|
10
|
+
return QuestionAnswerAgent(
|
|
11
|
+
client=client,
|
|
12
|
+
provider=provider,
|
|
13
|
+
model=model_name,
|
|
14
|
+
use_citations=use_citations,
|
|
15
|
+
)
|
haiku/rag/qa/agent.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from pydantic_ai import Agent, RunContext
|
|
3
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
4
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
5
|
+
|
|
6
|
+
from haiku.rag.client import HaikuRAG
|
|
7
|
+
from haiku.rag.config import Config
|
|
8
|
+
from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SearchResult(BaseModel):
|
|
12
|
+
content: str = Field(description="The document text content")
|
|
13
|
+
score: float = Field(description="Relevance score (higher is more relevant)")
|
|
14
|
+
document_uri: str = Field(description="Source URI/path of the document")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Dependencies(BaseModel):
|
|
18
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
19
|
+
client: HaikuRAG
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QuestionAnswerAgent:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
client: HaikuRAG,
|
|
26
|
+
provider: str,
|
|
27
|
+
model: str,
|
|
28
|
+
use_citations: bool = False,
|
|
29
|
+
q: float = 0.0,
|
|
30
|
+
):
|
|
31
|
+
self._client = client
|
|
32
|
+
|
|
33
|
+
system_prompt = SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
|
|
34
|
+
model_obj = self._get_model(provider, model)
|
|
35
|
+
|
|
36
|
+
self._agent = Agent(
|
|
37
|
+
model=model_obj,
|
|
38
|
+
deps_type=Dependencies,
|
|
39
|
+
system_prompt=system_prompt,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@self._agent.tool
|
|
43
|
+
async def search_documents(
|
|
44
|
+
ctx: RunContext[Dependencies],
|
|
45
|
+
query: str,
|
|
46
|
+
limit: int = 3,
|
|
47
|
+
) -> list[SearchResult]:
|
|
48
|
+
"""Search the knowledge base for relevant documents."""
|
|
49
|
+
search_results = await ctx.deps.client.search(query, limit=limit)
|
|
50
|
+
expanded_results = await ctx.deps.client.expand_context(search_results)
|
|
51
|
+
|
|
52
|
+
return [
|
|
53
|
+
SearchResult(
|
|
54
|
+
content=chunk.content,
|
|
55
|
+
score=score,
|
|
56
|
+
document_uri=chunk.document_uri or "",
|
|
57
|
+
)
|
|
58
|
+
for chunk, score in expanded_results
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
def _get_model(self, provider: str, model: str):
|
|
62
|
+
"""Get the appropriate model object for the provider."""
|
|
63
|
+
if provider == "ollama":
|
|
64
|
+
return OpenAIModel(
|
|
65
|
+
model_name=model,
|
|
66
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
# For all other providers, use the provider:model format
|
|
70
|
+
return f"{provider}:{model}"
|
|
71
|
+
|
|
72
|
+
async def answer(self, question: str) -> str:
|
|
73
|
+
"""Answer a question using the RAG system."""
|
|
74
|
+
deps = Dependencies(client=self._client)
|
|
75
|
+
result = await self._agent.run(question, deps=deps)
|
|
76
|
+
return result.output
|
haiku/rag/qa/prompts.py
CHANGED
|
@@ -18,6 +18,7 @@ Guidelines:
|
|
|
18
18
|
- Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
|
|
19
19
|
|
|
20
20
|
Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
|
|
21
|
+
/no_think
|
|
21
22
|
"""
|
|
22
23
|
|
|
23
24
|
SYSTEM_PROMPT_WITH_CITATIONS = """
|
|
@@ -55,4 +56,5 @@ Citations:
|
|
|
55
56
|
- /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
|
|
56
57
|
|
|
57
58
|
Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
|
|
59
|
+
/no_think
|
|
58
60
|
"""
|
haiku/rag/reranking/ollama.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
|
1
|
-
import json
|
|
2
|
-
|
|
3
|
-
from ollama import AsyncClient
|
|
4
1
|
from pydantic import BaseModel
|
|
2
|
+
from pydantic_ai import Agent
|
|
3
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
4
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
5
5
|
|
|
6
6
|
from haiku.rag.config import Config
|
|
7
7
|
from haiku.rag.reranking.base import RerankerBase
|
|
8
8
|
from haiku.rag.store.models.chunk import Chunk
|
|
9
9
|
|
|
10
|
-
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
|
|
11
|
-
|
|
12
10
|
|
|
13
11
|
class RerankResult(BaseModel):
|
|
14
12
|
"""Individual rerank result with index and relevance score."""
|
|
@@ -26,7 +24,28 @@ class RerankResponse(BaseModel):
|
|
|
26
24
|
class OllamaReranker(RerankerBase):
|
|
27
25
|
def __init__(self, model: str = Config.RERANK_MODEL):
|
|
28
26
|
self._model = model
|
|
29
|
-
|
|
27
|
+
|
|
28
|
+
# Create the reranking prompt
|
|
29
|
+
system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
|
|
30
|
+
|
|
31
|
+
Return your response as a JSON object with a "results" array. Each result should have:
|
|
32
|
+
- "index": the original index of the document (integer)
|
|
33
|
+
- "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
|
|
34
|
+
|
|
35
|
+
Only return the top documents up to the requested limit, ordered by decreasing relevance score.
|
|
36
|
+
/no_think
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
model_obj = OpenAIModel(
|
|
40
|
+
model_name=model,
|
|
41
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
self._agent = Agent(
|
|
45
|
+
model=model_obj,
|
|
46
|
+
output_type=RerankResponse,
|
|
47
|
+
system_prompt=system_prompt,
|
|
48
|
+
)
|
|
30
49
|
|
|
31
50
|
async def rerank(
|
|
32
51
|
self, query: str, chunks: list[Chunk], top_n: int = 10
|
|
@@ -38,15 +57,6 @@ class OllamaReranker(RerankerBase):
|
|
|
38
57
|
for i, chunk in enumerate(chunks):
|
|
39
58
|
documents.append({"index": i, "content": chunk.content})
|
|
40
59
|
|
|
41
|
-
# Create the prompt for reranking
|
|
42
|
-
system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
|
|
43
|
-
|
|
44
|
-
Return your response as a JSON object with a "results" array. Each result should have:
|
|
45
|
-
- "index": the original index of the document (integer)
|
|
46
|
-
- "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
|
|
47
|
-
|
|
48
|
-
Only return the top documents up to the requested limit, ordered by decreasing relevance score."""
|
|
49
|
-
|
|
50
60
|
documents_text = ""
|
|
51
61
|
for doc in documents:
|
|
52
62
|
documents_text += f"Index {doc['index']}: {doc['content']}\n\n"
|
|
@@ -56,27 +66,14 @@ Only return the top documents up to the requested limit, ordered by decreasing r
|
|
|
56
66
|
Documents to rerank:
|
|
57
67
|
{documents_text.strip()}
|
|
58
68
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
messages = [
|
|
62
|
-
{"role": "system", "content": system_prompt},
|
|
63
|
-
{"role": "user", "content": user_prompt},
|
|
64
|
-
]
|
|
69
|
+
Rank these documents by relevance to the query and return the top {top_n} results as JSON."""
|
|
65
70
|
|
|
66
71
|
try:
|
|
67
|
-
|
|
68
|
-
model=self._model,
|
|
69
|
-
messages=messages,
|
|
70
|
-
format=RerankResponse.model_json_schema(),
|
|
71
|
-
options=OLLAMA_OPTIONS,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
content = response["message"]["content"]
|
|
72
|
+
result = await self._agent.run(user_prompt)
|
|
75
73
|
|
|
76
|
-
parsed_response = RerankResponse.model_validate(json.loads(content))
|
|
77
74
|
return [
|
|
78
|
-
(chunks[
|
|
79
|
-
for
|
|
75
|
+
(chunks[result_item.index], result_item.relevance_score)
|
|
76
|
+
for result_item in result.output.results[:top_n]
|
|
80
77
|
]
|
|
81
78
|
|
|
82
79
|
except Exception:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haiku.rag
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Retrieval Augmented Generation (RAG) with SQLite
|
|
5
5
|
Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -22,6 +22,7 @@ Requires-Dist: docling>=2.15.0
|
|
|
22
22
|
Requires-Dist: fastmcp>=2.8.1
|
|
23
23
|
Requires-Dist: httpx>=0.28.1
|
|
24
24
|
Requires-Dist: ollama>=0.5.3
|
|
25
|
+
Requires-Dist: pydantic-ai>=0.7.2
|
|
25
26
|
Requires-Dist: pydantic>=2.11.7
|
|
26
27
|
Requires-Dist: python-dotenv>=1.1.0
|
|
27
28
|
Requires-Dist: rich>=14.0.0
|
|
@@ -29,14 +30,8 @@ Requires-Dist: sqlite-vec>=0.1.6
|
|
|
29
30
|
Requires-Dist: tiktoken>=0.9.0
|
|
30
31
|
Requires-Dist: typer>=0.16.0
|
|
31
32
|
Requires-Dist: watchfiles>=1.1.0
|
|
32
|
-
Provides-Extra: anthropic
|
|
33
|
-
Requires-Dist: anthropic>=0.56.0; extra == 'anthropic'
|
|
34
|
-
Provides-Extra: cohere
|
|
35
|
-
Requires-Dist: cohere>=5.16.1; extra == 'cohere'
|
|
36
33
|
Provides-Extra: mxbai
|
|
37
34
|
Requires-Dist: mxbai-rerank>=0.1.6; extra == 'mxbai'
|
|
38
|
-
Provides-Extra: openai
|
|
39
|
-
Requires-Dist: openai>=1.0.0; extra == 'openai'
|
|
40
35
|
Provides-Extra: voyageai
|
|
41
36
|
Requires-Dist: voyageai>=0.3.2; extra == 'voyageai'
|
|
42
37
|
Description-Content-Type: text/markdown
|
|
@@ -51,7 +46,7 @@ Retrieval-Augmented Generation (RAG) library on SQLite.
|
|
|
51
46
|
|
|
52
47
|
- **Local SQLite**: No external servers required
|
|
53
48
|
- **Multiple embedding providers**: Ollama, VoyageAI, OpenAI
|
|
54
|
-
- **Multiple QA providers**:
|
|
49
|
+
- **Multiple QA providers**: Any provider/model supported by Pydantic AI
|
|
55
50
|
- **Hybrid search**: Vector + full-text search with Reciprocal Rank Fusion
|
|
56
51
|
- **Reranking**: Default search result reranking with MixedBread AI or Cohere
|
|
57
52
|
- **Question answering**: Built-in QA agents on your documents
|
|
@@ -9,22 +9,19 @@ haiku/rag/mcp.py,sha256=tMN6fNX7ZtAER1R6DL1GkC9HZozTC4HzuQs199p7icI,4551
|
|
|
9
9
|
haiku/rag/monitor.py,sha256=r386nkhdlsU8UECwIuVwnrSlgMk3vNIuUZGNIzkZuec,2770
|
|
10
10
|
haiku/rag/reader.py,sha256=qkPTMJuQ_o4sK-8zpDl9WFYe_MJ7aL_gUw6rczIpW-g,3274
|
|
11
11
|
haiku/rag/utils.py,sha256=g-uNTG60iBLgkeHHuah6eVZEkX3NFLs-LZU1YnzJzLQ,2967
|
|
12
|
-
haiku/rag/embeddings/__init__.py,sha256=
|
|
12
|
+
haiku/rag/embeddings/__init__.py,sha256=n7aHW3BxHlpGxU4ze4YYDOsljzFpEep8dwVE2n45JoE,1218
|
|
13
13
|
haiku/rag/embeddings/base.py,sha256=NTQvuzbZPu0LBo5wAu3qGyJ4xXUaRAt1fjBO0ygWn_Y,465
|
|
14
14
|
haiku/rag/embeddings/ollama.py,sha256=y6-lp0XpbnyIjoOEdtSzMdEVkU5glOwnWQ1FkpUZnpI,370
|
|
15
|
-
haiku/rag/embeddings/openai.py,sha256=
|
|
15
|
+
haiku/rag/embeddings/openai.py,sha256=iA-DewCOSip8PLU_RhEJHFHBle4DtmCCIGNfGs58Wvk,357
|
|
16
16
|
haiku/rag/embeddings/voyageai.py,sha256=0hiRTIqu-bpl-4OaCtMHvWfPdgbrzhnfZJowSV8pLRA,415
|
|
17
|
-
haiku/rag/qa/__init__.py,sha256=
|
|
18
|
-
haiku/rag/qa/
|
|
19
|
-
haiku/rag/qa/
|
|
20
|
-
haiku/rag/qa/ollama.py,sha256=3T9ciKWpCIY7jejvdrsMC_wIvGRWQEWA0AwKjOlX35M,2131
|
|
21
|
-
haiku/rag/qa/openai.py,sha256=4BFc8pzFI-CTDxxKMskMxMKkacvUoRTVWI8kKntl3Jw,3718
|
|
22
|
-
haiku/rag/qa/prompts.py,sha256=WTA66brySfzIkuDZ_hRQQKGx12ngIu9nUDKMNGg2-Bg,3321
|
|
17
|
+
haiku/rag/qa/__init__.py,sha256=Sl7Kzrg9CuBOcMF01wc1NtQhUNWjJI0MhIHfCWrb8V4,434
|
|
18
|
+
haiku/rag/qa/agent.py,sha256=r6tYKvOW4W1HxBRHH1kmzlzb1bIJcQSuHd6cG9ANqXY,2594
|
|
19
|
+
haiku/rag/qa/prompts.py,sha256=xdT4cyrOrAK9UDgVqyev1wHF49jD57Bh40gx2sH4NPI,3341
|
|
23
20
|
haiku/rag/reranking/__init__.py,sha256=fwC3pauteJwh9Ulm2270QvwAdwr4NMr4RUEuolC-wKU,1063
|
|
24
21
|
haiku/rag/reranking/base.py,sha256=LM9yUSSJ414UgBZhFTgxGprlRqzfTe4I1vgjricz2JY,405
|
|
25
22
|
haiku/rag/reranking/cohere.py,sha256=1iTdiaa8vvb6oHVB2qpWzUOVkyfUcimVSZp6Qr4aq4c,1049
|
|
26
23
|
haiku/rag/reranking/mxbai.py,sha256=46sVTsTIkzIX9THgM3u8HaEmgY7evvEyB-N54JTHvK8,867
|
|
27
|
-
haiku/rag/reranking/ollama.py,sha256=
|
|
24
|
+
haiku/rag/reranking/ollama.py,sha256=Q3dJepxFyB9CRCtrZvcwX-Drrpa2-8TMO7YGhxD1Qcs,2610
|
|
28
25
|
haiku/rag/store/__init__.py,sha256=hq0W0DAC7ysqhWSP2M2uHX8cbG6kbr-sWHxhq6qQcY0,103
|
|
29
26
|
haiku/rag/store/engine.py,sha256=cOMBToLilI1Di1qQrFzGLqtRMsuvtiX0Q5RNIEzQy9w,6232
|
|
30
27
|
haiku/rag/store/models/__init__.py,sha256=s0E72zneGlowvZrFWaNxHYjOAUjgWdLxzdYsnvNRVlY,88
|
|
@@ -37,8 +34,8 @@ haiku/rag/store/repositories/document.py,sha256=ki8LiDukwU1469Yw51i0rQFvBzUQeYkF
|
|
|
37
34
|
haiku/rag/store/repositories/settings.py,sha256=qZLXvLsErnCWL0nBQQNfRnatHzCKhtUDLvUK9k-W_fU,2463
|
|
38
35
|
haiku/rag/store/upgrades/__init__.py,sha256=kKS1YWT_P-CYKhKtokOLTIFNKf9jlfjFFr8lyIMeogM,100
|
|
39
36
|
haiku/rag/store/upgrades/v0_3_4.py,sha256=GLogKZdZ40NX1vBHKdOJju7fFzNUCHoEnjSZg17Hm2U,663
|
|
40
|
-
haiku_rag-0.
|
|
41
|
-
haiku_rag-0.
|
|
42
|
-
haiku_rag-0.
|
|
43
|
-
haiku_rag-0.
|
|
44
|
-
haiku_rag-0.
|
|
37
|
+
haiku_rag-0.6.0.dist-info/METADATA,sha256=oLxNtf0pFMyLwc9sVsiztYbrpiyVNkg0wsX0TZdUYFw,4283
|
|
38
|
+
haiku_rag-0.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
39
|
+
haiku_rag-0.6.0.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
|
|
40
|
+
haiku_rag-0.6.0.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
|
|
41
|
+
haiku_rag-0.6.0.dist-info/RECORD,,
|
haiku/rag/qa/anthropic.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from anthropic import AsyncAnthropic # type: ignore
|
|
5
|
-
from anthropic.types import ( # type: ignore
|
|
6
|
-
MessageParam,
|
|
7
|
-
TextBlock,
|
|
8
|
-
ToolParam,
|
|
9
|
-
ToolUseBlock,
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
from haiku.rag.client import HaikuRAG
|
|
13
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
14
|
-
|
|
15
|
-
class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
client: HaikuRAG,
|
|
19
|
-
model: str = "claude-3-5-haiku-20241022",
|
|
20
|
-
use_citations: bool = False,
|
|
21
|
-
):
|
|
22
|
-
super().__init__(client, model or self._model, use_citations)
|
|
23
|
-
self.tools: Sequence[ToolParam] = [
|
|
24
|
-
ToolParam(
|
|
25
|
-
name="search_documents",
|
|
26
|
-
description="Search the knowledge base for relevant documents. Returns a JSON array with content, score, and document_uri for each result.",
|
|
27
|
-
input_schema={
|
|
28
|
-
"type": "object",
|
|
29
|
-
"properties": {
|
|
30
|
-
"query": {
|
|
31
|
-
"type": "string",
|
|
32
|
-
"description": "The search query to find relevant documents",
|
|
33
|
-
},
|
|
34
|
-
"limit": {
|
|
35
|
-
"type": "integer",
|
|
36
|
-
"description": "Maximum number of results to return",
|
|
37
|
-
"default": 3,
|
|
38
|
-
},
|
|
39
|
-
},
|
|
40
|
-
"required": ["query"],
|
|
41
|
-
},
|
|
42
|
-
)
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
async def answer(self, question: str) -> str:
|
|
46
|
-
anthropic_client = AsyncAnthropic()
|
|
47
|
-
|
|
48
|
-
messages: list[MessageParam] = [{"role": "user", "content": question}]
|
|
49
|
-
|
|
50
|
-
max_rounds = 5 # Prevent infinite loops
|
|
51
|
-
|
|
52
|
-
for _ in range(max_rounds):
|
|
53
|
-
response = await anthropic_client.messages.create(
|
|
54
|
-
model=self._model,
|
|
55
|
-
max_tokens=4096,
|
|
56
|
-
system=self._system_prompt,
|
|
57
|
-
messages=messages,
|
|
58
|
-
tools=self.tools,
|
|
59
|
-
temperature=0.0,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
if response.stop_reason == "tool_use":
|
|
63
|
-
messages.append({"role": "assistant", "content": response.content})
|
|
64
|
-
|
|
65
|
-
# Process tool calls
|
|
66
|
-
tool_results = []
|
|
67
|
-
for content_block in response.content:
|
|
68
|
-
if isinstance(content_block, ToolUseBlock):
|
|
69
|
-
if content_block.name == "search_documents":
|
|
70
|
-
args = content_block.input
|
|
71
|
-
query = (
|
|
72
|
-
args.get("query", question)
|
|
73
|
-
if isinstance(args, dict)
|
|
74
|
-
else question
|
|
75
|
-
)
|
|
76
|
-
limit = (
|
|
77
|
-
int(args.get("limit", 3))
|
|
78
|
-
if isinstance(args, dict)
|
|
79
|
-
else 3
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
context = await self._search_and_expand(
|
|
83
|
-
query, limit=limit
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
tool_results.append(
|
|
87
|
-
{
|
|
88
|
-
"type": "tool_result",
|
|
89
|
-
"tool_use_id": content_block.id,
|
|
90
|
-
"content": context,
|
|
91
|
-
}
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
if tool_results:
|
|
95
|
-
messages.append({"role": "user", "content": tool_results})
|
|
96
|
-
else:
|
|
97
|
-
# No tool use, return the response
|
|
98
|
-
if response.content:
|
|
99
|
-
first_content = response.content[0]
|
|
100
|
-
if isinstance(first_content, TextBlock):
|
|
101
|
-
return first_content.text
|
|
102
|
-
return ""
|
|
103
|
-
|
|
104
|
-
# If we've exhausted max rounds, return empty string
|
|
105
|
-
return ""
|
|
106
|
-
|
|
107
|
-
except ImportError:
|
|
108
|
-
pass
|
haiku/rag/qa/base.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
|
|
3
|
-
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class QuestionAnswerAgentBase:
|
|
8
|
-
_model: str = ""
|
|
9
|
-
_system_prompt: str = SYSTEM_PROMPT
|
|
10
|
-
|
|
11
|
-
def __init__(self, client: HaikuRAG, model: str = "", use_citations: bool = False):
|
|
12
|
-
self._model = model
|
|
13
|
-
self._client = client
|
|
14
|
-
self._system_prompt = (
|
|
15
|
-
SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
async def answer(self, question: str) -> str:
|
|
19
|
-
raise NotImplementedError(
|
|
20
|
-
"QABase is an abstract class. Please implement the answer method in a subclass."
|
|
21
|
-
)
|
|
22
|
-
|
|
23
|
-
async def _search_and_expand(self, query: str, limit: int = 3) -> str:
|
|
24
|
-
"""Search for documents and expand context, then format as JSON"""
|
|
25
|
-
search_results = await self._client.search(query, limit=limit)
|
|
26
|
-
expanded_results = await self._client.expand_context(search_results)
|
|
27
|
-
return self._format_search_results(expanded_results)
|
|
28
|
-
|
|
29
|
-
def _format_search_results(self, search_results) -> str:
|
|
30
|
-
"""Format search results as JSON list of {content, score, document_uri}"""
|
|
31
|
-
formatted_results = []
|
|
32
|
-
for chunk, score in search_results:
|
|
33
|
-
formatted_results.append(
|
|
34
|
-
{
|
|
35
|
-
"content": chunk.content,
|
|
36
|
-
"score": score,
|
|
37
|
-
"document_uri": chunk.document_uri,
|
|
38
|
-
}
|
|
39
|
-
)
|
|
40
|
-
return json.dumps(formatted_results, indent=2)
|
|
41
|
-
|
|
42
|
-
tools = [
|
|
43
|
-
{
|
|
44
|
-
"type": "function",
|
|
45
|
-
"function": {
|
|
46
|
-
"name": "search_documents",
|
|
47
|
-
"description": "Search the knowledge base for relevant documents. Returns a JSON array of search results.",
|
|
48
|
-
"parameters": {
|
|
49
|
-
"type": "object",
|
|
50
|
-
"properties": {
|
|
51
|
-
"query": {
|
|
52
|
-
"type": "string",
|
|
53
|
-
"description": "The search query to find relevant documents",
|
|
54
|
-
},
|
|
55
|
-
"limit": {
|
|
56
|
-
"type": "integer",
|
|
57
|
-
"description": "Maximum number of results to return",
|
|
58
|
-
"default": 3,
|
|
59
|
-
},
|
|
60
|
-
},
|
|
61
|
-
"required": ["query"],
|
|
62
|
-
},
|
|
63
|
-
"returns": {
|
|
64
|
-
"type": "string",
|
|
65
|
-
"description": "JSON array of search results",
|
|
66
|
-
"schema": {
|
|
67
|
-
"type": "array",
|
|
68
|
-
"items": {
|
|
69
|
-
"type": "object",
|
|
70
|
-
"properties": {
|
|
71
|
-
"content": {
|
|
72
|
-
"type": "string",
|
|
73
|
-
"description": "The document text content",
|
|
74
|
-
},
|
|
75
|
-
"score": {
|
|
76
|
-
"type": "number",
|
|
77
|
-
"description": "Relevance score (higher is more relevant)",
|
|
78
|
-
},
|
|
79
|
-
"document_uri": {
|
|
80
|
-
"type": "string",
|
|
81
|
-
"description": "Source URI/path of the document",
|
|
82
|
-
},
|
|
83
|
-
},
|
|
84
|
-
},
|
|
85
|
-
},
|
|
86
|
-
},
|
|
87
|
-
},
|
|
88
|
-
}
|
|
89
|
-
]
|
haiku/rag/qa/ollama.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
|
1
|
-
from ollama import AsyncClient
|
|
2
|
-
|
|
3
|
-
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.config import Config
|
|
5
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
6
|
-
|
|
7
|
-
OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
|
|
11
|
-
def __init__(
|
|
12
|
-
self,
|
|
13
|
-
client: HaikuRAG,
|
|
14
|
-
model: str = Config.QA_MODEL,
|
|
15
|
-
use_citations: bool = False,
|
|
16
|
-
):
|
|
17
|
-
super().__init__(client, model or self._model, use_citations)
|
|
18
|
-
|
|
19
|
-
async def answer(self, question: str) -> str:
|
|
20
|
-
ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
21
|
-
|
|
22
|
-
messages = [
|
|
23
|
-
{"role": "system", "content": self._system_prompt},
|
|
24
|
-
{"role": "user", "content": question},
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
max_rounds = 5 # Prevent infinite loops
|
|
28
|
-
|
|
29
|
-
for _ in range(max_rounds):
|
|
30
|
-
response = await ollama_client.chat(
|
|
31
|
-
model=self._model,
|
|
32
|
-
messages=messages,
|
|
33
|
-
tools=self.tools,
|
|
34
|
-
options=OLLAMA_OPTIONS,
|
|
35
|
-
think=False,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
if response.get("message", {}).get("tool_calls"):
|
|
39
|
-
messages.append(response["message"])
|
|
40
|
-
|
|
41
|
-
for tool_call in response["message"]["tool_calls"]:
|
|
42
|
-
if tool_call["function"]["name"] == "search_documents":
|
|
43
|
-
args = tool_call["function"]["arguments"]
|
|
44
|
-
query = args.get("query", question)
|
|
45
|
-
limit = int(args.get("limit", 3))
|
|
46
|
-
|
|
47
|
-
context = await self._search_and_expand(query, limit=limit)
|
|
48
|
-
messages.append(
|
|
49
|
-
{
|
|
50
|
-
"role": "tool",
|
|
51
|
-
"content": context,
|
|
52
|
-
"tool_call_id": tool_call.get("id", "search_tool"),
|
|
53
|
-
}
|
|
54
|
-
)
|
|
55
|
-
else:
|
|
56
|
-
# No tool calls, return the response
|
|
57
|
-
return response["message"]["content"]
|
|
58
|
-
|
|
59
|
-
# If we've exhausted max rounds, return empty string
|
|
60
|
-
return ""
|
haiku/rag/qa/openai.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
try:
|
|
4
|
-
from openai import AsyncOpenAI # type: ignore
|
|
5
|
-
from openai.types.chat import ( # type: ignore
|
|
6
|
-
ChatCompletionAssistantMessageParam,
|
|
7
|
-
ChatCompletionMessageParam,
|
|
8
|
-
ChatCompletionSystemMessageParam,
|
|
9
|
-
ChatCompletionToolMessageParam,
|
|
10
|
-
ChatCompletionUserMessageParam,
|
|
11
|
-
)
|
|
12
|
-
from openai.types.chat.chat_completion_tool_param import ( # type: ignore
|
|
13
|
-
ChatCompletionToolParam,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
from haiku.rag.client import HaikuRAG
|
|
17
|
-
from haiku.rag.qa.base import QuestionAnswerAgentBase
|
|
18
|
-
|
|
19
|
-
class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
client: HaikuRAG,
|
|
23
|
-
model: str = "gpt-4o-mini",
|
|
24
|
-
use_citations: bool = False,
|
|
25
|
-
):
|
|
26
|
-
super().__init__(client, model or self._model, use_citations)
|
|
27
|
-
self.tools: Sequence[ChatCompletionToolParam] = [
|
|
28
|
-
ChatCompletionToolParam(tool) for tool in self.tools
|
|
29
|
-
]
|
|
30
|
-
|
|
31
|
-
async def answer(self, question: str) -> str:
|
|
32
|
-
openai_client = AsyncOpenAI()
|
|
33
|
-
|
|
34
|
-
messages: list[ChatCompletionMessageParam] = [
|
|
35
|
-
ChatCompletionSystemMessageParam(
|
|
36
|
-
role="system", content=self._system_prompt
|
|
37
|
-
),
|
|
38
|
-
ChatCompletionUserMessageParam(role="user", content=question),
|
|
39
|
-
]
|
|
40
|
-
|
|
41
|
-
max_rounds = 5 # Prevent infinite loops
|
|
42
|
-
|
|
43
|
-
for _ in range(max_rounds):
|
|
44
|
-
response = await openai_client.chat.completions.create(
|
|
45
|
-
model=self._model,
|
|
46
|
-
messages=messages,
|
|
47
|
-
tools=self.tools,
|
|
48
|
-
temperature=0.0,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
response_message = response.choices[0].message
|
|
52
|
-
|
|
53
|
-
if response_message.tool_calls:
|
|
54
|
-
messages.append(
|
|
55
|
-
ChatCompletionAssistantMessageParam(
|
|
56
|
-
role="assistant",
|
|
57
|
-
content=response_message.content,
|
|
58
|
-
tool_calls=[
|
|
59
|
-
{
|
|
60
|
-
"id": tc.id,
|
|
61
|
-
"type": "function",
|
|
62
|
-
"function": {
|
|
63
|
-
"name": tc.function.name,
|
|
64
|
-
"arguments": tc.function.arguments,
|
|
65
|
-
},
|
|
66
|
-
}
|
|
67
|
-
for tc in response_message.tool_calls
|
|
68
|
-
],
|
|
69
|
-
)
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
for tool_call in response_message.tool_calls:
|
|
73
|
-
if tool_call.function.name == "search_documents":
|
|
74
|
-
import json
|
|
75
|
-
|
|
76
|
-
args = json.loads(tool_call.function.arguments)
|
|
77
|
-
query = args.get("query", question)
|
|
78
|
-
limit = int(args.get("limit", 3))
|
|
79
|
-
|
|
80
|
-
context = await self._search_and_expand(query, limit=limit)
|
|
81
|
-
|
|
82
|
-
messages.append(
|
|
83
|
-
ChatCompletionToolMessageParam(
|
|
84
|
-
role="tool",
|
|
85
|
-
content=context,
|
|
86
|
-
tool_call_id=tool_call.id,
|
|
87
|
-
)
|
|
88
|
-
)
|
|
89
|
-
else:
|
|
90
|
-
# No tool calls, return the response
|
|
91
|
-
return response_message.content or ""
|
|
92
|
-
|
|
93
|
-
# If we've exhausted max rounds, return empty string
|
|
94
|
-
return ""
|
|
95
|
-
|
|
96
|
-
except ImportError:
|
|
97
|
-
pass
|
|
File without changes
|
|
File without changes
|
|
File without changes
|