haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (94) hide show
  1. haiku/rag/app.py +430 -72
  2. haiku/rag/chunkers/__init__.py +31 -0
  3. haiku/rag/chunkers/base.py +31 -0
  4. haiku/rag/chunkers/docling_local.py +164 -0
  5. haiku/rag/chunkers/docling_serve.py +179 -0
  6. haiku/rag/cli.py +207 -24
  7. haiku/rag/cli_chat.py +489 -0
  8. haiku/rag/client.py +1251 -266
  9. haiku/rag/config/__init__.py +16 -10
  10. haiku/rag/config/loader.py +5 -44
  11. haiku/rag/config/models.py +126 -17
  12. haiku/rag/converters/__init__.py +31 -0
  13. haiku/rag/converters/base.py +63 -0
  14. haiku/rag/converters/docling_local.py +193 -0
  15. haiku/rag/converters/docling_serve.py +229 -0
  16. haiku/rag/converters/text_utils.py +237 -0
  17. haiku/rag/embeddings/__init__.py +123 -24
  18. haiku/rag/embeddings/voyageai.py +175 -20
  19. haiku/rag/graph/__init__.py +0 -11
  20. haiku/rag/graph/agui/__init__.py +8 -2
  21. haiku/rag/graph/agui/cli_renderer.py +1 -1
  22. haiku/rag/graph/agui/emitter.py +219 -31
  23. haiku/rag/graph/agui/server.py +20 -62
  24. haiku/rag/graph/agui/stream.py +1 -2
  25. haiku/rag/graph/research/__init__.py +5 -2
  26. haiku/rag/graph/research/dependencies.py +12 -126
  27. haiku/rag/graph/research/graph.py +390 -135
  28. haiku/rag/graph/research/models.py +91 -112
  29. haiku/rag/graph/research/prompts.py +99 -91
  30. haiku/rag/graph/research/state.py +35 -27
  31. haiku/rag/inspector/__init__.py +8 -0
  32. haiku/rag/inspector/app.py +259 -0
  33. haiku/rag/inspector/widgets/__init__.py +6 -0
  34. haiku/rag/inspector/widgets/chunk_list.py +100 -0
  35. haiku/rag/inspector/widgets/context_modal.py +89 -0
  36. haiku/rag/inspector/widgets/detail_view.py +130 -0
  37. haiku/rag/inspector/widgets/document_list.py +75 -0
  38. haiku/rag/inspector/widgets/info_modal.py +209 -0
  39. haiku/rag/inspector/widgets/search_modal.py +183 -0
  40. haiku/rag/inspector/widgets/visual_modal.py +126 -0
  41. haiku/rag/mcp.py +106 -102
  42. haiku/rag/monitor.py +33 -9
  43. haiku/rag/providers/__init__.py +5 -0
  44. haiku/rag/providers/docling_serve.py +108 -0
  45. haiku/rag/qa/__init__.py +12 -10
  46. haiku/rag/qa/agent.py +43 -61
  47. haiku/rag/qa/prompts.py +35 -57
  48. haiku/rag/reranking/__init__.py +9 -6
  49. haiku/rag/reranking/base.py +1 -1
  50. haiku/rag/reranking/cohere.py +5 -4
  51. haiku/rag/reranking/mxbai.py +5 -2
  52. haiku/rag/reranking/vllm.py +3 -4
  53. haiku/rag/reranking/zeroentropy.py +6 -5
  54. haiku/rag/store/__init__.py +2 -1
  55. haiku/rag/store/engine.py +242 -42
  56. haiku/rag/store/exceptions.py +4 -0
  57. haiku/rag/store/models/__init__.py +8 -2
  58. haiku/rag/store/models/chunk.py +190 -0
  59. haiku/rag/store/models/document.py +46 -0
  60. haiku/rag/store/repositories/chunk.py +141 -121
  61. haiku/rag/store/repositories/document.py +25 -84
  62. haiku/rag/store/repositories/settings.py +11 -14
  63. haiku/rag/store/upgrades/__init__.py +19 -3
  64. haiku/rag/store/upgrades/v0_10_1.py +1 -1
  65. haiku/rag/store/upgrades/v0_19_6.py +65 -0
  66. haiku/rag/store/upgrades/v0_20_0.py +68 -0
  67. haiku/rag/store/upgrades/v0_23_1.py +100 -0
  68. haiku/rag/store/upgrades/v0_9_3.py +3 -3
  69. haiku/rag/utils.py +371 -146
  70. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
  71. haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
  72. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
  73. haiku/rag/chunker.py +0 -65
  74. haiku/rag/embeddings/base.py +0 -25
  75. haiku/rag/embeddings/ollama.py +0 -28
  76. haiku/rag/embeddings/openai.py +0 -26
  77. haiku/rag/embeddings/vllm.py +0 -29
  78. haiku/rag/graph/agui/events.py +0 -254
  79. haiku/rag/graph/common/__init__.py +0 -5
  80. haiku/rag/graph/common/models.py +0 -42
  81. haiku/rag/graph/common/nodes.py +0 -265
  82. haiku/rag/graph/common/prompts.py +0 -46
  83. haiku/rag/graph/common/utils.py +0 -44
  84. haiku/rag/graph/deep_qa/__init__.py +0 -1
  85. haiku/rag/graph/deep_qa/dependencies.py +0 -27
  86. haiku/rag/graph/deep_qa/graph.py +0 -243
  87. haiku/rag/graph/deep_qa/models.py +0 -20
  88. haiku/rag/graph/deep_qa/prompts.py +0 -59
  89. haiku/rag/graph/deep_qa/state.py +0 -56
  90. haiku/rag/graph/research/common.py +0 -87
  91. haiku/rag/reader.py +0 -135
  92. haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
  93. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
  94. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/mcp.py CHANGED
