haiku.rag 0.5.5__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

@@ -17,20 +17,14 @@ def get_embedder() -> EmbedderBase:
17
17
  except ImportError:
18
18
  raise ImportError(
19
19
  "VoyageAI embedder requires the 'voyageai' package. "
20
- "Please install haiku.rag with the 'voyageai' extra:"
20
+ "Please install haiku.rag with the 'voyageai' extra: "
21
21
  "uv pip install haiku.rag[voyageai]"
22
22
  )
23
23
  return VoyageAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
24
24
 
25
25
  if Config.EMBEDDINGS_PROVIDER == "openai":
26
- try:
27
- from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
28
- except ImportError:
29
- raise ImportError(
30
- "OpenAI embedder requires the 'openai' package. "
31
- "Please install haiku.rag with the 'openai' extra:"
32
- "uv pip install haiku.rag[openai]"
33
- )
26
+ from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
27
+
34
28
  return OpenAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
35
29
 
36
30
  raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDINGS_PROVIDER}")
@@ -1,16 +1,13 @@
1
- try:
2
- from openai import AsyncOpenAI
1
+ from openai import AsyncOpenAI
3
2
 
4
- from haiku.rag.embeddings.base import EmbedderBase
3
+ from haiku.rag.embeddings.base import EmbedderBase
5
4
 
6
- class Embedder(EmbedderBase):
7
- async def embed(self, text: str) -> list[float]:
8
- client = AsyncOpenAI()
9
- response = await client.embeddings.create(
10
- model=self._model,
11
- input=text,
12
- )
13
- return response.data[0].embedding
14
5
 
15
- except ImportError:
16
- pass
6
+ class Embedder(EmbedderBase):
7
+ async def embed(self, text: str) -> list[float]:
8
+ client = AsyncOpenAI()
9
+ response = await client.embeddings.create(
10
+ model=self._model,
11
+ input=text,
12
+ )
13
+ return response.data[0].embedding
haiku/rag/qa/__init__.py CHANGED
@@ -1,44 +1,15 @@
1
1
  from haiku.rag.client import HaikuRAG
2
2
  from haiku.rag.config import Config
3
- from haiku.rag.qa.base import QuestionAnswerAgentBase
4
- from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
3
+ from haiku.rag.qa.agent import QuestionAnswerAgent
5
4
 
6
5
 
7
- def get_qa_agent(
8
- client: HaikuRAG, model: str = "", use_citations: bool = False
9
- ) -> QuestionAnswerAgentBase:
10
- """
11
- Factory function to get the appropriate QA agent based on the configuration.
12
- """
13
- if Config.QA_PROVIDER == "ollama":
14
- return QuestionAnswerOllamaAgent(
15
- client, model or Config.QA_MODEL, use_citations
16
- )
6
+ def get_qa_agent(client: HaikuRAG, use_citations: bool = False) -> QuestionAnswerAgent:
7
+ provider = Config.QA_PROVIDER
8
+ model_name = Config.QA_MODEL
17
9
 
