haiku.rag 0.5.5__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

haiku/rag/utils.py CHANGED
@@ -1,4 +1,7 @@
1
+ import asyncio
1
2
  import sys
3
+ from collections.abc import Callable
4
+ from functools import wraps
2
5
  from importlib import metadata
3
6
  from io import BytesIO
4
7
  from pathlib import Path
@@ -10,6 +13,42 @@ from docling_core.types.io import DocumentStream
10
13
  from packaging.version import Version, parse
11
14
 
12
15
 
16
+ def debounce(wait: float) -> Callable:
17
+ """
18
+ A decorator to debounce a function, ensuring it is called only after a specified delay
19
+ and always executes after the last call.
20
+
21
+ Args:
22
+ wait (float): The debounce delay in seconds.
23
+
24
+ Returns:
25
+ Callable: The decorated function.
26
+ """
27
+
28
+ def decorator(func: Callable) -> Callable:
29
+ last_call = None
30
+ task = None
31
+
32
+ @wraps(func)
33
+ async def debounced(*args, **kwargs):
34
+ nonlocal last_call, task
35
+ last_call = asyncio.get_event_loop().time()
36
+
37
+ if task:
38
+ task.cancel()
39
+
40
+ async def call_func():
41
+ await asyncio.sleep(wait)
42
+ if asyncio.get_event_loop().time() - last_call >= wait: # type: ignore
43
+ await func(*args, **kwargs)
44
+
45
+ task = asyncio.create_task(call_func())
46
+
47
+ return debounced
48
+
49
+ return decorator
50
+
51
+
13
52
  def get_default_data_dir() -> Path:
14
53
  """Get the user data directory for the current system platform.
15
54
 
@@ -32,37 +71,6 @@ def get_default_data_dir() -> Path:
32
71
  return data_path
33
72
 
34
73
 
35
- def semantic_version_to_int(version: str) -> int:
36
- """Convert a semantic version string to an integer.
37
-
38
- Args:
39
- version: Semantic version string.
40
-
41
- Returns:
42
- Integer representation of semantic version.
43
- """
44
- major, minor, patch = version.split(".")
45
- major = int(major) << 16
46
- minor = int(minor) << 8
47
- patch = int(patch)
48
- return major + minor + patch
49
-
50
-
51
- def int_to_semantic_version(version: int) -> str:
52
- """Convert an integer to a semantic version string.
53
-
54
- Args:
55
- version: Integer representation of semantic version.
56
-
57
- Returns:
58
- Semantic version string.
59
- """
60
- major = version >> 16
61
- minor = (version >> 8) & 255
62
- patch = version & 255
63
- return f"{major}.{minor}.{patch}"
64
-
65
-
66
74
  async def is_up_to_date() -> tuple[bool, Version, Version]:
67
75
  """Check whether haiku.rag is current.
68
76
 
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haiku.rag
3
- Version: 0.5.5
4
- Summary: Retrieval Augmented Generation (RAG) with SQLite
3
+ Version: 0.7.0
4
+ Summary: Retrieval Augmented Generation (RAG) with LanceDB
5
5
  Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
6
6
  License: MIT
7
7
  License-File: LICENSE
8
- Keywords: RAG,mcp,ml,sqlite,sqlite-vec
8
+ Keywords: RAG,lancedb,mcp,ml,vector-database
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Environment :: Console
11
11
  Classifier: Intended Audience :: Developers
@@ -17,42 +17,39 @@ Classifier: Programming Language :: Python :: 3.10
17
17
  Classifier: Programming Language :: Python :: 3.11
18
18
  Classifier: Programming Language :: Python :: 3.12
19
19
  Classifier: Typing :: Typed
20
- Requires-Python: >=3.11
21
- Requires-Dist: docling>=2.15.0
20
+ Requires-Python: >=3.12
21
+ Requires-Dist: docling>=2.49.0
22
22
  Requires-Dist: fastmcp>=2.8.1
23
23
  Requires-Dist: httpx>=0.28.1