@@ -7,12 +7,8 @@ from pydantic import BaseModel
7
7
  from haiku.rag.client import HaikuRAG
8
8
  from haiku.rag.config import AppConfig, Config
9
9
  from haiku.rag.graph.research.models import ResearchReport
10
-
11
-
12
- class SearchResult(BaseModel):
13
- document_id: str
14
- content: str
15
- score: float
10
+ from haiku.rag.store.models import SearchResult
11
+ from haiku.rag.utils import format_citations
16
12
 
17
13
 
18
14
  class DocumentResult(BaseModel):
@@ -25,84 +21,92 @@ class DocumentResult(BaseModel):
25
21
  updated_at: str
26
22
 
27
23
 
28
- def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
29
- """Create an MCP server with the specified database path."""
30
- mcp = FastMCP("haiku-rag")
24
+ def create_mcp_server(
25
+ db_path: Path, config: AppConfig = Config, read_only: bool = False
26
+ ) -> FastMCP:
27
+ """Create an MCP server with the specified database path.
31
28
 
32
- @mcp.tool()
33
- async def add_document_from_file(
34
- file_path: str,
35
- metadata: dict[str, Any] | None = None,
36
- title: str | None = None,
37
- ) -> str | None:
38
- """Add a document to the RAG system from a file path."""
39
- try:
40
- async with HaikuRAG(db_path, config=config) as rag:
41
- result = await rag.create_document_from_source(
42
- Path(file_path), title=title, metadata=metadata or {}
43
- )
44
- # Handle both single document and list of documents (directories)
45
- if isinstance(result, list):
46
- return result[0].id if result else None
47
- return result.id
48
- except Exception:
49
- return None
50
-
51
- @mcp.tool()
52
- async def add_document_from_url(
53
- url: str, metadata: dict[str, Any] | None = None, title: str | None = None
54
- ) -> str | None:
55
- """Add a document to the RAG system from a URL."""
56
- try:
57
- async with HaikuRAG(db_path, config=config) as rag:
58
- result = await rag.create_document_from_source(
59
- url, title=title, metadata=metadata or {}
60
- )
61
- # Handle both single document and list of documents
62
- if isinstance(result, list):
63
- return result[0].id if result else None
64
- return result.id
65
- except Exception:
66
- return None
67
-
68
- @mcp.tool()
69
- async def add_document_from_text(
70
- content: str,
71
- uri: str | None = None,
72
- metadata: dict[str, Any] | None = None,
73
- title: str | None = None,
74
- ) -> str | None:
75
- """Add a document to the RAG system from text content."""
76
- try:
77
- async with HaikuRAG(db_path, config=config) as rag:
78
- document = await rag.create_document(
79
- content, uri, title=title, metadata=metadata or {}
80
- )
81
- return document.id
82
- except Exception:
83
- return None
29
+ Args:
30
+ db_path: Path to the database file.
31
+ config: Configuration to use.
32
+ read_only: If True, write tools (add_document_*, delete_document) are not registered.
33
+ """
34
+ mcp = FastMCP("haiku-rag")
84
35
 