18
- if Config.QA_PROVIDER == "openai":
19
- try:
20
- from haiku.rag.qa.openai import QuestionAnswerOpenAIAgent
21
- except ImportError:
22
- raise ImportError(
23
- "OpenAI QA agent requires the 'openai' package. "
24
- "Please install haiku.rag with the 'openai' extra:"
25
- "uv pip install haiku.rag[openai]"
26
- )
27
- return QuestionAnswerOpenAIAgent(
28
- client, model or Config.QA_MODEL, use_citations
29
- )
30
-
31
- if Config.QA_PROVIDER == "anthropic":
32
- try:
33
- from haiku.rag.qa.anthropic import QuestionAnswerAnthropicAgent
34
- except ImportError:
35
- raise ImportError(
36
- "Anthropic QA agent requires the 'anthropic' package. "
37
- "Please install haiku.rag with the 'anthropic' extra:"
38
- "uv pip install haiku.rag[anthropic]"
39
- )
40
- return QuestionAnswerAnthropicAgent(
41
- client, model or Config.QA_MODEL, use_citations
42
- )
43
-
44
- raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
10
+ return QuestionAnswerAgent(
11
+ client=client,
12
+ provider=provider,
13
+ model=model_name,
14
+ use_citations=use_citations,
15
+ )
haiku/rag/qa/agent.py ADDED
@@ -0,0 +1,76 @@
1
+ from pydantic import BaseModel, Field
2
+ from pydantic_ai import Agent, RunContext
3
+ from pydantic_ai.models.openai import OpenAIModel
4
+ from pydantic_ai.providers.ollama import OllamaProvider
5
+
6
+ from haiku.rag.client import HaikuRAG
7
+ from haiku.rag.config import Config
8
+ from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
9
+
10
+
11
+ class SearchResult(BaseModel):
12
+ content: str = Field(description="The document text content")
13
+ score: float = Field(description="Relevance score (higher is more relevant)")
14
+ document_uri: str = Field(description="Source URI/path of the document")
15
+
16
+
17
+ class Dependencies(BaseModel):
18
+ model_config = {"arbitrary_types_allowed": True}
19
+ client: HaikuRAG
20
+
21
+
22
+ class QuestionAnswerAgent:
23
+ def __init__(
24
+ self,
25
+ client: HaikuRAG,
26
+ provider: str,
27
+ model: str,
28
+ use_citations: bool = False,
29
+ q: float = 0.0,
30
+ ):
31
+ self._client = client
32
+
33
+ system_prompt = SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
34
+ model_obj = self._get_model(provider, model)
35
+
36
+ self._agent = Agent(
37
+ model=model_obj,
38
+ deps_type=Dependencies,
39
+ system_prompt=system_prompt,
40
+ )
41
+
42
+ @self._agent.tool
43
+ async def search_documents(
44
+ ctx: RunContext[Dependencies],
45
+ query: str,
46
+ limit: int = 3,
47
+ ) -> list[SearchResult]:
48
+ """Search the knowledge base for relevant documents."""
49
+ search_results = await ctx.deps.client.search(query, limit=limit)
50
+ expanded_results = await ctx.deps.client.expand_context(search_results)
51
+
52
+ return [
53
+ SearchResult(
54
+ content=chunk.content,
55
+ score=score,
56
+ document_uri=chunk.document_uri or "",
57
+ )
58
+ for chunk, score in expanded_results
59
+ ]
60
+
61
+ def _get_model(self, provider: str, model: str):
62
+ """Get the appropriate model object for the provider."""
63
+ if provider == "ollama":
64
+ return OpenAIModel(
65
+ model_name=model,
66
+ provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
67
+ )
68
+ else:
69
+ # For all other providers, use the provider:model format
70
+ return f"{provider}:{model}"
71
+
72
+ async def answer(self, question: str) -> str:
73
+ """Answer a question using the RAG system."""
74
+ deps = Dependencies(client=self._client)
75
+ result = await self._agent.run(question, deps=deps)
76
+ return result.output
haiku/rag/qa/prompts.py CHANGED
@@ -18,6 +18,7 @@ Guidelines:
18
18
  - Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
19
19
 
20
20
  Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
21
+ /no_think
21
22
  """
22
23
 
23
24
  SYSTEM_PROMPT_WITH_CITATIONS = """
@@ -55,4 +56,5 @@ Citations:
55
56
  - /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
56
57
 
57
58
  Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
59
+ /no_think
58
60
  """
@@ -1,14 +1,12 @@
1
- import json
2
-
3
- from ollama import AsyncClient
4
1
  from pydantic import BaseModel
2
+ from pydantic_ai import Agent
3
+ from pydantic_ai.models.openai import OpenAIModel
4
+ from pydantic_ai.providers.ollama import OllamaProvider
5
5
 
6
6
  from haiku.rag.config import Config
7
7
  from haiku.rag.reranking.base import RerankerBase
8
8
  from haiku.rag.store.models.chunk import Chunk
9
9
 
10
- OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
11
-
12
10
 
13
11
  class RerankResult(BaseModel):
14
12
  """Individual rerank result with index and relevance score."""
@@ -26,7 +24,28 @@ class RerankResponse(BaseModel):
26
24
  class OllamaReranker(RerankerBase):
27
25
  def __init__(self, model: str = Config.RERANK_MODEL):
28
26
  self._model = model
29
- self._client = AsyncClient(host=Config.OLLAMA_BASE_URL)
27
+
28
+ # Create the reranking prompt
29
+ system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
30
+
31
+ Return your response as a JSON object with a "results" array. Each result should have:
32
+ - "index": the original index of the document (integer)
33
+ - "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
34
+
35
+ Only return the top documents up to the requested limit, ordered by decreasing relevance score.
36
+ /no_think
37
+ """
38
+
39
+ model_obj = OpenAIModel(
40
+ model_name=model,
41
+ provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
42
+ )
43
+
44
+ self._agent = Agent(
45
+ model=model_obj,
46
+ output_type=RerankResponse,
47
+ system_prompt=system_prompt,
48
+ )
30
49
 