24
+ Requires-Dist: lancedb>=0.24.3
24
25
  Requires-Dist: ollama>=0.5.3
26
+ Requires-Dist: pydantic-ai>=0.8.1
25
27
  Requires-Dist: pydantic>=2.11.7
26
28
  Requires-Dist: python-dotenv>=1.1.0
27
- Requires-Dist: rich>=14.0.0
28
- Requires-Dist: sqlite-vec>=0.1.6
29
- Requires-Dist: tiktoken>=0.9.0
30
- Requires-Dist: typer>=0.16.0
29
+ Requires-Dist: rich>=14.1.0
30
+ Requires-Dist: tiktoken>=0.11.0
31
+ Requires-Dist: typer>=0.16.1
31
32
  Requires-Dist: watchfiles>=1.1.0
32
- Provides-Extra: anthropic
33
- Requires-Dist: anthropic>=0.56.0; extra == 'anthropic'
34
- Provides-Extra: cohere
35
- Requires-Dist: cohere>=5.16.1; extra == 'cohere'
36
33
  Provides-Extra: mxbai
37
34
  Requires-Dist: mxbai-rerank>=0.1.6; extra == 'mxbai'
38
- Provides-Extra: openai
39
- Requires-Dist: openai>=1.0.0; extra == 'openai'
40
35
  Provides-Extra: voyageai
41
36
  Requires-Dist: voyageai>=0.3.2; extra == 'voyageai'
42
37
  Description-Content-Type: text/markdown
43
38
 
44
- # Haiku SQLite RAG
39
+ # Haiku RAG
45
40
 
46
- Retrieval-Augmented Generation (RAG) library on SQLite.
41
+ Retrieval-Augmented Generation (RAG) library built on LanceDB.
47
42
 