36
+ # Write tools - only registered when not in read-only mode
37
+ if not read_only:
38
+
39
+ @mcp.tool()
40
+ async def add_document_from_file(
41
+ file_path: str,
42
+ metadata: dict[str, Any] | None = None,
43
+ title: str | None = None,
44
+ ) -> str | None:
45
+ """Add a document to the RAG system from a file path."""
46
+ try:
47
+ async with HaikuRAG(db_path, config=config) as rag:
48
+ result = await rag.create_document_from_source(
49
+ Path(file_path), title=title, metadata=metadata or {}
50
+ )
51
+ # Handle both single document and list of documents (directories)
52
+ if isinstance(result, list):
53
+ return result[0].id if result else None
54
+ return result.id
55
+ except Exception:
56
+ return None
57
+
58
+ @mcp.tool()
59
+ async def add_document_from_url(
60
+ url: str, metadata: dict[str, Any] | None = None, title: str | None = None
61
+ ) -> str | None:
62
+ """Add a document to the RAG system from a URL."""
63
+ try:
64
+ async with HaikuRAG(db_path, config=config) as rag:
65
+ result = await rag.create_document_from_source(
66
+ url, title=title, metadata=metadata or {}
67
+ )
68
+ # Handle both single document and list of documents
69
+ if isinstance(result, list):
70
+ return result[0].id if result else None
71
+ return result.id
72
+ except Exception:
73
+ return None
74
+
75
+ @mcp.tool()
76
+ async def add_document_from_text(
77
+ content: str,
78
+ uri: str | None = None,
79
+ metadata: dict[str, Any] | None = None,
80
+ title: str | None = None,
81
+ ) -> str | None:
82
+ """Add a document to the RAG system from text content."""
83
+ try:
84
+ async with HaikuRAG(db_path, config=config) as rag:
85
+ document = await rag.create_document(
86
+ content, uri, title=title, metadata=metadata or {}
87
+ )
88
+ return document.id
89
+ except Exception:
90
+ return None
91
+
92
+ @mcp.tool()
93
+ async def delete_document(document_id: str) -> bool:
94
+ """Delete a document by its ID."""
95
+ try:
96
+ async with HaikuRAG(db_path, config=config) as rag:
97
+ return await rag.delete_document(document_id)
98
+ except Exception:
99
+ return False
100
+
101
+ # Read tools - always registered
85
102
  @mcp.tool()
86
- async def search_documents(query: str, limit: int = 5) -> list[SearchResult]:
103
+ async def search_documents(
104
+ query: str, limit: int | None = None
105
+ ) -> list[SearchResult]:
87
106
  """Search the RAG system for documents using hybrid search (vector similarity + full-text search)."""
88
107
  try:
89
- async with HaikuRAG(db_path, config=config) as rag:
90
- results = await rag.search(query, limit)
91
-
92
- search_results = []
93
- for chunk, score in results:
94
- assert chunk.document_id is not None, (
95
- "Chunk document_id should not be None in search results"
96
- )
97
- search_results.append(
98
- SearchResult(
99
- document_id=chunk.document_id,
100
- content=chunk.content,
101
- score=score,
102
- )
103
- )
104
-
105
- return search_results
108
+ async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
109
+ return await rag.search(query, limit=limit)
106
110
  except Exception:
107
111
  return []
108
112
 
@@ -110,7 +114,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
110
114
  async def get_document(document_id: str) -> DocumentResult | None:
111
115
  """Get a document by its ID."""
112
116
  try:
113
- async with HaikuRAG(db_path, config=config) as rag:
117
+ async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
114
118
  document = await rag.get_document_by_id(document_id)
115
119
 
116
120
  if document is None:
@@ -145,7 +149,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
145
149
  List of DocumentResult instances matching the criteria.
146
150
  """
147
151
  try:
148
- async with HaikuRAG(db_path, config=config) as rag:
152
+ async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
149
153
  documents = await rag.list_documents(limit, offset, filter)
150
154
 
151
155
  return [
@@ -163,15 +167,6 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
163
167
  except Exception:
164
168
  return []
165
169
 
166
- @mcp.tool()
167
- async def delete_document(document_id: str) -> bool:
168
- """Delete a document by its ID."""
169
- try:
170
- async with HaikuRAG(db_path, config=config) as rag:
171
- return await rag.delete_document(document_id)
172
- except Exception:
173
- return False
174
-
175
170
  @mcp.tool()
176
171
  async def ask_question(
177
172
  question: str,
@@ -189,23 +184,32 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
189
184
  The answer as a string.
190
185
  """
191
186
  try:
192
- async with HaikuRAG(db_path, config=config) as rag:
187
+ async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
193
188
  if deep:
194
- from haiku.rag.graph.deep_qa.dependencies import DeepQAContext
195
- from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
196
- from haiku.rag.graph.deep_qa.state import DeepQADeps, DeepQAState
189
+ from haiku.rag.graph.research.dependencies import ResearchContext
190
+ from haiku.rag.graph.research.graph import build_research_graph
191
+ from haiku.rag.graph.research.state import (
192
+ ResearchDeps,
193
+ ResearchState,
194
+ )
197
195
 
198
- graph = build_deep_qa_graph(config=config)
199
- context = DeepQAContext(
200
- original_question=question, use_citations=cite
196
+ graph = build_research_graph(config=config)
197
+ context = ResearchContext(original_question=question)
198
+ state = ResearchState.from_config(
199
+ context=context,
200
+ config=config,
201
+ max_iterations=2,
202
+ confidence_threshold=0.0,
201
203
  )
202
- state = DeepQAState.from_config(context=context, config=config)
203
- deps = DeepQADeps(client=rag)
204
+ deps = ResearchDeps(client=rag)
204
205
 
205
206
  result = await graph.run(state=state, deps=deps)
206
- answer = result.answer
207
+ answer = result.executive_summary
208
+ citations = []
207
209
  else:
208
- answer = await rag.ask(question, cite=cite)
210
+ answer, citations = await rag.ask(question)
211
+ if cite and citations:
212
+ answer += "\n\n" + format_citations(citations)
209
213
  return answer
210
214
  except Exception as e:
211
215
  return f"Error answering question: {e!s}"
@@ -230,7 +234,7 @@ def create_mcp_server(db_path: Path, config: AppConfig = Config) -> FastMCP:
230
234
  from haiku.rag.graph.research.graph import build_research_graph
231
235
  from haiku.rag.graph.research.state import ResearchDeps, ResearchState
232
236
 
233
- async with HaikuRAG(db_path, config=config) as rag:
237
+ async with HaikuRAG(db_path, config=config, read_only=read_only) as rag:
234
238
  graph = build_research_graph(config=config)
235
239
  context = ResearchContext(original_question=question)
236
240
  state = ResearchState.from_config(context=context, config=config)
haiku/rag/monitor.py CHANGED
@@ -23,11 +23,19 @@ class FileFilter(DefaultFilter):
23
23
  *,
24
24
  ignore_patterns: list[str] | None = None,
25
25
  include_patterns: list[str] | None = None,
26
+ supported_extensions: list[str] | None = None,
26
27
  ) -> None:
27
- # Lazy import to avoid loading docling
28
- from haiku.rag.reader import FileReader
28
+ if supported_extensions is None:
29
+ # Default to docling-local extensions if not provided
30
+ from haiku.rag.converters.docling_local import DoclingLocalConverter
31
+ from haiku.rag.converters.text_utils import TextFileHandler
32
+
33
+ supported_extensions = (
34
+ DoclingLocalConverter.docling_extensions
35
+ + TextFileHandler.text_extensions
36
+ )
29
37
 
30
- self.extensions = tuple(FileReader.extensions)
38
+ self.extensions = tuple(supported_extensions)
31
39
  self.ignore_spec = (
32
40
  pathspec.PathSpec.from_lines(GitWildMatchPattern, ignore_patterns)
33
41
  if ignore_patterns
@@ -72,16 +80,33 @@ class FileWatcher:
72
80
  client: HaikuRAG,
73
81
  config: AppConfig = Config,
74
82
  ):
83
+ from haiku.rag.converters import get_converter
84
+
75
85
  self.paths = config.monitor.directories
76
86
  self.client = client
77
87
  self.ignore_patterns = config.monitor.ignore_patterns or None
78
88
  self.include_patterns = config.monitor.include_patterns or None
79
89
  self.delete_orphans = config.monitor.delete_orphans
90
+ self.supported_extensions = get_converter(config).supported_extensions
80
91
 
81
92
  async def observe(self):
93
+ if not self.paths:
94
+ logger.warning("No directories configured for monitoring")
95
+ return
96
+
97
+ # Validate all paths exist before attempting to watch
98
+ missing_paths = [p for p in self.paths if not Path(p).exists()]
99
+ if missing_paths:
100
+ raise FileNotFoundError(
101
+ f"Monitor directories do not exist: {missing_paths}. "
102
+ "Check your haiku.rag.yaml configuration."
103
+ )
104
+
82
105
  logger.info(f"Watching files in {self.paths}")
83
106
  filter = FileFilter(
84
- ignore_patterns=self.ignore_patterns, include_patterns=self.include_patterns
107
+ ignore_patterns=self.ignore_patterns,
108
+ include_patterns=self.include_patterns,
109
+ supported_extensions=self.supported_extensions,
85
110
  )
86
111
  await self.refresh()
87
112
 
@@ -96,9 +121,6 @@ class FileWatcher:
96
121
  await self._delete_document(Path(path))
97
122
 
98
123
  async def refresh(self):
99
- # Lazy import to avoid loading docling
100
- from haiku.rag.reader import FileReader
101
-
102
124
  # Delete orphaned documents in background if enabled
103
125
  if self.delete_orphans:
104
126
  logger.info("Starting orphan cleanup in background")
@@ -106,12 +128,14 @@ class FileWatcher:
106
128
 
107
129
  # Create filter to apply same logic as observe()
108
130
  filter = FileFilter(
109
- ignore_patterns=self.ignore_patterns, include_patterns=self.include_patterns
131
+ ignore_patterns=self.ignore_patterns,
132
+ include_patterns=self.include_patterns,
133
+ supported_extensions=self.supported_extensions,
110
134
  )
111
135
 
112
136
  for path in self.paths:
113
137
  for f in Path(path).rglob("**/*"):
114
- if f.is_file() and f.suffix in FileReader.extensions:
138
+ if f.is_file() and f.suffix in self.supported_extensions:
115
139
  # Apply pattern filters
