haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag-slim might be problematic. Click here for more details.
- haiku/rag/app.py +430 -72
- haiku/rag/chunkers/__init__.py +31 -0
- haiku/rag/chunkers/base.py +31 -0
- haiku/rag/chunkers/docling_local.py +164 -0
- haiku/rag/chunkers/docling_serve.py +179 -0
- haiku/rag/cli.py +207 -24
- haiku/rag/cli_chat.py +489 -0
- haiku/rag/client.py +1251 -266
- haiku/rag/config/__init__.py +16 -10
- haiku/rag/config/loader.py +5 -44
- haiku/rag/config/models.py +126 -17
- haiku/rag/converters/__init__.py +31 -0
- haiku/rag/converters/base.py +63 -0
- haiku/rag/converters/docling_local.py +193 -0
- haiku/rag/converters/docling_serve.py +229 -0
- haiku/rag/converters/text_utils.py +237 -0
- haiku/rag/embeddings/__init__.py +123 -24
- haiku/rag/embeddings/voyageai.py +175 -20
- haiku/rag/graph/__init__.py +0 -11
- haiku/rag/graph/agui/__init__.py +8 -2
- haiku/rag/graph/agui/cli_renderer.py +1 -1
- haiku/rag/graph/agui/emitter.py +219 -31
- haiku/rag/graph/agui/server.py +20 -62
- haiku/rag/graph/agui/stream.py +1 -2
- haiku/rag/graph/research/__init__.py +5 -2
- haiku/rag/graph/research/dependencies.py +12 -126
- haiku/rag/graph/research/graph.py +390 -135
- haiku/rag/graph/research/models.py +91 -112
- haiku/rag/graph/research/prompts.py +99 -91
- haiku/rag/graph/research/state.py +35 -27
- haiku/rag/inspector/__init__.py +8 -0
- haiku/rag/inspector/app.py +259 -0
- haiku/rag/inspector/widgets/__init__.py +6 -0
- haiku/rag/inspector/widgets/chunk_list.py +100 -0
- haiku/rag/inspector/widgets/context_modal.py +89 -0
- haiku/rag/inspector/widgets/detail_view.py +130 -0
- haiku/rag/inspector/widgets/document_list.py +75 -0
- haiku/rag/inspector/widgets/info_modal.py +209 -0
- haiku/rag/inspector/widgets/search_modal.py +183 -0
- haiku/rag/inspector/widgets/visual_modal.py +126 -0
- haiku/rag/mcp.py +106 -102
- haiku/rag/monitor.py +33 -9
- haiku/rag/providers/__init__.py +5 -0
- haiku/rag/providers/docling_serve.py +108 -0
- haiku/rag/qa/__init__.py +12 -10
- haiku/rag/qa/agent.py +43 -61
- haiku/rag/qa/prompts.py +35 -57
- haiku/rag/reranking/__init__.py +9 -6
- haiku/rag/reranking/base.py +1 -1
- haiku/rag/reranking/cohere.py +5 -4
- haiku/rag/reranking/mxbai.py +5 -2
- haiku/rag/reranking/vllm.py +3 -4
- haiku/rag/reranking/zeroentropy.py +6 -5
- haiku/rag/store/__init__.py +2 -1
- haiku/rag/store/engine.py +242 -42
- haiku/rag/store/exceptions.py +4 -0
- haiku/rag/store/models/__init__.py +8 -2
- haiku/rag/store/models/chunk.py +190 -0
- haiku/rag/store/models/document.py +46 -0
- haiku/rag/store/repositories/chunk.py +141 -121
- haiku/rag/store/repositories/document.py +25 -84
- haiku/rag/store/repositories/settings.py +11 -14
- haiku/rag/store/upgrades/__init__.py +19 -3
- haiku/rag/store/upgrades/v0_10_1.py +1 -1
- haiku/rag/store/upgrades/v0_19_6.py +65 -0
- haiku/rag/store/upgrades/v0_20_0.py +68 -0
- haiku/rag/store/upgrades/v0_23_1.py +100 -0
- haiku/rag/store/upgrades/v0_9_3.py +3 -3
- haiku/rag/utils.py +371 -146
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
- haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
- haiku/rag/chunker.py +0 -65
- haiku/rag/embeddings/base.py +0 -25
- haiku/rag/embeddings/ollama.py +0 -28
- haiku/rag/embeddings/openai.py +0 -26
- haiku/rag/embeddings/vllm.py +0 -29
- haiku/rag/graph/agui/events.py +0 -254
- haiku/rag/graph/common/__init__.py +0 -5
- haiku/rag/graph/common/models.py +0 -42
- haiku/rag/graph/common/nodes.py +0 -265
- haiku/rag/graph/common/prompts.py +0 -46
- haiku/rag/graph/common/utils.py +0 -44
- haiku/rag/graph/deep_qa/__init__.py +0 -1
- haiku/rag/graph/deep_qa/dependencies.py +0 -27
- haiku/rag/graph/deep_qa/graph.py +0 -243
- haiku/rag/graph/deep_qa/models.py +0 -20
- haiku/rag/graph/deep_qa/prompts.py +0 -59
- haiku/rag/graph/deep_qa/state.py +0 -56
- haiku/rag/graph/research/common.py +0 -87
- haiku/rag/reader.py +0 -135
- haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/mcp.py
CHANGED
|
@@ -7,12 +7,8 @@ from pydantic import BaseModel
|
|
|
7
7
|
from haiku.rag.client import HaikuRAG
|
|
8
8
|
from haiku.rag.config import AppConfig, Config
|
|
9
9
|
from haiku.rag.graph.research.models import ResearchReport
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class SearchResult(BaseModel):
|
|
13
|
-
document_id: str
|
|
14
|
-
content: str
|
|
15
|
-
score: float
|
|
10
|
+
from haiku.rag.store.models import SearchResult
|
|
11
|
+
from haiku.rag.utils import format_citations
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
class DocumentResult(BaseModel):
|
|
@@ -25,84 +21,92 @@ class DocumentResult(BaseModel):
|
|
|
25
21
|
updated_at: str
|
|
26
22
|
|
|
27
23
|
|
|
28
|
-
def create_mcp_server(
|
|
29
|
-
|
|
30
|
-
|
|
24
|
+
def create_mcp_server(
|
|
25
|
+
db_path: Path, config: AppConfig = Config, read_only: bool = False
|
|
26
|
+
) -> FastMCP:
|
|
27
|
+
"""Create an MCP server with the specified database path.
|
|
31
28
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
"""Add a document to the RAG system from a file path."""
|
|
39
|
-
try:
|
|
40
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
41
|
-
result = await rag.create_document_from_source(
|
|
42
|
-
Path(file_path), title=title, metadata=metadata or {}
|
|
43
|
-
)
|
|
44
|
-
# Handle both single document and list of documents (directories)
|
|
45
|
-
if isinstance(result, list):
|
|
46
|
-
return result[0].id if result else None
|
|
47
|
-
return result.id
|
|
48
|
-
except Exception:
|
|
49
|
-
return None
|
|
50
|
-
|
|
51
|
-
@mcp.tool()
|
|
52
|
-
async def add_document_from_url(
|
|
53
|
-
url: str, metadata: dict[str, Any] | None = None, title: str | None = None
|
|
54
|
-
) -> str | None:
|
|
55
|
-
"""Add a document to the RAG system from a URL."""
|
|
56
|
-
try:
|
|
57
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
58
|
-
result = await rag.create_document_from_source(
|
|
59
|
-
url, title=title, metadata=metadata or {}
|
|
60
|
-
)
|
|
61
|
-
# Handle both single document and list of documents
|
|
62
|
-
if isinstance(result, list):
|
|
63
|
-
return result[0].id if result else None
|
|
64
|
-
return result.id
|
|
65
|
-
except Exception:
|
|
66
|
-
return None
|
|
67
|
-
|
|
68
|
-
@mcp.tool()
|
|
69
|
-
async def add_document_from_text(
|
|
70
|
-
content: str,
|
|
71
|
-
uri: str | None = None,
|
|
72
|
-
metadata: dict[str, Any] | None = None,
|
|
73
|
-
title: str | None = None,
|
|
74
|
-
) -> str | None:
|
|
75
|
-
"""Add a document to the RAG system from text content."""
|
|
76
|
-
try:
|
|
77
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
78
|
-
document = await rag.create_document(
|
|
79
|
-
content, uri, title=title, metadata=metadata or {}
|
|
80
|
-
)
|
|
81
|
-
return document.id
|
|
82
|
-
except Exception:
|
|
83
|
-
return None
|
|
29
|
+
Args:
|
|
30
|
+
db_path: Path to the database file.
|
|
31
|
+
config: Configuration to use.
|
|
32
|
+
read_only: If True, write tools (add_document_*, delete_document) are not registered.
|
|
33
|
+
"""
|
|
34
|
+
mcp = FastMCP("haiku-rag")
|
|
84
35
|
|
|
36
|
+
# Write tools - only registered when not in read-only mode
|
|
37
|
+
if not read_only:
|
|
38
|
+
|
|
39
|
+
@mcp.tool()
|
|
40
|
+
async def add_document_from_file(
|
|
41
|
+
file_path: str,
|
|
42
|
+
metadata: dict[str, Any] | None = None,
|
|
43
|
+
title: str | None = None,
|
|
44
|
+
) -> str | None:
|
|
45
|
+
"""Add a document to the RAG system from a file path."""
|
|
46
|
+
try:
|
|
47
|
+
async with HaikuRAG(db_path, config=config) as rag:
|
|
48
|
+
result = await rag.create_document_from_source(
|
|
49
|
+
Path(file_path), title=title, metadata=metadata or {}
|
|
50
|
+
)
|
|
51
|
+
# Handle both single document and list of documents (directories)
|
|
52
|
+
if isinstance(result, list):
|
|
53
|
+
return result[0].id if result else None
|
|
54
|
+
return result.id
|
|
55
|
+
except Exception:
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
@mcp.tool()
|
|
59
|
+
async def add_document_from_url(
|
|
60
|
+
url: str, metadata: dict[str, Any] | None = None, title: str | None = None
|
|
61
|
+
) -> str | None:
|
|
62
|
+
"""Add a document to the RAG system from a URL."""
|
|
63
|
+
try:
|
|
64
|
+
async with HaikuRAG(db_path, config=config) as rag:
|
|
65
|
+
result = await rag.create_document_from_source(
|
|
66
|
+
url, title=title, metadata=metadata or {}
|
|
67
|
+
)
|
|
68
|
+
# Handle both single document and list of documents
|
|
69
|
+
if isinstance(result, list):
|
|
70
|
+
return result[0].id if result else None
|
|
71
|
+
return result.id
|
|
72
|
+
except Exception:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
@mcp.tool()
|
|
76
|
+
async def add_document_from_text(
|
|
77
|
+
content: str,
|
|
78
|
+
uri: str | None = None,
|
|
79
|
+
metadata: dict[str, Any] | None = None,
|
|
80
|
+
title: str | None = None,
|
|
81
|
+
) -> str | None:
|
|
82
|
+
"""Add a document to the RAG system from text content."""
|
|
83
|
+
try:
|
|
84
|
+
async with HaikuRAG(db_path, config=config) as rag:
|
|
85
|
+
document = await rag.create_document(
|
|
86
|
+
content, uri, title=title, metadata=metadata or {}
|
|
87
|
+
)
|
|
88
|
+
return document.id
|
|
89
|
+
except Exception:
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
@mcp.tool()
|
|
93
|
+
async def delete_document(document_id: str) -> bool:
|
|
94
|
+
"""Delete a document by its ID."""
|
|
95
|
+
try:
|
|
96
|
+
async with HaikuRAG(db_path, config=config) as rag:
|
|
97
|
+
return await rag.delete_document(document_id)
|
|
98
|
+
except Exception:
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
# Read tools - always registered
|
|
85
102
|
@mcp.tool()
|
|
86
|
-
async def search_documents(
|
|
103
|
+
async def search_documents(
|
|
104
|
+
query: str, limit: int | None = None
|
|
105
|
+
) -> list[SearchResult]:
|
|
87
106
|
"""Search the RAG system for documents using hybrid search (vector similarity + full-text search)."""
|
|
88
107
|
try:
|
|
89
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
search_results = []
|
|
93
|
-
for chunk, score in results:
|
|
94
|
-
assert chunk.document_id is not None, (
|
|
95
|
-
"Chunk document_id should not be None in search results"
|
|
96
|
-
)
|
|
97
|
-
search_results.append(
|
|
98
|
-
SearchResult(
|
|
99
|
-
document_id=chunk.document_id,
|
|
100
|
-
content=chunk.content,
|
|
101
|
-
score=score,
|
|
102
|
-
)
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
return search_results
|
|
108
|
+
async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
|
|
109
|
+
return await rag.search(query, limit=limit)
|
|
106
110
|
except Exception:
|
|
107
111
|
return []
|
|
108
112
|
|
|
@@ -110,7 +114,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
|
|
|
110
114
|
async def get_document(document_id: str) -> DocumentResult | None:
|
|
111
115
|
"""Get a document by its ID."""
|
|
112
116
|
try:
|
|
113
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
117
|
+
async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
|
|
114
118
|
document = await rag.get_document_by_id(document_id)
|
|
115
119
|
|
|
116
120
|
if document is None:
|
|
@@ -145,7 +149,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
|
|
|
145
149
|
List of DocumentResult instances matching the criteria.
|
|
146
150
|
"""
|
|
147
151
|
try:
|
|
148
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
152
|
+
async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
|
|
149
153
|
documents = await rag.list_documents(limit, offset, filter)
|
|
150
154
|
|
|
151
155
|
return [
|
|
@@ -163,15 +167,6 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
|
|
|
163
167
|
except Exception:
|
|
164
168
|
return []
|
|
165
169
|
|
|
166
|
-
@mcp.tool()
|
|
167
|
-
async def delete_document(document_id: str) -> bool:
|
|
168
|
-
"""Delete a document by its ID."""
|
|
169
|
-
try:
|
|
170
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
171
|
-
return await rag.delete_document(document_id)
|
|
172
|
-
except Exception:
|
|
173
|
-
return False
|
|
174
|
-
|
|
175
170
|
@mcp.tool()
|
|
176
171
|
async def ask_question(
|
|
177
172
|
question: str,
|
|
@@ -189,23 +184,32 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
|
|
|
189
184
|
The answer as a string.
|
|
190
185
|
"""
|
|
191
186
|
try:
|
|
192
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
187
|
+
async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
|
|
193
188
|
if deep:
|
|
194
|
-
from haiku.rag.graph.
|
|
195
|
-
from haiku.rag.graph.
|
|
196
|
-
from haiku.rag.graph.
|
|
189
|
+
from haiku.rag.graph.research.dependencies import ResearchContext
|
|
190
|
+
from haiku.rag.graph.research.graph import build_research_graph
|
|
191
|
+
from haiku.rag.graph.research.state import (
|
|
192
|
+
ResearchDeps,
|
|
193
|
+
ResearchState,
|
|
194
|
+
)
|
|
197
195
|
|
|
198
|
-
graph =
|
|
199
|
-
context =
|
|
200
|
-
|
|
196
|
+
graph = build_research_graph(config=config)
|
|
197
|
+
context = ResearchContext(original_question=question)
|
|
198
|
+
state = ResearchState.from_config(
|
|
199
|
+
context=context,
|
|
200
|
+
config=config,
|
|
201
|
+
max_iterations=2,
|
|
202
|
+
confidence_threshold=0.0,
|
|
201
203
|
)
|
|
202
|
-
|
|
203
|
-
deps = DeepQADeps(client=rag)
|
|
204
|
+
deps = ResearchDeps(client=rag)
|
|
204
205
|
|
|
205
206
|
result = await graph.run(state=state, deps=deps)
|
|
206
|
-
answer = result.
|
|
207
|
+
answer = result.executive_summary
|
|
208
|
+
citations = []
|
|
207
209
|
else:
|
|
208
|
-
answer = await rag.ask(question
|
|
210
|
+
answer, citations = await rag.ask(question)
|
|
211
|
+
if cite and citations:
|
|
212
|
+
answer += "\n\n" + format_citations(citations)
|
|
209
213
|
return answer
|
|
210
214
|
except Exception as e:
|
|
211
215
|
return f"Error answering question: {e!s}"
|
|
@@ -230,7 +234,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
|
|
|
230
234
|
from haiku.rag.graph.research.graph import build_research_graph
|
|
231
235
|
from haiku.rag.graph.research.state import ResearchDeps, ResearchState
|
|
232
236
|
|
|
233
|
-
async with HaikuRAG(db_path, config=config) as rag:
|
|
237
|
+
async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
|
|
234
238
|
graph = build_research_graph(config=config)
|
|
235
239
|
context = ResearchContext(original_question=question)
|
|
236
240
|
state = ResearchState.from_config(context=context, config=config)
|
haiku/rag/monitor.py
CHANGED
|
@@ -23,11 +23,19 @@ class FileFilter(DefaultFilter):
|
|
|
23
23
|
*,
|
|
24
24
|
ignore_patterns: list[str] | None = None,
|
|
25
25
|
include_patterns: list[str] | None = None,
|
|
26
|
+
supported_extensions: list[str] | None = None,
|
|
26
27
|
) -> None:
|
|
27
|
-
|
|
28
|
-
|
|
28
|
+
if supported_extensions is None:
|
|
29
|
+
# Default to docling-local extensions if not provided
|
|
30
|
+
from haiku.rag.converters.docling_local import DoclingLocalConverter
|
|
31
|
+
from haiku.rag.converters.text_utils import TextFileHandler
|
|
32
|
+
|
|
33
|
+
supported_extensions = (
|
|
34
|
+
DoclingLocalConverter.docling_extensions
|
|
35
|
+
+ TextFileHandler.text_extensions
|
|
36
|
+
)
|
|
29
37
|
|
|
30
|
-
self.extensions = tuple(
|
|
38
|
+
self.extensions = tuple(supported_extensions)
|
|
31
39
|
self.ignore_spec = (
|
|
32
40
|
pathspec.PathSpec.from_lines(GitWildMatchPattern, ignore_patterns)
|
|
33
41
|
if ignore_patterns
|
|
@@ -72,16 +80,33 @@ class FileWatcher:
|
|
|
72
80
|
client: HaikuRAG,
|
|
73
81
|
config: AppConfig = Config,
|
|
74
82
|
):
|
|
83
|
+
from haiku.rag.converters import get_converter
|
|
84
|
+
|
|
75
85
|
self.paths = config.monitor.directories
|
|
76
86
|
self.client = client
|
|
77
87
|
self.ignore_patterns = config.monitor.ignore_patterns or None
|
|
78
88
|
self.include_patterns = config.monitor.include_patterns or None
|
|
79
89
|
self.delete_orphans = config.monitor.delete_orphans
|
|
90
|
+
self.supported_extensions = get_converter(config).supported_extensions
|
|
80
91
|
|
|
81
92
|
async def observe(self):
|
|
93
|
+
if not self.paths:
|
|
94
|
+
logger.warning("No directories configured for monitoring")
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Validate all paths exist before attempting to watch
|
|
98
|
+
missing_paths = [p for p in self.paths if not Path(p).exists()]
|
|
99
|
+
if missing_paths:
|
|
100
|
+
raise FileNotFoundError(
|
|
101
|
+
f"Monitor directories do not exist: {missing_paths}. "
|
|
102
|
+
"Check your haiku.rag.yaml configuration."
|
|
103
|
+
)
|
|
104
|
+
|
|
82
105
|
logger.info(f"Watching files in {self.paths}")
|
|
83
106
|
filter = FileFilter(
|
|
84
|
-
ignore_patterns=self.ignore_patterns,
|
|
107
|
+
ignore_patterns=self.ignore_patterns,
|
|
108
|
+
include_patterns=self.include_patterns,
|
|
109
|
+
supported_extensions=self.supported_extensions,
|
|
85
110
|
)
|
|
86
111
|
await self.refresh()
|
|
87
112
|
|
|
@@ -96,9 +121,6 @@ class FileWatcher:
|
|
|
96
121
|
await self._delete_document(Path(path))
|
|
97
122
|
|
|
98
123
|
async def refresh(self):
|
|
99
|
-
# Lazy import to avoid loading docling
|
|
100
|
-
from haiku.rag.reader import FileReader
|
|
101
|
-
|
|
102
124
|
# Delete orphaned documents in background if enabled
|
|
103
125
|
if self.delete_orphans:
|
|
104
126
|
logger.info("Starting orphan cleanup in background")
|
|
@@ -106,12 +128,14 @@ class FileWatcher:
|
|
|
106
128
|
|
|
107
129
|
# Create filter to apply same logic as observe()
|
|
108
130
|
filter = FileFilter(
|
|
109
|
-
ignore_patterns=self.ignore_patterns,
|
|
131
|
+
ignore_patterns=self.ignore_patterns,
|
|
132
|
+
include_patterns=self.include_patterns,
|
|
133
|
+
supported_extensions=self.supported_extensions,
|
|
110
134
|
)
|
|
111
135
|
|
|
112
136
|
for path in self.paths:
|
|
113
137
|
for f in Path(path).rglob("**/*"):
|
|
114
|
-
if f.is_file() and f.suffix in
|
|
138
|
+
if f.is_file() and f.suffix in self.supported_extensions:
|
|
115
139
|
# Apply pattern filters
|
|
116
140
|
if filter(Change.added, str(f)):
|
|
117
141
|
await self._upsert_document(f)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Shared client for docling-serve async API."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DoclingServeClient:
|
|
10
|
+
"""Client for docling-serve async workflow.
|
|
11
|
+
|
|
12
|
+
Handles the submit → poll → fetch pattern used by both conversion and chunking.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, base_url: str, api_key: str | None = None, timeout: float = 300):
|
|
16
|
+
self.base_url = base_url.rstrip("/")
|
|
17
|
+
self.api_key = api_key
|
|
18
|
+
self.timeout = timeout
|
|
19
|
+
|
|
20
|
+
def _get_headers(self) -> dict[str, str]:
|
|
21
|
+
"""Get headers for API requests."""
|
|
22
|
+
headers: dict[str, str] = {}
|
|
23
|
+
if self.api_key:
|
|
24
|
+
headers["X-Api-Key"] = self.api_key
|
|
25
|
+
return headers
|
|
26
|
+
|
|
27
|
+
async def submit_and_poll(
|
|
28
|
+
self,
|
|
29
|
+
endpoint: str,
|
|
30
|
+
files: dict[str, Any],
|
|
31
|
+
data: dict[str, Any],
|
|
32
|
+
name: str = "document",
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""Submit a task and poll until completion.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
endpoint: The async endpoint path (e.g., "/v1/convert/file/async")
|
|
38
|
+
files: Files to upload
|
|
39
|
+
data: Form data parameters
|
|
40
|
+
name: Name for error messages
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The result dictionary from the completed task
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If the task fails or service is unavailable
|
|
47
|
+
"""
|
|
48
|
+
headers = self._get_headers()
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
52
|
+
# Submit async task
|
|
53
|
+
submit_url = f"{self.base_url}{endpoint}"
|
|
54
|
+
response = await client.post(
|
|
55
|
+
submit_url,
|
|
56
|
+
files=files,
|
|
57
|
+
data=data,
|
|
58
|
+
headers=headers,
|
|
59
|
+
)
|
|
60
|
+
response.raise_for_status()
|
|
61
|
+
submit_result = response.json()
|
|
62
|
+
task_id = submit_result.get("task_id")
|
|
63
|
+
|
|
64
|
+
if not task_id:
|
|
65
|
+
raise ValueError("docling-serve did not return a task_id")
|
|
66
|
+
|
|
67
|
+
# Poll for completion
|
|
68
|
+
poll_url = f"{self.base_url}/v1/status/poll/{task_id}"
|
|
69
|
+
while True:
|
|
70
|
+
poll_response = await client.get(poll_url, headers=headers)
|
|
71
|
+
poll_response.raise_for_status()
|
|
72
|
+
poll_result = poll_response.json()
|
|
73
|
+
status = poll_result.get("task_status")
|
|
74
|
+
|
|
75
|
+
if status == "success":
|
|
76
|
+
break
|
|
77
|
+
elif status in ("failure", "error"):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"docling-serve task failed for {name}: {poll_result}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
await asyncio.sleep(1)
|
|
83
|
+
|
|
84
|
+
# Fetch result
|
|
85
|
+
result_url = f"{self.base_url}/v1/result/{task_id}"
|
|
86
|
+
result_response = await client.get(result_url, headers=headers)
|
|
87
|
+
result_response.raise_for_status()
|
|
88
|
+
return result_response.json()
|
|
89
|
+
|
|
90
|
+
except httpx.ConnectError as e:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Could not connect to docling-serve at {self.base_url}. "
|
|
93
|
+
f"Ensure the service is running and accessible. Error: {e}"
|
|
94
|
+
)
|
|
95
|
+
except httpx.TimeoutException as e:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Request to docling-serve timed out after {self.timeout}s. Error: {e}"
|
|
98
|
+
)
|
|
99
|
+
except httpx.HTTPStatusError as e:
|
|
100
|
+
if e.response.status_code == 401:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Authentication failed. Check your API key configuration."
|
|
103
|
+
)
|
|
104
|
+
raise ValueError(f"HTTP error from docling-serve: {e}")
|
|
105
|
+
except ValueError:
|
|
106
|
+
raise
|
|
107
|
+
except Exception as e:
|
|
108
|
+
raise ValueError(f"Failed to process via docling-serve: {e}")
|
haiku/rag/qa/__init__.py
CHANGED
|
@@ -1,33 +1,35 @@
|
|
|
1
1
|
from haiku.rag.client import HaikuRAG
|
|
2
2
|
from haiku.rag.config import AppConfig, Config
|
|
3
3
|
from haiku.rag.qa.agent import QuestionAnswerAgent
|
|
4
|
+
from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT
|
|
5
|
+
from haiku.rag.utils import build_prompt
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
def get_qa_agent(
|
|
7
9
|
client: HaikuRAG,
|
|
8
10
|
config: AppConfig = Config,
|
|
9
|
-
use_citations: bool = False,
|
|
10
11
|
system_prompt: str | None = None,
|
|
11
12
|
) -> QuestionAnswerAgent:
|
|
12
|
-
"""
|
|
13
|
-
Factory function to get a QA agent based on the configuration.
|
|
13
|
+
"""Factory function to get a QA agent based on the configuration.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
client: HaikuRAG client instance.
|
|
17
17
|
config: Configuration to use. Defaults to global Config.
|
|
18
|
-
|
|
19
|
-
system_prompt: Optional custom system prompt.
|
|
18
|
+
system_prompt: Optional custom system prompt (overrides config).
|
|
20
19
|
|
|
21
20
|
Returns:
|
|
22
21
|
A configured QuestionAnswerAgent instance.
|
|
23
22
|
"""
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
# Determine the base prompt: explicit > config > default
|
|
24
|
+
if system_prompt is None:
|
|
25
|
+
system_prompt = config.prompts.qa or QA_SYSTEM_PROMPT
|
|
26
|
+
|
|
27
|
+
# Prepend system_context if configured
|
|
28
|
+
system_prompt = build_prompt(system_prompt, config)
|
|
26
29
|
|
|
27
30
|
return QuestionAnswerAgent(
|
|
28
31
|
client=client,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
use_citations=use_citations,
|
|
32
|
+
model_config=config.qa.model,
|
|
33
|
+
config=config,
|
|
32
34
|
system_prompt=system_prompt,
|
|
33
35
|
)
|
haiku/rag/qa/agent.py
CHANGED
|
@@ -1,49 +1,38 @@
|
|
|
1
|
-
from pydantic import BaseModel
|
|
1
|
+
from pydantic import BaseModel
|
|
2
2
|
from pydantic_ai import Agent, RunContext
|
|
3
|
-
from pydantic_ai.
|
|
4
|
-
from pydantic_ai.providers.ollama import OllamaProvider
|
|
5
|
-
from pydantic_ai.providers.openai import OpenAIProvider
|
|
3
|
+
from pydantic_ai.output import ToolOutput
|
|
6
4
|
|
|
7
5
|
from haiku.rag.client import HaikuRAG
|
|
8
|
-
from haiku.rag.config import
|
|
9
|
-
from haiku.rag.
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
content: str = Field(description="The document text content")
|
|
14
|
-
score: float = Field(description="Relevance score (higher is more relevant)")
|
|
15
|
-
document_uri: str = Field(
|
|
16
|
-
description="Source title (if available) or URI/path of the document"
|
|
17
|
-
)
|
|
6
|
+
from haiku.rag.config.models import AppConfig, ModelConfig
|
|
7
|
+
from haiku.rag.graph.research.models import Citation, RawSearchAnswer, resolve_citations
|
|
8
|
+
from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT
|
|
9
|
+
from haiku.rag.store.models import SearchResult
|
|
10
|
+
from haiku.rag.utils import get_model
|
|
18
11
|
|
|
19
12
|
|
|
20
13
|
class Dependencies(BaseModel):
|
|
21
14
|
model_config = {"arbitrary_types_allowed": True}
|
|
22
15
|
client: HaikuRAG
|
|
16
|
+
search_results: list[SearchResult] = []
|
|
17
|
+
search_filter: str | None = None
|
|
23
18
|
|
|
24
19
|
|
|
25
20
|
class QuestionAnswerAgent:
|
|
26
21
|
def __init__(
|
|
27
22
|
self,
|
|
28
23
|
client: HaikuRAG,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
use_citations: bool = False,
|
|
32
|
-
q: float = 0.0,
|
|
24
|
+
model_config: ModelConfig,
|
|
25
|
+
config: AppConfig | None = None,
|
|
33
26
|
system_prompt: str | None = None,
|
|
34
27
|
):
|
|
35
28
|
self._client = client
|
|
36
|
-
|
|
37
|
-
if system_prompt is None:
|
|
38
|
-
system_prompt = (
|
|
39
|
-
QA_SYSTEM_PROMPT_WITH_CITATIONS if use_citations else QA_SYSTEM_PROMPT
|
|
40
|
-
)
|
|
41
|
-
model_obj = self._get_model(provider, model)
|
|
29
|
+
model_obj = get_model(model_config, config)
|
|
42
30
|
|
|
43
31
|
self._agent = Agent(
|
|
44
32
|
model=model_obj,
|
|
45
33
|
deps_type=Dependencies,
|
|
46
|
-
|
|
34
|
+
output_type=ToolOutput(RawSearchAnswer, max_retries=3),
|
|
35
|
+
instructions=system_prompt or QA_SYSTEM_PROMPT,
|
|
47
36
|
retries=3,
|
|
48
37
|
)
|
|
49
38
|
|
|
@@ -51,43 +40,36 @@ class QuestionAnswerAgent:
|
|
|
51
40
|
async def search_documents(
|
|
52
41
|
ctx: RunContext[Dependencies],
|
|
53
42
|
query: str,
|
|
54
|
-
limit: int =
|
|
55
|
-
) ->
|
|
56
|
-
"""Search the knowledge base for relevant documents.
|
|
57
|
-
search_results = await ctx.deps.client.search(query, limit=limit)
|
|
58
|
-
expanded_results = await ctx.deps.client.expand_context(search_results)
|
|
43
|
+
limit: int | None = None,
|
|
44
|
+
) -> str:
|
|
45
|
+
"""Search the knowledge base for relevant documents.
|
|
59
46
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
66
|
-
for chunk, score in expanded_results
|
|
67
|
-
]
|
|
68
|
-
|
|
69
|
-
def _get_model(self, provider: str, model: str):
|
|
70
|
-
"""Get the appropriate model object for the provider."""
|
|
71
|
-
if provider == "ollama":
|
|
72
|
-
return OpenAIChatModel(
|
|
73
|
-
model_name=model,
|
|
74
|
-
provider=OllamaProvider(
|
|
75
|
-
base_url=f"{Config.providers.ollama.base_url}/v1"
|
|
76
|
-
),
|
|
77
|
-
)
|
|
78
|
-
elif provider == "vllm":
|
|
79
|
-
return OpenAIChatModel(
|
|
80
|
-
model_name=model,
|
|
81
|
-
provider=OpenAIProvider(
|
|
82
|
-
base_url=f"{Config.providers.vllm.qa_base_url}/v1", api_key="none"
|
|
83
|
-
),
|
|
47
|
+
Returns results with chunk IDs and relevance scores.
|
|
48
|
+
Reference results by their chunk_id in cited_chunks.
|
|
49
|
+
"""
|
|
50
|
+
results = await ctx.deps.client.search(
|
|
51
|
+
query, limit=limit, filter=ctx.deps.search_filter
|
|
84
52
|
)
|
|
85
|
-
|
|
86
|
-
#
|
|
87
|
-
|
|
53
|
+
results = await ctx.deps.client.expand_context(results)
|
|
54
|
+
# Store results for citation resolution
|
|
55
|
+
ctx.deps.search_results = results
|
|
56
|
+
# Format with metadata for agent context
|
|
57
|
+
parts = [r.format_for_agent() for r in results]
|
|
58
|
+
return "\n\n".join(parts) if parts else "No results found."
|
|
59
|
+
|
|
60
|
+
async def answer(
|
|
61
|
+
self, question: str, filter: str | None = None
|
|
62
|
+
) -> tuple[str, list[Citation]]:
|
|
63
|
+
"""Answer a question using the RAG system.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
question: The question to answer
|
|
67
|
+
filter: SQL WHERE clause to filter documents
|
|
88
68
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple of (answer text, list of resolved citations)
|
|
71
|
+
"""
|
|
72
|
+
deps = Dependencies(client=self._client, search_filter=filter)
|
|
92
73
|
result = await self._agent.run(question, deps=deps)
|
|
93
|
-
|
|
74
|
+
citations = resolve_citations(result.output.cited_chunks, deps.search_results)
|
|
75
|
+
return result.output.answer, citations
|