haiku.rag 0.5.5__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +4 -4
- haiku/rag/cli.py +38 -27
- haiku/rag/client.py +19 -23
- haiku/rag/config.py +6 -2
- haiku/rag/embeddings/__init__.py +3 -9
- haiku/rag/embeddings/openai.py +10 -13
- haiku/rag/logging.py +4 -0
- haiku/rag/mcp.py +12 -9
- haiku/rag/migration.py +316 -0
- haiku/rag/qa/__init__.py +10 -39
- haiku/rag/qa/agent.py +76 -0
- haiku/rag/qa/prompts.py +2 -0
- haiku/rag/reranking/__init__.py +0 -6
- haiku/rag/store/engine.py +173 -141
- haiku/rag/store/models/chunk.py +2 -2
- haiku/rag/store/models/document.py +1 -1
- haiku/rag/store/repositories/__init__.py +6 -2
- haiku/rag/store/repositories/chunk.py +279 -414
- haiku/rag/store/repositories/document.py +171 -205
- haiku/rag/store/repositories/settings.py +115 -49
- haiku/rag/store/upgrades/__init__.py +1 -3
- haiku/rag/utils.py +39 -31
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/METADATA +22 -22
- haiku_rag-0.7.0.dist-info/RECORD +39 -0
- haiku/rag/qa/anthropic.py +0 -108
- haiku/rag/qa/base.py +0 -89
- haiku/rag/qa/ollama.py +0 -60
- haiku/rag/qa/openai.py +0 -97
- haiku/rag/reranking/ollama.py +0 -84
- haiku/rag/store/repositories/base.py +0 -40
- haiku/rag/store/upgrades/v0_3_4.py +0 -26
- haiku_rag-0.5.5.dist-info/RECORD +0 -44
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.5.5.dist-info → haiku_rag-0.7.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/migration.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Migration script to migrate from SQLite to LanceDB.
|
|
4
|
+
|
|
5
|
+
This script will:
|
|
6
|
+
1. Read data from an existing SQLite database
|
|
7
|
+
2. Create a new LanceDB database with the same data
|
|
8
|
+
3. Preserve all documents, chunks, embeddings, and settings
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import sqlite3
|
|
13
|
+
import struct
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from uuid import uuid4
|
|
16
|
+
|
|
17
|
+
from rich.console import Console
|
|
18
|
+
from rich.progress import Progress, TaskID
|
|
19
|
+
|
|
20
|
+
from haiku.rag.store.engine import Store
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def deserialize_sqlite_embedding(data: bytes) -> list[float]:
|
|
24
|
+
"""Deserialize sqlite-vec embedding from bytes."""
|
|
25
|
+
if not data:
|
|
26
|
+
return []
|
|
27
|
+
# sqlite-vec stores embeddings as float32 arrays
|
|
28
|
+
num_floats = len(data) // 4
|
|
29
|
+
return list(struct.unpack(f"{num_floats}f", data))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SQLiteToLanceDBMigrator:
|
|
33
|
+
"""Migrates data from SQLite to LanceDB."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, sqlite_path: Path, lancedb_path: Path):
|
|
36
|
+
self.sqlite_path = sqlite_path
|
|
37
|
+
self.lancedb_path = lancedb_path
|
|
38
|
+
self.console = Console()
|
|
39
|
+
|
|
40
|
+
def migrate(self) -> bool:
|
|
41
|
+
"""Perform the migration."""
|
|
42
|
+
try:
|
|
43
|
+
self.console.print(
|
|
44
|
+
f"[blue]Starting migration from {self.sqlite_path} to {self.lancedb_path}[/blue]"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Check if SQLite database exists
|
|
48
|
+
if not self.sqlite_path.exists():
|
|
49
|
+
self.console.print(
|
|
50
|
+
f"[red]SQLite database not found: {self.sqlite_path}[/red]"
|
|
51
|
+
)
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
# Connect to SQLite database
|
|
55
|
+
sqlite_conn = sqlite3.connect(self.sqlite_path)
|
|
56
|
+
sqlite_conn.row_factory = sqlite3.Row
|
|
57
|
+
|
|
58
|
+
# Create LanceDB store
|
|
59
|
+
lance_store = Store(self.lancedb_path, skip_validation=True)
|
|
60
|
+
|
|
61
|
+
with Progress() as progress:
|
|
62
|
+
# Migrate documents
|
|
63
|
+
doc_task = progress.add_task(
|
|
64
|
+
"[green]Migrating documents...", total=None
|
|
65
|
+
)
|
|
66
|
+
document_id_mapping = self._migrate_documents(
|
|
67
|
+
sqlite_conn, lance_store, progress, doc_task
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Migrate chunks and embeddings
|
|
71
|
+
chunk_task = progress.add_task(
|
|
72
|
+
"[yellow]Migrating chunks and embeddings...", total=None
|
|
73
|
+
)
|
|
74
|
+
self._migrate_chunks(
|
|
75
|
+
sqlite_conn, lance_store, progress, chunk_task, document_id_mapping
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Migrate settings
|
|
79
|
+
settings_task = progress.add_task(
|
|
80
|
+
"[blue]Migrating settings...", total=None
|
|
81
|
+
)
|
|
82
|
+
self._migrate_settings(
|
|
83
|
+
sqlite_conn, lance_store, progress, settings_task
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
sqlite_conn.close()
|
|
87
|
+
|
|
88
|
+
# Optimize the chunks table after migration
|
|
89
|
+
self.console.print("[blue]Optimizing LanceDB...[/blue]")
|
|
90
|
+
try:
|
|
91
|
+
lance_store.chunks_table.optimize()
|
|
92
|
+
self.console.print("[green]✅ Optimization completed[/green]")
|
|
93
|
+
except Exception as e:
|
|
94
|
+
self.console.print(
|
|
95
|
+
f"[yellow]Warning: Optimization failed: {e}[/yellow]"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
lance_store.close()
|
|
99
|
+
|
|
100
|
+
self.console.print("[green]✅ Migration completed successfully![/green]")
|
|
101
|
+
self.console.print(
|
|
102
|
+
f"[green]✅ Migrated {len(document_id_mapping)} documents[/green]"
|
|
103
|
+
)
|
|
104
|
+
return True
|
|
105
|
+
|
|
106
|
+
except Exception as e:
|
|
107
|
+
self.console.print(f"[red]❌ Migration failed: {e}[/red]")
|
|
108
|
+
import traceback
|
|
109
|
+
|
|
110
|
+
self.console.print(f"[red]{traceback.format_exc()}[/red]")
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
def _migrate_documents(
|
|
114
|
+
self,
|
|
115
|
+
sqlite_conn: sqlite3.Connection,
|
|
116
|
+
lance_store: Store,
|
|
117
|
+
progress: Progress,
|
|
118
|
+
task: TaskID,
|
|
119
|
+
) -> dict[int, str]:
|
|
120
|
+
"""Migrate documents from SQLite to LanceDB and return ID mapping."""
|
|
121
|
+
cursor = sqlite_conn.cursor()
|
|
122
|
+
cursor.execute(
|
|
123
|
+
"SELECT id, content, uri, metadata, created_at, updated_at FROM documents ORDER BY id"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
documents = []
|
|
127
|
+
id_mapping = {} # Maps old integer ID to new UUID
|
|
128
|
+
|
|
129
|
+
for row in cursor.fetchall():
|
|
130
|
+
new_uuid = str(uuid4())
|
|
131
|
+
id_mapping[row["id"]] = new_uuid
|
|
132
|
+
|
|
133
|
+
doc_data = {
|
|
134
|
+
"id": new_uuid,
|
|
135
|
+
"content": row["content"],
|
|
136
|
+
"uri": row["uri"],
|
|
137
|
+
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
|
138
|
+
"created_at": row["created_at"],
|
|
139
|
+
"updated_at": row["updated_at"],
|
|
140
|
+
}
|
|
141
|
+
documents.append(doc_data)
|
|
142
|
+
|
|
143
|
+
# Batch insert documents to LanceDB
|
|
144
|
+
if documents:
|
|
145
|
+
from haiku.rag.store.engine import DocumentRecord
|
|
146
|
+
|
|
147
|
+
doc_records = [
|
|
148
|
+
DocumentRecord(
|
|
149
|
+
id=doc["id"],
|
|
150
|
+
content=doc["content"],
|
|
151
|
+
uri=doc["uri"],
|
|
152
|
+
metadata=json.dumps(doc["metadata"]),
|
|
153
|
+
created_at=doc["created_at"],
|
|
154
|
+
updated_at=doc["updated_at"],
|
|
155
|
+
)
|
|
156
|
+
for doc in documents
|
|
157
|
+
]
|
|
158
|
+
lance_store.documents_table.add(doc_records)
|
|
159
|
+
|
|
160
|
+
progress.update(task, completed=len(documents), total=len(documents))
|
|
161
|
+
return id_mapping
|
|
162
|
+
|
|
163
|
+
def _migrate_chunks(
|
|
164
|
+
self,
|
|
165
|
+
sqlite_conn: sqlite3.Connection,
|
|
166
|
+
lance_store: Store,
|
|
167
|
+
progress: Progress,
|
|
168
|
+
task: TaskID,
|
|
169
|
+
document_id_mapping: dict[int, str],
|
|
170
|
+
):
|
|
171
|
+
"""Migrate chunks and embeddings from SQLite to LanceDB."""
|
|
172
|
+
cursor = sqlite_conn.cursor()
|
|
173
|
+
|
|
174
|
+
# Get chunks first
|
|
175
|
+
cursor.execute("""
|
|
176
|
+
SELECT id, document_id, content, metadata
|
|
177
|
+
FROM chunks
|
|
178
|
+
ORDER BY id
|
|
179
|
+
""")
|
|
180
|
+
|
|
181
|
+
chunks_data = cursor.fetchall()
|
|
182
|
+
|
|
183
|
+
# Get embeddings separately to avoid vec0 virtual table issues
|
|
184
|
+
embeddings_map = {}
|
|
185
|
+
try:
|
|
186
|
+
# Try to get embeddings from the vec0 tables directly
|
|
187
|
+
cursor.execute("""
|
|
188
|
+
SELECT
|
|
189
|
+
r.chunk_id,
|
|
190
|
+
v.vectors
|
|
191
|
+
FROM chunk_embeddings_rowids r
|
|
192
|
+
JOIN chunk_embeddings_vector_chunks00 v ON r.rowid = v.rowid
|
|
193
|
+
""")
|
|
194
|
+
|
|
195
|
+
for row in cursor.fetchall():
|
|
196
|
+
chunk_id = row[0]
|
|
197
|
+
vectors_blob = row[1]
|
|
198
|
+
if vectors_blob and chunk_id not in embeddings_map:
|
|
199
|
+
embeddings_map[chunk_id] = vectors_blob
|
|
200
|
+
|
|
201
|
+
except sqlite3.OperationalError as e:
|
|
202
|
+
self.console.print(
|
|
203
|
+
f"[yellow]Warning: Could not extract embeddings: {e}[/yellow]"
|
|
204
|
+
)
|
|
205
|
+
self.console.print(
|
|
206
|
+
"[yellow]Continuing migration without embeddings...[/yellow]"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
chunks = []
|
|
210
|
+
for row in chunks_data:
|
|
211
|
+
# Generate new UUID for chunk
|
|
212
|
+
chunk_uuid = str(uuid4())
|
|
213
|
+
|
|
214
|
+
# Map the old document_id to new UUID
|
|
215
|
+
document_uuid = document_id_mapping.get(row["document_id"])
|
|
216
|
+
if not document_uuid:
|
|
217
|
+
self.console.print(
|
|
218
|
+
f"[yellow]Warning: Document ID {row['document_id']} not found in mapping for chunk {row['id']}[/yellow]"
|
|
219
|
+
)
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
# Get embedding for this chunk
|
|
223
|
+
embedding = []
|
|
224
|
+
embedding_blob = embeddings_map.get(row["id"])
|
|
225
|
+
if embedding_blob:
|
|
226
|
+
try:
|
|
227
|
+
embedding = deserialize_sqlite_embedding(embedding_blob)
|
|
228
|
+
except Exception as e:
|
|
229
|
+
self.console.print(
|
|
230
|
+
f"[yellow]Warning: Failed to deserialize embedding for chunk {row['id']}: {e}[/yellow]"
|
|
231
|
+
)
|
|
232
|
+
# Generate a zero vector of the expected dimension
|
|
233
|
+
embedding = [0.0] * lance_store.embedder._vector_dim
|
|
234
|
+
else:
|
|
235
|
+
# No embedding found, generate zero vector
|
|
236
|
+
embedding = [0.0] * lance_store.embedder._vector_dim
|
|
237
|
+
|
|
238
|
+
chunk_data = {
|
|
239
|
+
"id": chunk_uuid,
|
|
240
|
+
"document_id": document_uuid,
|
|
241
|
+
"content": row["content"],
|
|
242
|
+
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
|
243
|
+
"vector": embedding,
|
|
244
|
+
}
|
|
245
|
+
chunks.append(chunk_data)
|
|
246
|
+
|
|
247
|
+
# Batch insert chunks to LanceDB
|
|
248
|
+
if chunks:
|
|
249
|
+
chunk_records = [
|
|
250
|
+
lance_store.ChunkRecord(
|
|
251
|
+
id=chunk["id"],
|
|
252
|
+
document_id=chunk["document_id"],
|
|
253
|
+
content=chunk["content"],
|
|
254
|
+
metadata=json.dumps(chunk["metadata"]),
|
|
255
|
+
vector=chunk["vector"],
|
|
256
|
+
)
|
|
257
|
+
for chunk in chunks
|
|
258
|
+
]
|
|
259
|
+
lance_store.chunks_table.add(chunk_records)
|
|
260
|
+
|
|
261
|
+
progress.update(task, completed=len(chunks), total=len(chunks))
|
|
262
|
+
|
|
263
|
+
def _migrate_settings(
|
|
264
|
+
self,
|
|
265
|
+
sqlite_conn: sqlite3.Connection,
|
|
266
|
+
lance_store: Store,
|
|
267
|
+
progress: Progress,
|
|
268
|
+
task: TaskID,
|
|
269
|
+
):
|
|
270
|
+
"""Migrate settings from SQLite to LanceDB."""
|
|
271
|
+
cursor = sqlite_conn.cursor()
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
cursor.execute("SELECT id, settings FROM settings WHERE id = 1")
|
|
275
|
+
row = cursor.fetchone()
|
|
276
|
+
|
|
277
|
+
if row:
|
|
278
|
+
settings_data = json.loads(row["settings"]) if row["settings"] else {}
|
|
279
|
+
|
|
280
|
+
# Update the existing settings in LanceDB (use string ID)
|
|
281
|
+
lance_store.settings_table.update(
|
|
282
|
+
where="id = 'settings'",
|
|
283
|
+
values={"settings": json.dumps(settings_data)},
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
progress.update(task, completed=1, total=1)
|
|
287
|
+
else:
|
|
288
|
+
progress.update(task, completed=0, total=0)
|
|
289
|
+
|
|
290
|
+
except sqlite3.OperationalError:
|
|
291
|
+
# Settings table doesn't exist in old SQLite database
|
|
292
|
+
self.console.print(
|
|
293
|
+
"[yellow]No settings table found in SQLite database[/yellow]"
|
|
294
|
+
)
|
|
295
|
+
progress.update(task, completed=0, total=0)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
async def migrate_sqlite_to_lancedb(
|
|
299
|
+
sqlite_path: Path, lancedb_path: Path | None = None
|
|
300
|
+
) -> bool:
|
|
301
|
+
"""
|
|
302
|
+
Migrate an existing SQLite database to LanceDB.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
sqlite_path: Path to the existing SQLite database
|
|
306
|
+
lancedb_path: Path for the new LanceDB database (optional, will auto-generate if not provided)
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
True if migration was successful, False otherwise
|
|
310
|
+
"""
|
|
311
|
+
if lancedb_path is None:
|
|
312
|
+
# Auto-generate LanceDB path
|
|
313
|
+
lancedb_path = sqlite_path.parent / (sqlite_path.stem + ".lancedb")
|
|
314
|
+
|
|
315
|
+
migrator = SQLiteToLanceDBMigrator(sqlite_path, lancedb_path)
|
|
316
|
+
return migrator.migrate()
|
haiku/rag/qa/__init__.py
CHANGED
|
@@ -1,44 +1,15 @@
|
|
|
1
1
|
from haiku.rag.client import HaikuRAG
|
|
2
2
|
from haiku.rag.config import Config
|
|
3
|
-
from haiku.rag.qa.
|
|
4
|
-
from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
|
|
3
|
+
from haiku.rag.qa.agent import QuestionAnswerAgent
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def get_qa_agent(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
"""
|
|
11
|
-
Factory function to get the appropriate QA agent based on the configuration.
|
|
12
|
-
"""
|
|
13
|
-
if Config.QA_PROVIDER == "ollama":
|
|
14
|
-
return QuestionAnswerOllamaAgent(
|
|
15
|
-
client, model or Config.QA_MODEL, use_citations
|
|
16
|
-
)
|
|
6
|
+
def get_qa_agent(client: HaikuRAG, use_citations: bool = False) -> QuestionAnswerAgent:
|
|
7
|
+
provider = Config.QA_PROVIDER
|
|
8
|
+
model_name = Config.QA_MODEL
|
|
17
9
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
"Please install haiku.rag with the 'openai' extra:"
|
|
25
|
-
"uv pip install haiku.rag[openai]"
|
|
26
|
-
)
|
|
27
|
-
return QuestionAnswerOpenAIAgent(
|
|
28
|
-
client, model or Config.QA_MODEL, use_citations
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
if Config.QA_PROVIDER == "anthropic":
|
|
32
|
-
try:
|
|
33
|
-
from haiku.rag.qa.anthropic import QuestionAnswerAnthropicAgent
|
|
34
|
-
except ImportError:
|
|
35
|
-
raise ImportError(
|
|
36
|
-
"Anthropic QA agent requires the 'anthropic' package. "
|
|
37
|
-
"Please install haiku.rag with the 'anthropic' extra:"
|
|
38
|
-
"uv pip install haiku.rag[anthropic]"
|
|
39
|
-
)
|
|
40
|
-
return QuestionAnswerAnthropicAgent(
|
|
41
|
-
client, model or Config.QA_MODEL, use_citations
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
|
|
10
|
+
return QuestionAnswerAgent(
|
|
11
|
+
client=client,
|
|
12
|
+
provider=provider,
|
|
13
|
+
model=model_name,
|
|
14
|
+
use_citations=use_citations,
|
|
15
|
+
)
|
haiku/rag/qa/agent.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
from pydantic_ai import Agent, RunContext
|
|
3
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
4
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
5
|
+
|
|
6
|
+
from haiku.rag.client import HaikuRAG
|
|
7
|
+
from haiku.rag.config import Config
|
|
8
|
+
from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SearchResult(BaseModel):
|
|
12
|
+
content: str = Field(description="The document text content")
|
|
13
|
+
score: float = Field(description="Relevance score (higher is more relevant)")
|
|
14
|
+
document_uri: str = Field(description="Source URI/path of the document")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Dependencies(BaseModel):
|
|
18
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
19
|
+
client: HaikuRAG
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QuestionAnswerAgent:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
client: HaikuRAG,
|
|
26
|
+
provider: str,
|
|
27
|
+
model: str,
|
|
28
|
+
use_citations: bool = False,
|
|
29
|
+
q: float = 0.0,
|
|
30
|
+
):
|
|
31
|
+
self._client = client
|
|
32
|
+
|
|
33
|
+
system_prompt = SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
|
|
34
|
+
model_obj = self._get_model(provider, model)
|
|
35
|
+
|
|
36
|
+
self._agent = Agent(
|
|
37
|
+
model=model_obj,
|
|
38
|
+
deps_type=Dependencies,
|
|
39
|
+
system_prompt=system_prompt,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@self._agent.tool
|
|
43
|
+
async def search_documents(
|
|
44
|
+
ctx: RunContext[Dependencies],
|
|
45
|
+
query: str,
|
|
46
|
+
limit: int = 3,
|
|
47
|
+
) -> list[SearchResult]:
|
|
48
|
+
"""Search the knowledge base for relevant documents."""
|
|
49
|
+
search_results = await ctx.deps.client.search(query, limit=limit)
|
|
50
|
+
expanded_results = await ctx.deps.client.expand_context(search_results)
|
|
51
|
+
|
|
52
|
+
return [
|
|
53
|
+
SearchResult(
|
|
54
|
+
content=chunk.content,
|
|
55
|
+
score=score,
|
|
56
|
+
document_uri=chunk.document_uri or "",
|
|
57
|
+
)
|
|
58
|
+
for chunk, score in expanded_results
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
def _get_model(self, provider: str, model: str):
|
|
62
|
+
"""Get the appropriate model object for the provider."""
|
|
63
|
+
if provider == "ollama":
|
|
64
|
+
return OpenAIModel(
|
|
65
|
+
model_name=model,
|
|
66
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
# For all other providers, use the provider:model format
|
|
70
|
+
return f"{provider}:{model}"
|
|
71
|
+
|
|
72
|
+
async def answer(self, question: str) -> str:
|
|
73
|
+
"""Answer a question using the RAG system."""
|
|
74
|
+
deps = Dependencies(client=self._client)
|
|
75
|
+
result = await self._agent.run(question, deps=deps)
|
|
76
|
+
return result.output
|
haiku/rag/qa/prompts.py
CHANGED
|
@@ -18,6 +18,7 @@ Guidelines:
|
|
|
18
18
|
- Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
|
|
19
19
|
|
|
20
20
|
Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
|
|
21
|
+
/no_think
|
|
21
22
|
"""
|
|
22
23
|
|
|
23
24
|
SYSTEM_PROMPT_WITH_CITATIONS = """
|
|
@@ -55,4 +56,5 @@ Citations:
|
|
|
55
56
|
- /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
|
|
56
57
|
|
|
57
58
|
Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
|
|
59
|
+
/no_think
|
|
58
60
|
"""
|
haiku/rag/reranking/__init__.py
CHANGED