116
140
  if filter(Change.added, str(f)):
117
141
  await self._upsert_document(f)
@@ -0,0 +1,5 @@
1
+ """Provider clients for external services."""
2
+
3
+ from haiku.rag.providers.docling_serve import DoclingServeClient
4
+
5
+ __all__ = ["DoclingServeClient"]
@@ -0,0 +1,108 @@
1
+ """Shared client for docling-serve async API."""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import httpx
7
+
8
+
9
+ class DoclingServeClient:
10
+ """Client for docling-serve async workflow.
11
+
12
+ Handles the submit → poll → fetch pattern used by both conversion and chunking.
13
+ """
14
+
15
+ def __init__(self, base_url: str, api_key: str | None = None, timeout: float = 300):
16
+ self.base_url = base_url.rstrip("/")
17
+ self.api_key = api_key
18
+ self.timeout = timeout
19
+
20
+ def _get_headers(self) -> dict[str, str]:
21
+ """Get headers for API requests."""
22
+ headers: dict[str, str] = {}
23
+ if self.api_key:
24
+ headers["X-Api-Key"] = self.api_key
25
+ return headers
26
+
27
+ async def submit_and_poll(
28
+ self,
29
+ endpoint: str,
30
+ files: dict[str, Any],
31
+ data: dict[str, Any],
32
+ name: str = "document",
33
+ ) -> dict[str, Any]:
34
+ """Submit a task and poll until completion.
35
+
36
+ Args:
37
+ endpoint: The async endpoint path (e.g., "/v1/convert/file/async")
38
+ files: Files to upload
39
+ data: Form data parameters
40
+ name: Name for error messages
41
+
42
+ Returns:
43
+ The result dictionary from the completed task
44
+
45
+ Raises:
46
+ ValueError: If the task fails or service is unavailable
47
+ """
48
+ headers = self._get_headers()
49
+
50
+ try:
51
+ async with httpx.AsyncClient(timeout=self.timeout) as client:
52
+ # Submit async task
53
+ submit_url = f"{self.base_url}{endpoint}"
54
+ response = await client.post(
55
+ submit_url,
56
+ files=files,
57
+ data=data,
58
+ headers=headers,
59
+ )
60
+ response.raise_for_status()
61
+ submit_result = response.json()
62
+ task_id = submit_result.get("task_id")
63
+
64
+ if not task_id:
65
+ raise ValueError("docling-serve did not return a task_id")
66
+
67
+ # Poll for completion
68
+ poll_url = f"{self.base_url}/v1/status/poll/{task_id}"
69
+ while True:
70
+ poll_response = await client.get(poll_url, headers=headers)
71
+ poll_response.raise_for_status()
72
+ poll_result = poll_response.json()
73
+ status = poll_result.get("task_status")
74
+
75
+ if status == "success":
76
+ break
77
+ elif status in ("failure", "error"):
78
+ raise ValueError(
79
+ f"docling-serve task failed for {name}: {poll_result}"
80
+ )
81
+
82
+ await asyncio.sleep(1)
83
+
84
+ # Fetch result
85
+ result_url = f"{self.base_url}/v1/result/{task_id}"
86
+ result_response = await client.get(result_url, headers=headers)
87
+ result_response.raise_for_status()
88
+ return result_response.json()
89
+
90
+ except httpx.ConnectError as e:
91
+ raise ValueError(
92
+ f"Could not connect to docling-serve at {self.base_url}. "
93
+ f"Ensure the service is running and accessible. Error: {e}"
94
+ )
95
+ except httpx.TimeoutException as e:
96
+ raise ValueError(
97
+ f"Request to docling-serve timed out after {self.timeout}s. Error: {e}"
98
+ )
99
+ except httpx.HTTPStatusError as e:
100
+ if e.response.status_code == 401:
101
+ raise ValueError(
102
+ "Authentication failed. Check your API key configuration."
103
+ )
104
+ raise ValueError(f"HTTP error from docling-serve: {e}")
105
+ except ValueError:
106
+ raise
107
+ except Exception as e:
108
+ raise ValueError(f"Failed to process via docling-serve: {e}")
haiku/rag/qa/__init__.py CHANGED
@@ -1,33 +1,35 @@
1
1
  from haiku.rag.client import HaikuRAG
