haiku.rag 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +152 -28
- haiku/rag/cli.py +72 -2
- haiku/rag/migration.py +2 -2
- haiku/rag/research/__init__.py +8 -0
- haiku/rag/research/common.py +71 -6
- haiku/rag/research/dependencies.py +179 -11
- haiku/rag/research/graph.py +5 -3
- haiku/rag/research/models.py +134 -1
- haiku/rag/research/nodes/analysis.py +181 -0
- haiku/rag/research/nodes/plan.py +16 -9
- haiku/rag/research/nodes/search.py +14 -11
- haiku/rag/research/nodes/synthesize.py +7 -3
- haiku/rag/research/prompts.py +67 -28
- haiku/rag/research/state.py +11 -4
- haiku/rag/research/stream.py +177 -0
- haiku/rag/store/__init__.py +1 -1
- haiku/rag/store/models/__init__.py +1 -1
- haiku/rag/utils.py +34 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/METADATA +34 -14
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/RECORD +23 -22
- haiku/rag/research/nodes/evaluate.py +0 -80
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/app.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from importlib.metadata import version as pkg_version
|
|
2
4
|
from pathlib import Path
|
|
3
5
|
|
|
4
6
|
from rich.console import Console
|
|
@@ -16,6 +18,7 @@ from haiku.rag.research.graph import (
|
|
|
16
18
|
ResearchState,
|
|
17
19
|
build_research_graph,
|
|
18
20
|
)
|
|
21
|
+
from haiku.rag.research.stream import stream_research_graph
|
|
19
22
|
from haiku.rag.store.models.chunk import Chunk
|
|
20
23
|
from haiku.rag.store.models.document import Document
|
|
21
24
|
|
|
@@ -25,26 +28,141 @@ class HaikuRAGApp:
|
|
|
25
28
|
self.db_path = db_path
|
|
26
29
|
self.console = Console()
|
|
27
30
|
|
|
31
|
+
async def info(self):
|
|
32
|
+
"""Display read-only information about the database without modifying it."""
|
|
33
|
+
|
|
34
|
+
import lancedb
|
|
35
|
+
|
|
36
|
+
# Basic: show path
|
|
37
|
+
self.console.print("[bold]haiku.rag database info[/bold]")
|
|
38
|
+
self.console.print(
|
|
39
|
+
f" [repr.attrib_name]path[/repr.attrib_name]: {self.db_path}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if not self.db_path.exists():
|
|
43
|
+
self.console.print("[red]Database path does not exist.[/red]")
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
# Connect without going through Store to avoid upgrades/validation writes
|
|
47
|
+
try:
|
|
48
|
+
db = lancedb.connect(self.db_path)
|
|
49
|
+
table_names = set(db.table_names())
|
|
50
|
+
except Exception as e:
|
|
51
|
+
self.console.print(f"[red]Failed to open database: {e}[/red]")
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
ldb_version = pkg_version("lancedb")
|
|
56
|
+
except Exception:
|
|
57
|
+
ldb_version = "unknown"
|
|
58
|
+
try:
|
|
59
|
+
hr_version = pkg_version("haiku.rag")
|
|
60
|
+
except Exception:
|
|
61
|
+
hr_version = "unknown"
|
|
62
|
+
try:
|
|
63
|
+
docling_version = pkg_version("docling")
|
|
64
|
+
except Exception:
|
|
65
|
+
docling_version = "unknown"
|
|
66
|
+
|
|
67
|
+
# Read settings (if present) to find stored haiku.rag version and embedding config
|
|
68
|
+
stored_version = "unknown"
|
|
69
|
+
embed_provider: str | None = None
|
|
70
|
+
embed_model: str | None = None
|
|
71
|
+
vector_dim: int | None = None
|
|
72
|
+
|
|
73
|
+
if "settings" in table_names:
|
|
74
|
+
settings_tbl = db.open_table("settings")
|
|
75
|
+
arrow = settings_tbl.search().where("id = 'settings'").limit(1).to_arrow()
|
|
76
|
+
rows = arrow.to_pylist() if arrow is not None else []
|
|
77
|
+
if rows:
|
|
78
|
+
raw = rows[0].get("settings") or "{}"
|
|
79
|
+
data = json.loads(raw) if isinstance(raw, str) else (raw or {})
|
|
80
|
+
stored_version = str(data.get("version", stored_version))
|
|
81
|
+
embed_provider = data.get("EMBEDDINGS_PROVIDER")
|
|
82
|
+
embed_model = data.get("EMBEDDINGS_MODEL")
|
|
83
|
+
vector_dim = (
|
|
84
|
+
int(data.get("EMBEDDINGS_VECTOR_DIM")) # pyright: ignore[reportArgumentType]
|
|
85
|
+
if data.get("EMBEDDINGS_VECTOR_DIM") is not None
|
|
86
|
+
else None
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
num_docs = 0
|
|
90
|
+
if "documents" in table_names:
|
|
91
|
+
docs_tbl = db.open_table("documents")
|
|
92
|
+
num_docs = int(docs_tbl.count_rows()) # type: ignore[attr-defined]
|
|
93
|
+
|
|
94
|
+
# Table versions per table (direct API)
|
|
95
|
+
doc_versions = (
|
|
96
|
+
len(list(db.open_table("documents").list_versions()))
|
|
97
|
+
if "documents" in table_names
|
|
98
|
+
else 0
|
|
99
|
+
)
|
|
100
|
+
chunk_versions = (
|
|
101
|
+
len(list(db.open_table("chunks").list_versions()))
|
|
102
|
+
if "chunks" in table_names
|
|
103
|
+
else 0
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
self.console.print(
|
|
107
|
+
f" [repr.attrib_name]haiku.rag version (db)[/repr.attrib_name]: {stored_version}"
|
|
108
|
+
)
|
|
109
|
+
if embed_provider or embed_model or vector_dim:
|
|
110
|
+
provider_part = embed_provider or "unknown"
|
|
111
|
+
model_part = embed_model or "unknown"
|
|
112
|
+
dim_part = f"{vector_dim}" if vector_dim is not None else "unknown"
|
|
113
|
+
self.console.print(
|
|
114
|
+
" [repr.attrib_name]embeddings[/repr.attrib_name]: "
|
|
115
|
+
f"{provider_part}/{model_part} (dim: {dim_part})"
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
self.console.print(
|
|
119
|
+
" [repr.attrib_name]embeddings[/repr.attrib_name]: unknown"
|
|
120
|
+
)
|
|
121
|
+
self.console.print(
|
|
122
|
+
f" [repr.attrib_name]documents[/repr.attrib_name]: {num_docs}"
|
|
123
|
+
)
|
|
124
|
+
self.console.print(
|
|
125
|
+
f" [repr.attrib_name]versions (documents)[/repr.attrib_name]: {doc_versions}"
|
|
126
|
+
)
|
|
127
|
+
self.console.print(
|
|
128
|
+
f" [repr.attrib_name]versions (chunks)[/repr.attrib_name]: {chunk_versions}"
|
|
129
|
+
)
|
|
130
|
+
self.console.rule()
|
|
131
|
+
self.console.print("[bold]Versions[/bold]")
|
|
132
|
+
self.console.print(
|
|
133
|
+
f" [repr.attrib_name]haiku.rag[/repr.attrib_name]: {hr_version}"
|
|
134
|
+
)
|
|
135
|
+
self.console.print(
|
|
136
|
+
f" [repr.attrib_name]lancedb[/repr.attrib_name]: {ldb_version}"
|
|
137
|
+
)
|
|
138
|
+
self.console.print(
|
|
139
|
+
f" [repr.attrib_name]docling[/repr.attrib_name]: {docling_version}"
|
|
140
|
+
)
|
|
141
|
+
|
|
28
142
|
async def list_documents(self):
|
|
29
143
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
30
144
|
documents = await self.client.list_documents()
|
|
31
145
|
for doc in documents:
|
|
32
146
|
self._rich_print_document(doc, truncate=True)
|
|
33
147
|
|
|
34
|
-
async def add_document_from_text(self, text: str):
|
|
148
|
+
async def add_document_from_text(self, text: str, metadata: dict | None = None):
|
|
35
149
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
36
|
-
doc = await self.client.create_document(text)
|
|
150
|
+
doc = await self.client.create_document(text, metadata=metadata)
|
|
37
151
|
self._rich_print_document(doc, truncate=True)
|
|
38
152
|
self.console.print(
|
|
39
|
-
f"[
|
|
153
|
+
f"[bold green]Document {doc.id} added successfully.[/bold green]"
|
|
40
154
|
)
|
|
41
155
|
|
|
42
|
-
async def add_document_from_source(
|
|
156
|
+
async def add_document_from_source(
|
|
157
|
+
self, source: str, title: str | None = None, metadata: dict | None = None
|
|
158
|
+
):
|
|
43
159
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
44
|
-
doc = await self.client.create_document_from_source(
|
|
160
|
+
doc = await self.client.create_document_from_source(
|
|
161
|
+
source, title=title, metadata=metadata
|
|
162
|
+
)
|
|
45
163
|
self._rich_print_document(doc, truncate=True)
|
|
46
164
|
self.console.print(
|
|
47
|
-
f"[
|
|
165
|
+
f"[bold green]Document {doc.id} added successfully.[/bold green]"
|
|
48
166
|
)
|
|
49
167
|
|
|
50
168
|
async def get_document(self, doc_id: str):
|
|
@@ -59,7 +177,9 @@ class HaikuRAGApp:
|
|
|
59
177
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
60
178
|
deleted = await self.client.delete_document(doc_id)
|
|
61
179
|
if deleted:
|
|
62
|
-
self.console.print(
|
|
180
|
+
self.console.print(
|
|
181
|
+
f"[bold green]Document {doc_id} deleted successfully.[/bold green]"
|
|
182
|
+
)
|
|
63
183
|
else:
|
|
64
184
|
self.console.print(
|
|
65
185
|
f"[yellow]Document with id {doc_id} not found.[/yellow]"
|
|
@@ -69,7 +189,7 @@ class HaikuRAGApp:
|
|
|
69
189
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
70
190
|
results = await self.client.search(query, limit=limit)
|
|
71
191
|
if not results:
|
|
72
|
-
self.console.print("[
|
|
192
|
+
self.console.print("[yellow]No results found.[/yellow]")
|
|
73
193
|
return
|
|
74
194
|
for chunk, score in results:
|
|
75
195
|
self._rich_print_search_result(chunk, score)
|
|
@@ -102,9 +222,9 @@ class HaikuRAGApp:
|
|
|
102
222
|
self.console.print()
|
|
103
223
|
|
|
104
224
|
graph = build_research_graph()
|
|
225
|
+
context = ResearchContext(original_question=question)
|
|
105
226
|
state = ResearchState(
|
|
106
|
-
|
|
107
|
-
context=ResearchContext(original_question=question),
|
|
227
|
+
context=context,
|
|
108
228
|
max_iterations=max_iterations,
|
|
109
229
|
confidence_threshold=confidence_threshold,
|
|
110
230
|
max_concurrency=max_concurrency,
|
|
@@ -117,22 +237,20 @@ class HaikuRAGApp:
|
|
|
117
237
|
provider=Config.RESEARCH_PROVIDER or Config.QA_PROVIDER,
|
|
118
238
|
model=Config.RESEARCH_MODEL or Config.QA_MODEL,
|
|
119
239
|
)
|
|
120
|
-
# Prefer graph.run; fall back to iter if unavailable
|
|
121
240
|
report = None
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
if run.result:
|
|
133
|
-
report = run.result.output
|
|
241
|
+
async for event in stream_research_graph(graph, start, state, deps):
|
|
242
|
+
if event.type == "report":
|
|
243
|
+
report = event.report
|
|
244
|
+
break
|
|
245
|
+
if event.type == "error":
|
|
246
|
+
self.console.print(
|
|
247
|
+
f"[red]Error during research: {event.message}[/red]"
|
|
248
|
+
)
|
|
249
|
+
return
|
|
250
|
+
|
|
134
251
|
if report is None:
|
|
135
|
-
|
|
252
|
+
self.console.print("[red]Research did not produce a report.[/red]")
|
|
253
|
+
return
|
|
136
254
|
|
|
137
255
|
# Display the report
|
|
138
256
|
self.console.print("[bold green]Research Report[/bold green]")
|
|
@@ -202,14 +320,16 @@ class HaikuRAGApp:
|
|
|
202
320
|
return
|
|
203
321
|
|
|
204
322
|
self.console.print(
|
|
205
|
-
f"[
|
|
323
|
+
f"[bold cyan]Rebuilding database with {total_docs} documents...[/bold cyan]"
|
|
206
324
|
)
|
|
207
325
|
with Progress() as progress:
|
|
208
326
|
task = progress.add_task("Rebuilding...", total=total_docs)
|
|
209
327
|
async for _ in client.rebuild_database():
|
|
210
328
|
progress.update(task, advance=1)
|
|
211
329
|
|
|
212
|
-
self.console.print(
|
|
330
|
+
self.console.print(
|
|
331
|
+
"[bold green]Database rebuild completed successfully.[/bold green]"
|
|
332
|
+
)
|
|
213
333
|
except Exception as e:
|
|
214
334
|
self.console.print(f"[red]Error rebuilding database: {e}[/red]")
|
|
215
335
|
|
|
@@ -218,7 +338,9 @@ class HaikuRAGApp:
|
|
|
218
338
|
try:
|
|
219
339
|
async with HaikuRAG(db_path=self.db_path, skip_validation=True) as client:
|
|
220
340
|
await client.vacuum()
|
|
221
|
-
self.console.print(
|
|
341
|
+
self.console.print(
|
|
342
|
+
"[bold green]Vacuum completed successfully.[/bold green]"
|
|
343
|
+
)
|
|
222
344
|
except Exception as e:
|
|
223
345
|
self.console.print(f"[red]Error during vacuum: {e}[/red]")
|
|
224
346
|
|
|
@@ -240,7 +362,9 @@ class HaikuRAGApp:
|
|
|
240
362
|
else:
|
|
241
363
|
display_value = field_value
|
|
242
364
|
|
|
243
|
-
self.console.print(
|
|
365
|
+
self.console.print(
|
|
366
|
+
f" [repr.attrib_name]{field_name}[/repr.attrib_name]: {display_value}"
|
|
367
|
+
)
|
|
244
368
|
|
|
245
369
|
def _rich_print_document(self, doc: Document, truncate: bool = False):
|
|
246
370
|
"""Format a document for display."""
|
haiku/rag/cli.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import json
|
|
2
3
|
import warnings
|
|
3
4
|
from importlib.metadata import version
|
|
4
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
8
|
import typer
|
|
7
9
|
|
|
@@ -137,11 +139,41 @@ def list_documents(
|
|
|
137
139
|
asyncio.run(app.list_documents())
|
|
138
140
|
|
|
139
141
|
|
|
142
|
+
def _parse_meta_options(meta: list[str] | None) -> dict[str, Any]:
|
|
143
|
+
"""Parse repeated --meta KEY=VALUE options into a dictionary.
|
|
144
|
+
|
|
145
|
+
Raises a Typer error if any entry is malformed.
|
|
146
|
+
"""
|
|
147
|
+
result: dict[str, Any] = {}
|
|
148
|
+
if not meta:
|
|
149
|
+
return result
|
|
150
|
+
for item in meta:
|
|
151
|
+
if "=" not in item:
|
|
152
|
+
raise typer.BadParameter("--meta must be in KEY=VALUE format")
|
|
153
|
+
key, value = item.split("=", 1)
|
|
154
|
+
if not key:
|
|
155
|
+
raise typer.BadParameter("--meta key cannot be empty")
|
|
156
|
+
# Best-effort JSON coercion: numbers, booleans, null, arrays/objects
|
|
157
|
+
try:
|
|
158
|
+
parsed = json.loads(value)
|
|
159
|
+
result[key] = parsed
|
|
160
|
+
except Exception:
|
|
161
|
+
# Leave as string if not valid JSON literal
|
|
162
|
+
result[key] = value
|
|
163
|
+
return result
|
|
164
|
+
|
|
165
|
+
|
|
140
166
|
@cli.command("add", help="Add a document from text input")
|
|
141
167
|
def add_document_text(
|
|
142
168
|
text: str = typer.Argument(
|
|
143
169
|
help="The text content of the document to add",
|
|
144
170
|
),
|
|
171
|
+
meta: list[str] | None = typer.Option(
|
|
172
|
+
None,
|
|
173
|
+
"--meta",
|
|
174
|
+
help="Metadata entries as KEY=VALUE (repeatable)",
|
|
175
|
+
metavar="KEY=VALUE",
|
|
176
|
+
),
|
|
145
177
|
db: Path = typer.Option(
|
|
146
178
|
Config.DEFAULT_DATA_DIR / "haiku.rag.lancedb",
|
|
147
179
|
"--db",
|
|
@@ -151,7 +183,8 @@ def add_document_text(
|
|
|
151
183
|
from haiku.rag.app import HaikuRAGApp
|
|
152
184
|
|
|
153
185
|
app = HaikuRAGApp(db_path=db)
|
|
154
|
-
|
|
186
|
+
metadata = _parse_meta_options(meta)
|
|
187
|
+
asyncio.run(app.add_document_from_text(text=text, metadata=metadata or None))
|
|
155
188
|
|
|
156
189
|
|
|
157
190
|
@cli.command("add-src", help="Add a document from a file path or URL")
|
|
@@ -165,6 +198,12 @@ def add_document_src(
|
|
|
165
198
|
"--title",
|
|
166
199
|
help="Optional human-readable title to store with the document",
|
|
167
200
|
),
|
|
201
|
+
meta: list[str] | None = typer.Option(
|
|
202
|
+
None,
|
|
203
|
+
"--meta",
|
|
204
|
+
help="Metadata entries as KEY=VALUE (repeatable)",
|
|
205
|
+
metavar="KEY=VALUE",
|
|
206
|
+
),
|
|
168
207
|
db: Path = typer.Option(
|
|
169
208
|
Config.DEFAULT_DATA_DIR / "haiku.rag.lancedb",
|
|
170
209
|
"--db",
|
|
@@ -174,7 +213,12 @@ def add_document_src(
|
|
|
174
213
|
from haiku.rag.app import HaikuRAGApp
|
|
175
214
|
|
|
176
215
|
app = HaikuRAGApp(db_path=db)
|
|
177
|
-
|
|
216
|
+
metadata = _parse_meta_options(meta)
|
|
217
|
+
asyncio.run(
|
|
218
|
+
app.add_document_from_source(
|
|
219
|
+
source=source, title=title, metadata=metadata or None
|
|
220
|
+
)
|
|
221
|
+
)
|
|
178
222
|
|
|
179
223
|
|
|
180
224
|
@cli.command("get", help="Get and display a document by its ID")
|
|
@@ -347,6 +391,32 @@ def vacuum(
|
|
|
347
391
|
asyncio.run(app.vacuum())
|
|
348
392
|
|
|
349
393
|
|
|
394
|
+
@cli.command("info", help="Show read-only database info (no upgrades or writes)")
|
|
395
|
+
def info(
|
|
396
|
+
db: Path = typer.Option(
|
|
397
|
+
Config.DEFAULT_DATA_DIR / "haiku.rag.lancedb",
|
|
398
|
+
"--db",
|
|
399
|
+
help="Path to the LanceDB database file",
|
|
400
|
+
),
|
|
401
|
+
):
|
|
402
|
+
from haiku.rag.app import HaikuRAGApp
|
|
403
|
+
|
|
404
|
+
app = HaikuRAGApp(db_path=db)
|
|
405
|
+
asyncio.run(app.info())
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@cli.command("download-models", help="Download Docling and Ollama models per config")
|
|
409
|
+
def download_models_cmd():
|
|
410
|
+
from haiku.rag.utils import prefetch_models
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
prefetch_models()
|
|
414
|
+
typer.echo("Models downloaded successfully.")
|
|
415
|
+
except Exception as e:
|
|
416
|
+
typer.echo(f"Error downloading models: {e}")
|
|
417
|
+
raise typer.Exit(1)
|
|
418
|
+
|
|
419
|
+
|
|
350
420
|
@cli.command(
|
|
351
421
|
"serve", help="Start the haiku.rag MCP server (by default in streamable HTTP mode)"
|
|
352
422
|
)
|
haiku/rag/migration.py
CHANGED
|
@@ -51,7 +51,7 @@ class SQLiteToLanceDBMigrator:
|
|
|
51
51
|
|
|
52
52
|
sqlite_conn.enable_load_extension(True)
|
|
53
53
|
sqlite_vec.load(sqlite_conn)
|
|
54
|
-
self.console.print("[
|
|
54
|
+
self.console.print("[cyan]Loaded sqlite-vec extension[/cyan]")
|
|
55
55
|
except Exception as e:
|
|
56
56
|
self.console.print(
|
|
57
57
|
f"[yellow]Warning: Could not load sqlite-vec extension: {e}[/yellow]"
|
|
@@ -92,7 +92,7 @@ class SQLiteToLanceDBMigrator:
|
|
|
92
92
|
sqlite_conn.close()
|
|
93
93
|
|
|
94
94
|
# Optimize and cleanup using centralized vacuum
|
|
95
|
-
self.console.print("[
|
|
95
|
+
self.console.print("[cyan]Optimizing LanceDB...[/cyan]")
|
|
96
96
|
try:
|
|
97
97
|
lance_store.vacuum()
|
|
98
98
|
self.console.print("[green]✅ Optimization completed[/green]")
|
haiku/rag/research/__init__.py
CHANGED
|
@@ -6,6 +6,11 @@ from haiku.rag.research.graph import (
|
|
|
6
6
|
build_research_graph,
|
|
7
7
|
)
|
|
8
8
|
from haiku.rag.research.models import EvaluationResult, ResearchReport, SearchAnswer
|
|
9
|
+
from haiku.rag.research.stream import (
|
|
10
|
+
ResearchStateSnapshot,
|
|
11
|
+
ResearchStreamEvent,
|
|
12
|
+
stream_research_graph,
|
|
13
|
+
)
|
|
9
14
|
|
|
10
15
|
__all__ = [
|
|
11
16
|
"ResearchDependencies",
|
|
@@ -17,4 +22,7 @@ __all__ = [
|
|
|
17
22
|
"ResearchState",
|
|
18
23
|
"PlanNode",
|
|
19
24
|
"build_research_graph",
|
|
25
|
+
"stream_research_graph",
|
|
26
|
+
"ResearchStreamEvent",
|
|
27
|
+
"ResearchStateSnapshot",
|
|
20
28
|
]
|
haiku/rag/research/common.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
2
|
|
|
3
3
|
from pydantic_ai import format_as_xml
|
|
4
4
|
from pydantic_ai.models.openai import OpenAIChatModel
|
|
@@ -7,6 +7,10 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
|
|
7
7
|
|
|
8
8
|
from haiku.rag.config import Config
|
|
9
9
|
from haiku.rag.research.dependencies import ResearchContext
|
|
10
|
+
from haiku.rag.research.models import InsightAnalysis
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
13
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
10
14
|
|
|
11
15
|
|
|
12
16
|
def get_model(provider: str, model: str) -> Any:
|
|
@@ -27,9 +31,8 @@ def get_model(provider: str, model: str) -> Any:
|
|
|
27
31
|
return f"{provider}:{model}"
|
|
28
32
|
|
|
29
33
|
|
|
30
|
-
def log(
|
|
31
|
-
|
|
32
|
-
console.print(msg)
|
|
34
|
+
def log(deps: "ResearchDeps", state: "ResearchState", msg: str) -> None:
|
|
35
|
+
deps.emit_log(msg, state)
|
|
33
36
|
|
|
34
37
|
|
|
35
38
|
def format_context_for_prompt(context: ResearchContext) -> str:
|
|
@@ -47,7 +50,69 @@ def format_context_for_prompt(context: ResearchContext) -> str:
|
|
|
47
50
|
}
|
|
48
51
|
for qa in context.qa_responses
|
|
49
52
|
],
|
|
50
|
-
"insights":
|
|
51
|
-
|
|
53
|
+
"insights": [
|
|
54
|
+
{
|
|
55
|
+
"id": insight.id,
|
|
56
|
+
"summary": insight.summary,
|
|
57
|
+
"status": insight.status.value,
|
|
58
|
+
"supporting_sources": insight.supporting_sources,
|
|
59
|
+
"originating_questions": insight.originating_questions,
|
|
60
|
+
"notes": insight.notes,
|
|
61
|
+
}
|
|
62
|
+
for insight in context.insights
|
|
63
|
+
],
|
|
64
|
+
"gaps": [
|
|
65
|
+
{
|
|
66
|
+
"id": gap.id,
|
|
67
|
+
"description": gap.description,
|
|
68
|
+
"severity": gap.severity.value,
|
|
69
|
+
"blocking": gap.blocking,
|
|
70
|
+
"resolved": gap.resolved,
|
|
71
|
+
"resolved_by": gap.resolved_by,
|
|
72
|
+
"supporting_sources": gap.supporting_sources,
|
|
73
|
+
"notes": gap.notes,
|
|
74
|
+
}
|
|
75
|
+
for gap in context.gaps
|
|
76
|
+
],
|
|
52
77
|
}
|
|
53
78
|
return format_as_xml(context_data, root_tag="research_context")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def format_analysis_for_prompt(
|
|
82
|
+
analysis: InsightAnalysis | None,
|
|
83
|
+
) -> str:
|
|
84
|
+
"""Format the latest insight analysis as XML for prompts."""
|
|
85
|
+
|
|
86
|
+
if analysis is None:
|
|
87
|
+
return "<latest_analysis />"
|
|
88
|
+
|
|
89
|
+
data = {
|
|
90
|
+
"commentary": analysis.commentary,
|
|
91
|
+
"highlights": [
|
|
92
|
+
{
|
|
93
|
+
"id": insight.id,
|
|
94
|
+
"summary": insight.summary,
|
|
95
|
+
"status": insight.status.value,
|
|
96
|
+
"supporting_sources": insight.supporting_sources,
|
|
97
|
+
"originating_questions": insight.originating_questions,
|
|
98
|
+
"notes": insight.notes,
|
|
99
|
+
}
|
|
100
|
+
for insight in analysis.highlights
|
|
101
|
+
],
|
|
102
|
+
"gap_assessments": [
|
|
103
|
+
{
|
|
104
|
+
"id": gap.id,
|
|
105
|
+
"description": gap.description,
|
|
106
|
+
"severity": gap.severity.value,
|
|
107
|
+
"blocking": gap.blocking,
|
|
108
|
+
"resolved": gap.resolved,
|
|
109
|
+
"resolved_by": gap.resolved_by,
|
|
110
|
+
"supporting_sources": gap.supporting_sources,
|
|
111
|
+
"notes": gap.notes,
|
|
112
|
+
}
|
|
113
|
+
for gap in analysis.gap_assessments
|
|
114
|
+
],
|
|
115
|
+
"resolved_gaps": analysis.resolved_gaps,
|
|
116
|
+
"new_questions": analysis.new_questions,
|
|
117
|
+
}
|
|
118
|
+
return format_as_xml(data, root_tag="latest_analysis")
|