48
- `haiku.rag` is a Retrieval-Augmented Generation (RAG) library built to work on SQLite alone without the need for external vector databases. It uses [sqlite-vec](https://github.com/asg017/sqlite-vec) for storing the embeddings and performs semantic (vector) search as well as full-text search combined through Reciprocal Rank Fusion. Both open-source (Ollama) as well as commercial (OpenAI, VoyageAI) embedding providers are supported.
43
+ `haiku.rag` is a Retrieval-Augmented Generation (RAG) library built to work with LanceDB as a local vector database. It uses LanceDB for storing embeddings and performs semantic (vector) search as well as full-text search combined through native hybrid search with Reciprocal Rank Fusion. Both open-source (Ollama) as well as commercial (OpenAI, VoyageAI) embedding providers are supported.
44
+
45
+ > **Note**: Starting with version 0.7.0, haiku.rag uses LanceDB instead of SQLite. If you have an existing SQLite database, use `haiku-rag migrate old_database.sqlite` to migrate your data safely.
49
46
 
50
47
  ## Features
51
48
 
52
- - **Local SQLite**: No external servers required
49
+ - **Local LanceDB**: No external servers required, supports also LanceDB cloud storage, S3, Google Cloud & Azure
53
50
  - **Multiple embedding providers**: Ollama, VoyageAI, OpenAI
54
- - **Multiple QA providers**: Ollama, OpenAI, Anthropic
55
- - **Hybrid search**: Vector + full-text search with Reciprocal Rank Fusion
51
+ - **Multiple QA providers**: Any provider/model supported by Pydantic AI
52
+ - **Native hybrid search**: Vector + full-text search with native LanceDB RRF reranking
56
53
  - **Reranking**: Default search result reranking with MixedBread AI or Cohere
57
54
  - **Question answering**: Built-in QA agents on your documents
58
55
  - **File monitoring**: Auto-index files when run as server
@@ -82,6 +79,9 @@ haiku-rag ask "Who is the author of haiku.rag?" --cite
82
79
  # Rebuild database (re-chunk and re-embed all documents)
83
80
  haiku-rag rebuild
84
81
 
82
+ # Migrate from SQLite to LanceDB
83
+ haiku-rag migrate old_database.sqlite
84
+
85
85
  # Start server with file monitoring
86
86
  export MONITOR_DIRECTORIES="/path/to/docs"
87
87
  haiku-rag serve
@@ -92,7 +92,7 @@ haiku-rag serve
92
92
  ```python
93
93
  from haiku.rag.client import HaikuRAG
94
94
 
95
- async with HaikuRAG("database.db") as client:
95
+ async with HaikuRAG("database.lancedb") as client:
96
96
  # Add document
97
97
  doc = await client.create_document("Your content")
98
98
 
@@ -0,0 +1,39 @@
1
+ haiku/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ haiku/rag/app.py,sha256=GmuZxH7BMutWt8Mdu0RSateRBaKiqXh7Z9tV7cZX6n0,7655
3
+ haiku/rag/chunker.py,sha256=PVe6ysv8UlacUd4Zb3_8RFWIaWDXnzBAy2VDJ4TaUsE,1555
4
+ haiku/rag/cli.py,sha256=UY9Vh5RsIxSCV14eQbNOiwToKmbFAvqTOAnxjieaYBs,6399
5
+ haiku/rag/client.py,sha256=N4zkWjE9Rsw9YgPvNo83xptHUQR2ognfOnjkoV_w6hc,20999
6
+ haiku/rag/config.py,sha256=9Mv0QJ3c6VF1oVRSXJlSsG234dCd_sKnJO-ybMaTpDQ,1690
7
+ haiku/rag/logging.py,sha256=9RjS931Dkp0nSdJaCSCydi33jq4gXmLmIidONvD9qD4,731
8
+ haiku/rag/mcp.py,sha256=bR9Y-Nz-hvjiql20Y0KE0hwNGwyjmPGX8K9d-qmXptY,4683
9
+ haiku/rag/migration.py,sha256=gWxQwiKo0YulRhogYz4K8N98kHN9LQXIx9FeTmT24v4,10915
10
+ haiku/rag/monitor.py,sha256=r386nkhdlsU8UECwIuVwnrSlgMk3vNIuUZGNIzkZuec,2770
11
+ haiku/rag/reader.py,sha256=qkPTMJuQ_o4sK-8zpDl9WFYe_MJ7aL_gUw6rczIpW-g,3274
12
+ haiku/rag/utils.py,sha256=c8F0ECsFSqvQxzxINAOAnvShoOnJPLsOaNE3JEY2JSc,3230
13
+ haiku/rag/embeddings/__init__.py,sha256=n7aHW3BxHlpGxU4ze4YYDOsljzFpEep8dwVE2n45JoE,1218
14
+ haiku/rag/embeddings/base.py,sha256=NTQvuzbZPu0LBo5wAu3qGyJ4xXUaRAt1fjBO0ygWn_Y,465
15
+ haiku/rag/embeddings/ollama.py,sha256=y6-lp0XpbnyIjoOEdtSzMdEVkU5glOwnWQ1FkpUZnpI,370
16
+ haiku/rag/embeddings/openai.py,sha256=iA-DewCOSip8PLU_RhEJHFHBle4DtmCCIGNfGs58Wvk,357
17
+ haiku/rag/embeddings/voyageai.py,sha256=0hiRTIqu-bpl-4OaCtMHvWfPdgbrzhnfZJowSV8pLRA,415
18
+ haiku/rag/qa/__init__.py,sha256=Sl7Kzrg9CuBOcMF01wc1NtQhUNWjJI0MhIHfCWrb8V4,434
19
+ haiku/rag/qa/agent.py,sha256=r6tYKvOW4W1HxBRHH1kmzlzb1bIJcQSuHd6cG9ANqXY,2594
20
+ haiku/rag/qa/prompts.py,sha256=xdT4cyrOrAK9UDgVqyev1wHF49jD57Bh40gx2sH4NPI,3341
21
+ haiku/rag/reranking/__init__.py,sha256=IRXHs4qPu6VbGJQpzSwhgtVWWumURH_vEoVFE-extlo,894
22
+ haiku/rag/reranking/base.py,sha256=LM9yUSSJ414UgBZhFTgxGprlRqzfTe4I1vgjricz2JY,405
23
+ haiku/rag/reranking/cohere.py,sha256=1iTdiaa8vvb6oHVB2qpWzUOVkyfUcimVSZp6Qr4aq4c,1049
24
+ haiku/rag/reranking/mxbai.py,sha256=46sVTsTIkzIX9THgM3u8HaEmgY7evvEyB-N54JTHvK8,867
25
+ haiku/rag/store/__init__.py,sha256=hq0W0DAC7ysqhWSP2M2uHX8cbG6kbr-sWHxhq6qQcY0,103
26
+ haiku/rag/store/engine.py,sha256=XHGo5Xl-dCFdQHrOdMo64xVK5n0k8-LoUl5V-tlA0HI,7131
27
+ haiku/rag/store/models/__init__.py,sha256=s0E72zneGlowvZrFWaNxHYjOAUjgWdLxzdYsnvNRVlY,88
28
+ haiku/rag/store/models/chunk.py,sha256=ZNyTfO6lh3rXWLVYO3TZcitbL4LSUGr42fR6jQQ5iQc,364
29
+ haiku/rag/store/models/document.py,sha256=zSSpt6pyrMJAIXGQvIcqojcqUzwZnhp3WxVokaWxNRc,396
30
+ haiku/rag/store/repositories/__init__.py,sha256=Olv5dLfBQINRV3HrsfUpjzkZ7Qm7goEYyMNykgo_DaY,291
31
+ haiku/rag/store/repositories/chunk.py,sha256=5S77mGh6pWxPHjaXriJGmvbSOhoNM8tLwygE2GXPlbU,13586
32
+ haiku/rag/store/repositories/document.py,sha256=lP8Lo82KTP-qwXFRpYZ46WjeAdAsHwZ5pJcrXdz4g0U,6988
33
+ haiku/rag/store/repositories/settings.py,sha256=dqnAvm-98nQrWpLBbf9QghJw673QD80-iqQhRMP5t0c,5025
34
+ haiku/rag/store/upgrades/__init__.py,sha256=wUiEoSiHTahvuagx93E4FB07v123AhdbOjwUkPusiIg,14
35
+ haiku_rag-0.7.0.dist-info/METADATA,sha256=FGx5ufhc35VUXqXLdbhmUEa-3hi7aaBuDeyZYf8xaaQ,4597
36
+ haiku_rag-0.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
37
+ haiku_rag-0.7.0.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
38
+ haiku_rag-0.7.0.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
39
+ haiku_rag-0.7.0.dist-info/RECORD,,
haiku/rag/qa/anthropic.py DELETED
@@ -1,108 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from anthropic import AsyncAnthropic # type: ignore
5
- from anthropic.types import ( # type: ignore
6
- MessageParam,
7
- TextBlock,
8
- ToolParam,
9
- ToolUseBlock,
10
- )
11
-
12
- from haiku.rag.client import HaikuRAG
13
- from haiku.rag.qa.base import QuestionAnswerAgentBase
14
-
15
- class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
16
- def __init__(
17
- self,
18
- client: HaikuRAG,
19
- model: str = "claude-3-5-haiku-20241022",
20
- use_citations: bool = False,
21
- ):
22
- super().__init__(client, model or self._model, use_citations)
23
- self.tools: Sequence[ToolParam] = [
24
- ToolParam(
25
- name="search_documents",
26
- description="Search the knowledge base for relevant documents. Returns a JSON array with content, score, and document_uri for each result.",
27
- input_schema={
28
- "type": "object",
29
- "properties": {
30
- "query": {
31
- "type": "string",
32
- "description": "The search query to find relevant documents",
33
- },
34
- "limit": {
35
- "type": "integer",
36
- "description": "Maximum number of results to return",
37
- "default": 3,
38
- },
39
- },
40
- "required": ["query"],
41
- },
42
- )
43
- ]
44
-
45
- async def answer(self, question: str) -> str:
46
- anthropic_client = AsyncAnthropic()
47
-
48
- messages: list[MessageParam] = [{"role": "user", "content": question}]
49
-
50
- max_rounds = 5 # Prevent infinite loops
51
-
52
- for _ in range(max_rounds):
53
- response = await anthropic_client.messages.create(
54
- model=self._model,
55
- max_tokens=4096,
56
- system=self._system_prompt,
57
- messages=messages,
58
- tools=self.tools,
59
- temperature=0.0,
60
- )
61
-
62
- if response.stop_reason == "tool_use":
63
- messages.append({"role": "assistant", "content": response.content})
64
-
65
- # Process tool calls
66
- tool_results = []
67
- for content_block in response.content:
68
- if isinstance(content_block, ToolUseBlock):
69
- if content_block.name == "search_documents":
70
- args = content_block.input
71
- query = (
72
- args.get("query", question)
73
- if isinstance(args, dict)
74
- else question
75
- )
76
- limit = (
77
- int(args.get("limit", 3))
78
- if isinstance(args, dict)
79
- else 3
80
- )
81
-
82
- context = await self._search_and_expand(
83
- query, limit=limit
84
- )
85
-
86
- tool_results.append(
87
- {
88
- "type": "tool_result",
89
- "tool_use_id": content_block.id,
90
- "content": context,
91
- }
92
- )
93
-
94
- if tool_results:
95
- messages.append({"role": "user", "content": tool_results})
96
- else:
97
- # No tool use, return the response
98
- if response.content:
99
- first_content = response.content[0]
100
- if isinstance(first_content, TextBlock):
101
- return first_content.text
102
- return ""
103
-
104
- # If we've exhausted max rounds, return empty string
105
- return ""
106
-
107
- except ImportError:
108
- pass
haiku/rag/qa/base.py DELETED
@@ -1,89 +0,0 @@
1
- import json
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
5
-
6
-
7
- class QuestionAnswerAgentBase:
8
- _model: str = ""
9
- _system_prompt: str = SYSTEM_PROMPT
10
-
11
- def __init__(self, client: HaikuRAG, model: str = "", use_citations: bool = False):
12
- self._model = model
13
- self._client = client
14
- self._system_prompt = (
15
- SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
16
- )
17
-
18
- async def answer(self, question: str) -> str:
19
- raise NotImplementedError(
20
- "QABase is an abstract class. Please implement the answer method in a subclass."
21
- )
22
-
23
- async def _search_and_expand(self, query: str, limit: int = 3) -> str:
24
- """Search for documents and expand context, then format as JSON"""
25
- search_results = await self._client.search(query, limit=limit)
26
- expanded_results = await self._client.expand_context(search_results)
27
- return self._format_search_results(expanded_results)
28
-
29
- def _format_search_results(self, search_results) -> str:
30
- """Format search results as JSON list of {content, score, document_uri}"""
31
- formatted_results = []
32
- for chunk, score in search_results:
33
- formatted_results.append(
34
- {
35
- "content": chunk.content,
36
- "score": score,
37
- "document_uri": chunk.document_uri,
38
- }
39
- )
40
- return json.dumps(formatted_results, indent=2)
41
-
42
- tools = [
43
- {
44
- "type": "function",
45
- "function": {
46
- "name": "search_documents",
47
- "description": "Search the knowledge base for relevant documents. Returns a JSON array of search results.",
48
- "parameters": {
49
- "type": "object",
50
- "properties": {
51
- "query": {
52
- "type": "string",
53
- "description": "The search query to find relevant documents",
54
- },
55
- "limit": {
56
- "type": "integer",
57
- "description": "Maximum number of results to return",
58
- "default": 3,
59
- },
60
- },
61
- "required": ["query"],
62
- },
63
- "returns": {
64
- "type": "string",
65
- "description": "JSON array of search results",
66
- "schema": {
67
- "type": "array",
68
- "items": {
69
- "type": "object",
70
- "properties": {
71
- "content": {
72
- "type": "string",
73
- "description": "The document text content",
74
- },
75
- "score": {
76
- "type": "number",
77
- "description": "Relevance score (higher is more relevant)",
78
- },
79
- "document_uri": {
80
- "type": "string",
81
- "description": "Source URI/path of the document",
82
- },
83
- },
84
- },
85
- },
86
- },
87
- },
88
- }
89
- ]
haiku/rag/qa/ollama.py DELETED
@@ -1,60 +0,0 @@
1
- from ollama import AsyncClient
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.config import Config
5
- from haiku.rag.qa.base import QuestionAnswerAgentBase
6
-
7
- OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
8
-
9
-
10
- class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
11
- def __init__(
12
- self,
13
- client: HaikuRAG,
14
- model: str = Config.QA_MODEL,
15
- use_citations: bool = False,
16
- ):
17
- super().__init__(client, model or self._model, use_citations)
18
-
19
- async def answer(self, question: str) -> str:
20
- ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
21
-
22
- messages = [
23
- {"role": "system", "content": self._system_prompt},
24
- {"role": "user", "content": question},
25
- ]
26
-
27
- max_rounds = 5 # Prevent infinite loops
28
-
29
- for _ in range(max_rounds):
30
- response = await ollama_client.chat(
31
- model=self._model,
32
- messages=messages,
33
- tools=self.tools,
34
- options=OLLAMA_OPTIONS,
35
- think=False,
36
- )
37
-
38
- if response.get("message", {}).get("tool_calls"):
39
- messages.append(response["message"])
40
-
41
- for tool_call in response["message"]["tool_calls"]:
42
- if tool_call["function"]["name"] == "search_documents":
43
- args = tool_call["function"]["arguments"]
44
- query = args.get("query", question)
45
- limit = int(args.get("limit", 3))
46
-
47
- context = await self._search_and_expand(query, limit=limit)
48
- messages.append(
49
- {
50
- "role": "tool",
51
- "content": context,
52
- "tool_call_id": tool_call.get("id", "search_tool"),
53
- }
54
- )
55
- else:
56
- # No tool calls, return the response
57
- return response["message"]["content"]
58
-
59
- # If we've exhausted max rounds, return empty string
60
- return ""
haiku/rag/qa/openai.py DELETED
@@ -1,97 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- try:
4
- from openai import AsyncOpenAI # type: ignore
5
- from openai.types.chat import ( # type: ignore
6
- ChatCompletionAssistantMessageParam,
7
- ChatCompletionMessageParam,
8
- ChatCompletionSystemMessageParam,
9
- ChatCompletionToolMessageParam,
10
- ChatCompletionUserMessageParam,
11
- )
12
- from openai.types.chat.chat_completion_tool_param import ( # type: ignore
13
- ChatCompletionToolParam,
14
- )
15
-
16
- from haiku.rag.client import HaikuRAG
17
- from haiku.rag.qa.base import QuestionAnswerAgentBase
18
-
19
- class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
20
- def __init__(
21
- self,
22
- client: HaikuRAG,
23
- model: str = "gpt-4o-mini",
24
- use_citations: bool = False,
25
- ):
26
- super().__init__(client, model or self._model, use_citations)
27
- self.tools: Sequence[ChatCompletionToolParam] = [
28
- ChatCompletionToolParam(tool) for tool in self.tools
29
- ]
30
-
31
- async def answer(self, question: str) -> str:
32
- openai_client = AsyncOpenAI()
33
-
34
- messages: list[ChatCompletionMessageParam] = [
35
- ChatCompletionSystemMessageParam(
36
- role="system", content=self._system_prompt
37
- ),
38
- ChatCompletionUserMessageParam(role="user", content=question),
39
- ]
40
-
41
- max_rounds = 5 # Prevent infinite loops
42
-
43
- for _ in range(max_rounds):
44
- response = await openai_client.chat.completions.create(
45
- model=self._model,
46
- messages=messages,
47
- tools=self.tools,
48
- temperature=0.0,
49
- )
50
-
51
- response_message = response.choices[0].message
52
-
53
- if response_message.tool_calls:
54
- messages.append(
55
- ChatCompletionAssistantMessageParam(
56
- role="assistant",
57
- content=response_message.content,
58
- tool_calls=[
59
- {
60
- "id": tc.id,
61
- "type": "function",
62
- "function": {
63
- "name": tc.function.name,
64
- "arguments": tc.function.arguments,
65
- },
66
- }
67
- for tc in response_message.tool_calls
68
- ],
69
- )
70
- )
71
-
72
- for tool_call in response_message.tool_calls:
73
- if tool_call.function.name == "search_documents":
74
- import json
75
-
76
- args = json.loads(tool_call.function.arguments)
77
- query = args.get("query", question)
78
- limit = int(args.get("limit", 3))
79
-
80
- context = await self._search_and_expand(query, limit=limit)
81
-
82
- messages.append(
83
- ChatCompletionToolMessageParam(
84
- role="tool",
85
- content=context,
86
- tool_call_id=tool_call.id,
87
- )
88
- )
89
- else:
90
- # No tool calls, return the response
91
- return response_message.content or ""
92
-
93
- # If we've exhausted max rounds, return empty string
94
- return ""
95
-
96
- except ImportError:
97
- pass
@@ -1,84 +0,0 @@
1
- import json
2
-
3
- from ollama import AsyncClient
4
- from pydantic import BaseModel
5
-
6
- from haiku.rag.config import Config
7
- from haiku.rag.reranking.base import RerankerBase
8
- from haiku.rag.store.models.chunk import Chunk
9
-
10
- OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
11
-
12
-
13
- class RerankResult(BaseModel):
14
- """Individual rerank result with index and relevance score."""
15
-
16
- index: int
17
- relevance_score: float
18
-
19
-
20
- class RerankResponse(BaseModel):
21
- """Response from the reranking model containing ranked results."""
22
-
23
- results: list[RerankResult]
24
-
25
-
26
- class OllamaReranker(RerankerBase):
27
- def __init__(self, model: str = Config.RERANK_MODEL):
28
- self._model = model
29
- self._client = AsyncClient(host=Config.OLLAMA_BASE_URL)
30
-
31
- async def rerank(
32
- self, query: str, chunks: list[Chunk], top_n: int = 10
33
- ) -> list[tuple[Chunk, float]]:
34
- if not chunks:
35
- return []
36
-
37
- documents = []
38
- for i, chunk in enumerate(chunks):
39
- documents.append({"index": i, "content": chunk.content})
40
-
41
- # Create the prompt for reranking
42
- system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
43
-
44
- Return your response as a JSON object with a "results" array. Each result should have:
45
- - "index": the original index of the document (integer)
46
- - "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
47
-
48
- Only return the top documents up to the requested limit, ordered by decreasing relevance score."""
49
-
50
- documents_text = ""
51
- for doc in documents:
52
- documents_text += f"Index {doc['index']}: {doc['content']}\n\n"
53
-
54
- user_prompt = f"""Query: {query}
55
-
56
- Documents to rerank:
57
- {documents_text.strip()}
58
-
59
- Please rank these documents by relevance to the query and return the top {top_n} results as JSON."""
60
-
61
- messages = [
62
- {"role": "system", "content": system_prompt},
63
- {"role": "user", "content": user_prompt},
64
- ]
65
-
66
- try:
67
- response = await self._client.chat(
68
- model=self._model,
69
- messages=messages,
70
- format=RerankResponse.model_json_schema(),
71
- options=OLLAMA_OPTIONS,
72
- )
73
-
74
- content = response["message"]["content"]
75
-
76
- parsed_response = RerankResponse.model_validate(json.loads(content))
77
- return [
78
- (chunks[result.index], result.relevance_score)
79
- for result in parsed_response.results[:top_n]
80
- ]
81
-
82
- except Exception:
83
- # Fallback: return chunks in original order with same score
84
- return [(chunks[i], 1.0) for i in range(min(top_n, len(chunks)))]