2
2
  from haiku.rag.config import AppConfig, Config
3
3
  from haiku.rag.qa.agent import QuestionAnswerAgent
4
+ from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT
5
+ from haiku.rag.utils import build_prompt
4
6
 
5
7
 
6
8
  def get_qa_agent(
7
9
  client: HaikuRAG,
8
10
  config: AppConfig = Config,
9
- use_citations: bool = False,
10
11
  system_prompt: str | None = None,
11
12
  ) -> QuestionAnswerAgent:
12
- """
13
- Factory function to get a QA agent based on the configuration.
13
+ """Factory function to get a QA agent based on the configuration.
14
14
 
15
15
  Args:
16
16
  client: HaikuRAG client instance.
17
17
  config: Configuration to use. Defaults to global Config.
18
- use_citations: Whether to include citations in responses.
19
- system_prompt: Optional custom system prompt.
18
+ system_prompt: Optional custom system prompt (overrides config).
20
19
 
21
20
  Returns:
22
21
  A configured QuestionAnswerAgent instance.
23
22
  """
24
- provider = config.qa.provider
25
- model_name = config.qa.model
23
+ # Determine the base prompt: explicit > config > default
24
+ if system_prompt is None:
25
+ system_prompt = config.prompts.qa or QA_SYSTEM_PROMPT
26
+
27
+ # Prepend system_context if configured
28
+ system_prompt = build_prompt(system_prompt, config)
26
29
 
27
30
  return QuestionAnswerAgent(
28
31
  client=client,
29
- provider=provider,
30
- model=model_name,
31
- use_citations=use_citations,
32
+ model_config=config.qa.model,
33
+ config=config,
32
34
  system_prompt=system_prompt,
33
35
  )
haiku/rag/qa/agent.py CHANGED
@@ -1,49 +1,38 @@
1
- from pydantic import BaseModel, Field
1
+ from pydantic import BaseModel
2
2
  from pydantic_ai import Agent, RunContext
3
- from pydantic_ai.models.openai import OpenAIChatModel
4
- from pydantic_ai.providers.ollama import OllamaProvider
5
- from pydantic_ai.providers.openai import OpenAIProvider
3
+ from pydantic_ai.output import ToolOutput
6
4
 
7
5
  from haiku.rag.client import HaikuRAG
8
- from haiku.rag.config import Config
9
- from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT, QA_SYSTEM_PROMPT_WITH_CITATIONS
10
-
11
-
12
- class SearchResult(BaseModel):
13
- content: str = Field(description="The document text content")
14
- score: float = Field(description="Relevance score (higher is more relevant)")
15
- document_uri: str = Field(
16
- description="Source title (if available) or URI/path of the document"
17
- )
6
+ from haiku.rag.config.models import AppConfig, ModelConfig
7
+ from haiku.rag.graph.research.models import Citation, RawSearchAnswer, resolve_citations
8
+ from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT
9
+ from haiku.rag.store.models import SearchResult
10
+ from haiku.rag.utils import get_model
18
11
 
19
12
 
20
13
  class Dependencies(BaseModel):
21
14
  model_config = {"arbitrary_types_allowed": True}
22
15
  client: HaikuRAG
16
+ search_results: list[SearchResult] = []
17
+ search_filter: str | None = None
23
18
 
24
19
 
25
20
  class QuestionAnswerAgent:
26
21
  def __init__(
27
22
  self,
28
23
  client: HaikuRAG,
29
- provider: str,
30
- model: str,
31
- use_citations: bool = False,
32
- q: float = 0.0,
24
+ model_config: ModelConfig,
25
+ config: AppConfig | None = None,
33
26
  system_prompt: str | None = None,
34
27
  ):
35
28
  self._client = client
36
-
37
- if system_prompt is None:
38
- system_prompt = (
39
- QA_SYSTEM_PROMPT_WITH_CITATIONS if use_citations else QA_SYSTEM_PROMPT
40
- )
41
- model_obj = self._get_model(provider, model)
29
+ model_obj = get_model(model_config, config)
42
30
 
43
31
  self._agent = Agent(
44
32
  model=model_obj,
45
33
  deps_type=Dependencies,
46
- system_prompt=system_prompt,
34
+ output_type=ToolOutput(RawSearchAnswer, max_retries=3),
35
+ instructions=system_prompt or QA_SYSTEM_PROMPT,
47
36
  retries=3,
48
37
  )
49
38
 
@@ -51,43 +40,36 @@ class QuestionAnswerAgent:
51
40
  async def search_documents(
52
41
  ctx: RunContext[Dependencies],
53
42
  query: str,
54
- limit: int = 3,
55
- ) -> list[SearchResult]:
56
- """Search the knowledge base for relevant documents."""
57
- search_results = await ctx.deps.client.search(query, limit=limit)
58
- expanded_results = await ctx.deps.client.expand_context(search_results)
43
+ limit: int | None = None,
44
+ ) -> str:
45
+ """Search the knowledge base for relevant documents.
59
46
 