31
50
  async def rerank(
32
51
  self, query: str, chunks: list[Chunk], top_n: int = 10
@@ -38,15 +57,6 @@ class OllamaReranker(RerankerBase):
38
57
  for i, chunk in enumerate(chunks):
39
58
  documents.append({"index": i, "content": chunk.content})
40
59
 
41
- # Create the prompt for reranking
42
- system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
43
-
44
- Return your response as a JSON object with a "results" array. Each result should have:
45
- - "index": the original index of the document (integer)
46
- - "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
47
-
48
- Only return the top documents up to the requested limit, ordered by decreasing relevance score."""
49
-
50
60
  documents_text = ""
51
61
  for doc in documents:
52
62
  documents_text += f"Index {doc['index']}: {doc['content']}\n\n"
@@ -56,27 +66,14 @@ Only return the top documents up to the requested limit, ordered by decreasing r
56
66
  Documents to rerank:
57
67
  {documents_text.strip()}
58
68
 
59
- Please rank these documents by relevance to the query and return the top {top_n} results as JSON."""
60
-
61
- messages = [
62
- {"role": "system", "content": system_prompt},
63
- {"role": "user", "content": user_prompt},
64
- ]
69
+ Rank these documents by relevance to the query and return the top {top_n} results as JSON."""
65
70
 
66
71
  try:
67
- response = await self._client.chat(
68
- model=self._model,
69
- messages=messages,
70
- format=RerankResponse.model_json_schema(),
71
- options=OLLAMA_OPTIONS,
72
- )
73
-
74
- content = response["message"]["content"]
72
+ result = await self._agent.run(user_prompt)
75
73
 
76
- parsed_response = RerankResponse.model_validate(json.loads(content))
77
74
  return [
78
- (chunks[result.index], result.relevance_score)
79
- for result in parsed_response.results[:top_n]
75
+ (chunks[result_item.index], result_item.relevance_score)
76
+ for result_item in result.output.results[:top_n]
80
77
  ]
81
78
 
82
79
  except Exception:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haiku.rag
3
- Version: 0.5.5
3
+ Version: 0.6.0
4
4
  Summary: Retrieval Augmented Generation (RAG) with SQLite
5
5
  Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
6
6
  License: MIT
@@ -22,6 +22,7 @@ Requires-Dist: docling>=2.15.0
22
22
  Requires-Dist: fastmcp>=2.8.1
23
23
  Requires-Dist: httpx>=0.28.1
24
24
  Requires-Dist: ollama>=0.5.3
25
+ Requires-Dist: pydantic-ai>=0.7.2
25
26
  Requires-Dist: pydantic>=2.11.7
26
27
  Requires-Dist: python-dotenv>=1.1.0
27
28
  Requires-Dist: rich>=14.0.0
@@ -29,14 +30,8 @@ Requires-Dist: sqlite-vec>=0.1.6
29
30
  Requires-Dist: tiktoken>=0.9.0
30
31
  Requires-Dist: typer>=0.16.0
31
32
  Requires-Dist: watchfiles>=1.1.0
32
- Provides-Extra: anthropic
33
- Requires-Dist: anthropic>=0.56.0; extra == 'anthropic'
34
- Provides-Extra: cohere
35
- Requires-Dist: cohere>=5.16.1; extra == 'cohere'
36
33
  Provides-Extra: mxbai
37
34
  Requires-Dist: mxbai-rerank>=0.1.6; extra == 'mxbai'
38
- Provides-Extra: openai
39
- Requires-Dist: openai>=1.0.0; extra == 'openai'
40
35
  Provides-Extra: voyageai
41
36
  Requires-Dist: voyageai>=0.3.2; extra == 'voyageai'
42
37
  Description-Content-Type: text/markdown
@@ -51,7 +46,7 @@ Retrieval-Augmented Generation (RAG) library on SQLite.
51
46
 
52
47
  - **Local SQLite**: No external servers required
53
48
  - **Multiple embedding providers**: Ollama, VoyageAI, OpenAI
54
- - **Multiple QA providers**: Ollama, OpenAI, Anthropic
49
+ - **Multiple QA providers**: Any provider/model supported by Pydantic AI
55
50
  - **Hybrid search**: Vector + full-text search with Reciprocal Rank Fusion
56
51
  - **Reranking**: Default search result reranking with MixedBread AI or Cohere
57
52
  - **Question answering**: Built-in QA agents on your documents
@@ -9,22 +9,19 @@ haiku/rag/mcp.py,sha256=tMN6fNX7ZtAER1R6DL1GkC9HZozTC4HzuQs199p7icI,4551
9
9
  haiku/rag/monitor.py,sha256=r386nkhdlsU8UECwIuVwnrSlgMk3vNIuUZGNIzkZuec,2770
10
10
  haiku/rag/reader.py,sha256=qkPTMJuQ_o4sK-8zpDl9WFYe_MJ7aL_gUw6rczIpW-g,3274
11
11
  haiku/rag/utils.py,sha256=g-uNTG60iBLgkeHHuah6eVZEkX3NFLs-LZU1YnzJzLQ,2967
12
- haiku/rag/embeddings/__init__.py,sha256=yFBlxS0jBiVHl_rWz5kb43t6Ha132U1ZGdlIPfhzPdg,1491
12
+ haiku/rag/embeddings/__init__.py,sha256=n7aHW3BxHlpGxU4ze4YYDOsljzFpEep8dwVE2n45JoE,1218
13
13
  haiku/rag/embeddings/base.py,sha256=NTQvuzbZPu0LBo5wAu3qGyJ4xXUaRAt1fjBO0ygWn_Y,465
14
14
  haiku/rag/embeddings/ollama.py,sha256=y6-lp0XpbnyIjoOEdtSzMdEVkU5glOwnWQ1FkpUZnpI,370
15
- haiku/rag/embeddings/openai.py,sha256=i4Ui5hAJkcKqJkH9L3jJo7fuGYHn07td532w-ksg_T8,431
15
+ haiku/rag/embeddings/openai.py,sha256=iA-DewCOSip8PLU_RhEJHFHBle4DtmCCIGNfGs58Wvk,357
16
16
  haiku/rag/embeddings/voyageai.py,sha256=0hiRTIqu-bpl-4OaCtMHvWfPdgbrzhnfZJowSV8pLRA,415
17
- haiku/rag/qa/__init__.py,sha256=vC9S6cvZtPz-UfA_v4DMwI7eam6567BXNrUwHsMo_i8,1633
18
- haiku/rag/qa/anthropic.py,sha256=o0RVn7lcdYvoCUGXh551jeuoB3ANJSZ7uz2R_h_pZ2w,4321
19
- haiku/rag/qa/base.py,sha256=dCX14ifJW4QMCNFP_pmss9SYWM9Qm1cSWZrMl6A-2C8,3541
20
- haiku/rag/qa/ollama.py,sha256=3T9ciKWpCIY7jejvdrsMC_wIvGRWQEWA0AwKjOlX35M,2131
21
- haiku/rag/qa/openai.py,sha256=4BFc8pzFI-CTDxxKMskMxMKkacvUoRTVWI8kKntl3Jw,3718
22
- haiku/rag/qa/prompts.py,sha256=WTA66brySfzIkuDZ_hRQQKGx12ngIu9nUDKMNGg2-Bg,3321
17
+ haiku/rag/qa/__init__.py,sha256=Sl7Kzrg9CuBOcMF01wc1NtQhUNWjJI0MhIHfCWrb8V4,434
18
+ haiku/rag/qa/agent.py,sha256=r6tYKvOW4W1HxBRHH1kmzlzb1bIJcQSuHd6cG9ANqXY,2594
19
+ haiku/rag/qa/prompts.py,sha256=xdT4cyrOrAK9UDgVqyev1wHF49jD57Bh40gx2sH4NPI,3341
23
20
  haiku/rag/reranking/__init__.py,sha256=fwC3pauteJwh9Ulm2270QvwAdwr4NMr4RUEuolC-wKU,1063
24
21
  haiku/rag/reranking/base.py,sha256=LM9yUSSJ414UgBZhFTgxGprlRqzfTe4I1vgjricz2JY,405
25
22
  haiku/rag/reranking/cohere.py,sha256=1iTdiaa8vvb6oHVB2qpWzUOVkyfUcimVSZp6Qr4aq4c,1049
26
23
  haiku/rag/reranking/mxbai.py,sha256=46sVTsTIkzIX9THgM3u8HaEmgY7evvEyB-N54JTHvK8,867
27
- haiku/rag/reranking/ollama.py,sha256=tCrLlNNDBCZu7J3to1gvBq-sOvN1flYEA7E3H3Jq0mU,2790
24
+ haiku/rag/reranking/ollama.py,sha256=Q3dJepxFyB9CRCtrZvcwX-Drrpa2-8TMO7YGhxD1Qcs,2610
28
25
  haiku/rag/store/__init__.py,sha256=hq0W0DAC7ysqhWSP2M2uHX8cbG6kbr-sWHxhq6qQcY0,103
29
26
  haiku/rag/store/engine.py,sha256=cOMBToLilI1Di1qQrFzGLqtRMsuvtiX0Q5RNIEzQy9w,6232
30
27
  haiku/rag/store/models/__init__.py,sha256=s0E72zneGlowvZrFWaNxHYjOAUjgWdLxzdYsnvNRVlY,88
@@ -37,8 +34,8 @@ haiku/rag/store/repositories/document.py,sha256=ki8LiDukwU1469Yw51i0rQFvBzUQeYkF
37
34
  haiku/rag/store/repositories/settings.py,sha256=qZLXvLsErnCWL0nBQQNfRnatHzCKhtUDLvUK9k-W_fU,2463
38
35
  haiku/rag/store/upgrades/__init__.py,sha256=kKS1YWT_P-CYKhKtokOLTIFNKf9jlfjFFr8lyIMeogM,100
39
36
  haiku/rag/store/upgrades/v0_3_4.py,sha256=GLogKZdZ40NX1vBHKdOJju7fFzNUCHoEnjSZg17Hm2U,663
40
- haiku_rag-0.5.5.dist-info/METADATA,sha256=rponlCmspT548_0Z_YbYSp8Q2c1QQlCEXzRMx5sxPfs,4455
41
- haiku_rag-0.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
42
- haiku_rag-0.5.5.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
43
- haiku_rag-0.5.5.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
44
- haiku_rag-0.5.5.dist-info/RECORD,,
37
+ haiku_rag-0.6.0.dist-info/METADATA,sha256=oLxNtf0pFMyLwc9sVsiztYbrpiyVNkg0wsX0TZdUYFw,4283
38
+ haiku_rag-0.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
39
+ haiku_rag-0.6.0.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
40
+ haiku_rag-0.6.0.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
41
+ haiku_rag-0.6.0.dist-info/RECORD,,
haiku/rag/qa/anthropic.py DELETED
@@ -1,108 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from anthropic import AsyncAnthropic # type: ignore
5
- from anthropic.types import ( # type: ignore
6
- MessageParam,
7
- TextBlock,
8
- ToolParam,
9
- ToolUseBlock,
10
- )
11
-
12
- from haiku.rag.client import HaikuRAG
13
- from haiku.rag.qa.base import QuestionAnswerAgentBase
14
-
15
- class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
16
- def __init__(
17
- self,
18
- client: HaikuRAG,
19
- model: str = "claude-3-5-haiku-20241022",
20
- use_citations: bool = False,
21
- ):
22
- super().__init__(client, model or self._model, use_citations)
23
- self.tools: Sequence[ToolParam] = [
24
- ToolParam(
25
- name="search_documents",
26
- description="Search the knowledge base for relevant documents. Returns a JSON array with content, score, and document_uri for each result.",
27
- input_schema={
28
- "type": "object",
29
- "properties": {
30
- "query": {
31
- "type": "string",
32
- "description": "The search query to find relevant documents",
33
- },
34
- "limit": {
35
- "type": "integer",
36
- "description": "Maximum number of results to return",
37
- "default": 3,
38
- },
39
- },
40
- "required": ["query"],
41
- },
42
- )
43
- ]
44
-
45
- async def answer(self, question: str) -> str:
46
- anthropic_client = AsyncAnthropic()
47
-
48
- messages: list[MessageParam] = [{"role": "user", "content": question}]
49
-
50
- max_rounds = 5 # Prevent infinite loops
51
-
52
- for _ in range(max_rounds):
53
- response = await anthropic_client.messages.create(
54
- model=self._model,
55
- max_tokens=4096,
56
- system=self._system_prompt,
57
- messages=messages,
58
- tools=self.tools,
59
- temperature=0.0,
60
- )
61
-
62
- if response.stop_reason == "tool_use":
63
- messages.append({"role": "assistant", "content": response.content})
64
-
65
- # Process tool calls
66
- tool_results = []
67
- for content_block in response.content:
68
- if isinstance(content_block, ToolUseBlock):
69
- if content_block.name == "search_documents":
70
- args = content_block.input
71
- query = (
72
- args.get("query", question)
73
- if isinstance(args, dict)
74
- else question
75
- )
76
- limit = (
77
- int(args.get("limit", 3))
78
- if isinstance(args, dict)
79
- else 3
80
- )
81
-
82
- context = await self._search_and_expand(
83
- query, limit=limit
84
- )
85
-
86
- tool_results.append(
87
- {
88
- "type": "tool_result",
89
- "tool_use_id": content_block.id,
90
- "content": context,
91
- }
92
- )
93
-
94
- if tool_results:
95
- messages.append({"role": "user", "content": tool_results})
96
- else:
97
- # No tool use, return the response
98
- if response.content:
99
- first_content = response.content[0]
100
- if isinstance(first_content, TextBlock):
101
- return first_content.text
102
- return ""
103
-
104
- # If we've exhausted max rounds, return empty string
105
- return ""
106
-
107
- except ImportError:
108
- pass
haiku/rag/qa/base.py DELETED
@@ -1,89 +0,0 @@
1
- import json
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
5
-
6
-
7
- class QuestionAnswerAgentBase:
8
- _model: str = ""
9
- _system_prompt: str = SYSTEM_PROMPT
10
-
11
- def __init__(self, client: HaikuRAG, model: str = "", use_citations: bool = False):
12
- self._model = model
13
- self._client = client
14
- self._system_prompt = (
15
- SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
16
- )
17
-
18
- async def answer(self, question: str) -> str:
19
- raise NotImplementedError(
20
- "QABase is an abstract class. Please implement the answer method in a subclass."
21
- )
22
-
23
- async def _search_and_expand(self, query: str, limit: int = 3) -> str:
24
- """Search for documents and expand context, then format as JSON"""
25
- search_results = await self._client.search(query, limit=limit)
26
- expanded_results = await self._client.expand_context(search_results)
27
- return self._format_search_results(expanded_results)
28
-
29
- def _format_search_results(self, search_results) -> str:
30
- """Format search results as JSON list of {content, score, document_uri}"""
31
- formatted_results = []
32
- for chunk, score in search_results:
33
- formatted_results.append(
34
- {
35
- "content": chunk.content,
36
- "score": score,
37
- "document_uri": chunk.document_uri,
38
- }
39
- )
40
- return json.dumps(formatted_results, indent=2)
41
-
42
- tools = [
43
- {
44
- "type": "function",
45
- "function": {
46
- "name": "search_documents",
47
- "description": "Search the knowledge base for relevant documents. Returns a JSON array of search results.",
48
- "parameters": {
49
- "type": "object",
50
- "properties": {
51
- "query": {
52
- "type": "string",
53
- "description": "The search query to find relevant documents",
54
- },
55
- "limit": {
56
- "type": "integer",
57
- "description": "Maximum number of results to return",
58
- "default": 3,
59
- },
60
- },
61
- "required": ["query"],
62
- },
63
- "returns": {
64
- "type": "string",
65
- "description": "JSON array of search results",
66
- "schema": {
67
- "type": "array",
68
- "items": {
69
- "type": "object",
70
- "properties": {
71
- "content": {
72
- "type": "string",
73
- "description": "The document text content",
74
- },
75
- "score": {
76
- "type": "number",
77
- "description": "Relevance score (higher is more relevant)",
78
- },
79
- "document_uri": {
80
- "type": "string",
81
- "description": "Source URI/path of the document",
82
- },
83
- },
84
- },
85
- },
86
- },
87
- },
88
- }
89
- ]
haiku/rag/qa/ollama.py DELETED
@@ -1,60 +0,0 @@
1
- from ollama import AsyncClient
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.config import Config
5
- from haiku.rag.qa.base import QuestionAnswerAgentBase
6
-
7
- OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
8
-
9
-
10
- class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
11
- def __init__(
12
- self,
13
- client: HaikuRAG,
14
- model: str = Config.QA_MODEL,
15
- use_citations: bool = False,
16
- ):
17
- super().__init__(client, model or self._model, use_citations)
18
-
19
- async def answer(self, question: str) -> str:
20
- ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
21
-
22
- messages = [
23
- {"role": "system", "content": self._system_prompt},
24
- {"role": "user", "content": question},
25
- ]
26
-
27
- max_rounds = 5 # Prevent infinite loops
28
-
29
- for _ in range(max_rounds):
30
- response = await ollama_client.chat(
31
- model=self._model,
32
- messages=messages,
33
- tools=self.tools,
34
- options=OLLAMA_OPTIONS,
35
- think=False,
36
- )
37
-
38
- if response.get("message", {}).get("tool_calls"):
39
- messages.append(response["message"])
40
-
41
- for tool_call in response["message"]["tool_calls"]:
42
- if tool_call["function"]["name"] == "search_documents":
43
- args = tool_call["function"]["arguments"]
44
- query = args.get("query", question)
45
- limit = int(args.get("limit", 3))
46
-
47
- context = await self._search_and_expand(query, limit=limit)
48
- messages.append(
49
- {
50
- "role": "tool",
51
- "content": context,
52
- "tool_call_id": tool_call.get("id", "search_tool"),
53
- }
54
- )
55
- else:
56
- # No tool calls, return the response
57
- return response["message"]["content"]
58
-
59
- # If we've exhausted max rounds, return empty string
60
- return ""
haiku/rag/qa/openai.py DELETED
@@ -1,97 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from openai import AsyncOpenAI # type: ignore
5
- from openai.types.chat import ( # type: ignore
6
- ChatCompletionAssistantMessageParam,
7
- ChatCompletionMessageParam,
8
- ChatCompletionSystemMessageParam,
9
- ChatCompletionToolMessageParam,
10
- ChatCompletionUserMessageParam,
11
- )
12
- from openai.types.chat.chat_completion_tool_param import ( # type: ignore
13
- ChatCompletionToolParam,
14
- )
15
-
16
- from haiku.rag.client import HaikuRAG
17
- from haiku.rag.qa.base import QuestionAnswerAgentBase
18
-
19
- class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
20
- def __init__(
21
- self,
22
- client: HaikuRAG,
23
- model: str = "gpt-4o-mini",
24
- use_citations: bool = False,
25
- ):
26
- super().__init__(client, model or self._model, use_citations)
27
- self.tools: Sequence[ChatCompletionToolParam] = [
28
- ChatCompletionToolParam(tool) for tool in self.tools
29
- ]
30
-
31
- async def answer(self, question: str) -> str:
32
- openai_client = AsyncOpenAI()
33
-
34
- messages: list[ChatCompletionMessageParam] = [
35
- ChatCompletionSystemMessageParam(
36
- role="system", content=self._system_prompt
37
- ),
38
- ChatCompletionUserMessageParam(role="user", content=question),
39
- ]
40
-
41
- max_rounds = 5 # Prevent infinite loops
42
-
43
- for _ in range(max_rounds):
44
- response = await openai_client.chat.completions.create(
45
- model=self._model,
46
- messages=messages,
47
- tools=self.tools,
48
- temperature=0.0,
49
- )
50
-
51
- response_message = response.choices[0].message
52
-
53
- if response_message.tool_calls:
54
- messages.append(
55
- ChatCompletionAssistantMessageParam(
56
- role="assistant",
57
- content=response_message.content,
58
- tool_calls=[
59
- {
60
- "id": tc.id,
61
- "type": "function",
62
- "function": {
63
- "name": tc.function.name,
64
- "arguments": tc.function.arguments,
65
- },
66
- }
67
- for tc in response_message.tool_calls
68
- ],
69
- )
70
- )
71
-
72
- for tool_call in response_message.tool_calls:
73
- if tool_call.function.name == "search_documents":
74
- import json
75
-
76
- args = json.loads(tool_call.function.arguments)
77
- query = args.get("query", question)
78
- limit = int(args.get("limit", 3))
79
-
80
- context = await self._search_and_expand(query, limit=limit)
81
-
82
- messages.append(
83
- ChatCompletionToolMessageParam(
84
- role="tool",
85
- content=context,
86
- tool_call_id=tool_call.id,
87
- )
88
- )
89
- else:
90
- # No tool calls, return the response
91
- return response_message.content or ""
92
-
93
- # If we've exhausted max rounds, return empty string
94
- return ""
95
-
96
- except ImportError:
97
- pass