60
- return [
61
- SearchResult(
62
- content=chunk.content,
63
- score=score,
64
- document_uri=(chunk.document_title or chunk.document_uri or ""),
65
- )
66
- for chunk, score in expanded_results
67
- ]
68
-
69
- def _get_model(self, provider: str, model: str):
70
- """Get the appropriate model object for the provider."""
71
- if provider == "ollama":
72
- return OpenAIChatModel(
73
- model_name=model,
74
- provider=OllamaProvider(
75
- base_url=f"{Config.providers.ollama.base_url}/v1"
76
- ),
77
- )
78
- elif provider == "vllm":
79
- return OpenAIChatModel(
80
- model_name=model,
81
- provider=OpenAIProvider(
82
- base_url=f"{Config.providers.vllm.qa_base_url}/v1", api_key="none"
83
- ),
47
+ Returns results with chunk IDs and relevance scores.
48
+ Reference results by their chunk_id in cited_chunks.
49
+ """
50
+ results = await ctx.deps.client.search(
51
+ query, limit=limit, filter=ctx.deps.search_filter
84
52
  )
85
- else:
86
- # For all other providers, use the provider:model format
87
- return f"{provider}:{model}"
53
+ results = await ctx.deps.client.expand_context(results)
54
+ # Store results for citation resolution
55
+ ctx.deps.search_results = results
56
+ # Format with metadata for agent context
57
+ parts = [r.format_for_agent() for r in results]
58
+ return "\n\n".join(parts) if parts else "No results found."
59
+
60
+ async def answer(
61
+ self, question: str, filter: str | None = None
62
+ ) -> tuple[str, list[Citation]]:
63
+ """Answer a question using the RAG system.
64
+
65
+ Args:
66
+ question: The question to answer
67
+ filter: SQL WHERE clause to filter documents
88
68
 
89
- async def answer(self, question: str) -> str:
90
- """Answer a question using the RAG system."""
91
- deps = Dependencies(client=self._client)
69
+ Returns:
70
+ Tuple of (answer text, list of resolved citations)
71
+ """
72
+ deps = Dependencies(client=self._client, search_filter=filter)
92
73
  result = await self._agent.run(question, deps=deps)
93
- return result.output
74
+ citations = resolve_citations(result.output.cited_chunks, deps.search_results)
75
+ return result.output.answer, citations