haiku.rag 0.5.5__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

haiku/rag/migration.py ADDED
@@ -0,0 +1,316 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Migration script to migrate from SQLite to LanceDB.
4
+
5
+ This script will:
6
+ 1. Read data from an existing SQLite database
7
+ 2. Create a new LanceDB database with the same data
8
+ 3. Preserve all documents, chunks, embeddings, and settings
9
+ """
10
+
11
+ import json
12
+ import sqlite3
13
+ import struct
14
+ from pathlib import Path
15
+ from uuid import uuid4
16
+
17
+ from rich.console import Console
18
+ from rich.progress import Progress, TaskID
19
+
20
+ from haiku.rag.store.engine import Store
21
+
22
+
23
+ def deserialize_sqlite_embedding(data: bytes) -> list[float]:
24
+ """Deserialize sqlite-vec embedding from bytes."""
25
+ if not data:
26
+ return []
27
+ # sqlite-vec stores embeddings as float32 arrays
28
+ num_floats = len(data) // 4
29
+ return list(struct.unpack(f"{num_floats}f", data))
30
+
31
+
32
+ class SQLiteToLanceDBMigrator:
33
+ """Migrates data from SQLite to LanceDB."""
34
+
35
+ def __init__(self, sqlite_path: Path, lancedb_path: Path):
36
+ self.sqlite_path = sqlite_path
37
+ self.lancedb_path = lancedb_path
38
+ self.console = Console()
39
+
40
+ def migrate(self) -> bool:
41
+ """Perform the migration."""
42
+ try:
43
+ self.console.print(
44
+ f"[blue]Starting migration from {self.sqlite_path} to {self.lancedb_path}[/blue]"
45
+ )
46
+
47
+ # Check if SQLite database exists
48
+ if not self.sqlite_path.exists():
49
+ self.console.print(
50
+ f"[red]SQLite database not found: {self.sqlite_path}[/red]"
51
+ )
52
+ return False
53
+
54
+ # Connect to SQLite database
55
+ sqlite_conn = sqlite3.connect(self.sqlite_path)
56
+ sqlite_conn.row_factory = sqlite3.Row
57
+
58
+ # Create LanceDB store
59
+ lance_store = Store(self.lancedb_path, skip_validation=True)
60
+
61
+ with Progress() as progress:
62
+ # Migrate documents
63
+ doc_task = progress.add_task(
64
+ "[green]Migrating documents...", total=None
65
+ )
66
+ document_id_mapping = self._migrate_documents(
67
+ sqlite_conn, lance_store, progress, doc_task
68
+ )
69
+
70
+ # Migrate chunks and embeddings
71
+ chunk_task = progress.add_task(
72
+ "[yellow]Migrating chunks and embeddings...", total=None
73
+ )
74
+ self._migrate_chunks(
75
+ sqlite_conn, lance_store, progress, chunk_task, document_id_mapping
76
+ )
77
+
78
+ # Migrate settings
79
+ settings_task = progress.add_task(
80
+ "[blue]Migrating settings...", total=None
81
+ )
82
+ self._migrate_settings(
83
+ sqlite_conn, lance_store, progress, settings_task
84
+ )
85
+
86
+ sqlite_conn.close()
87
+
88
+ # Optimize the chunks table after migration
89
+ self.console.print("[blue]Optimizing LanceDB...[/blue]")
90
+ try:
91
+ lance_store.chunks_table.optimize()
92
+ self.console.print("[green]✅ Optimization completed[/green]")
93
+ except Exception as e:
94
+ self.console.print(
95
+ f"[yellow]Warning: Optimization failed: {e}[/yellow]"
96
+ )
97
+
98
+ lance_store.close()
99
+
100
+ self.console.print("[green]✅ Migration completed successfully![/green]")
101
+ self.console.print(
102
+ f"[green]✅ Migrated {len(document_id_mapping)} documents[/green]"
103
+ )
104
+ return True
105
+
106
+ except Exception as e:
107
+ self.console.print(f"[red]❌ Migration failed: {e}[/red]")
108
+ import traceback
109
+
110
+ self.console.print(f"[red]{traceback.format_exc()}[/red]")
111
+ return False
112
+
113
+ def _migrate_documents(
114
+ self,
115
+ sqlite_conn: sqlite3.Connection,
116
+ lance_store: Store,
117
+ progress: Progress,
118
+ task: TaskID,
119
+ ) -> dict[int, str]:
120
+ """Migrate documents from SQLite to LanceDB and return ID mapping."""
121
+ cursor = sqlite_conn.cursor()
122
+ cursor.execute(
123
+ "SELECT id, content, uri, metadata, created_at, updated_at FROM documents ORDER BY id"
124
+ )
125
+
126
+ documents = []
127
+ id_mapping = {} # Maps old integer ID to new UUID
128
+
129
+ for row in cursor.fetchall():
130
+ new_uuid = str(uuid4())
131
+ id_mapping[row["id"]] = new_uuid
132
+
133
+ doc_data = {
134
+ "id": new_uuid,
135
+ "content": row["content"],
136
+ "uri": row["uri"],
137
+ "metadata": json.loads(row["metadata"]) if row["metadata"] else {},
138
+ "created_at": row["created_at"],
139
+ "updated_at": row["updated_at"],
140
+ }
141
+ documents.append(doc_data)
142
+
143
+ # Batch insert documents to LanceDB
144
+ if documents:
145
+ from haiku.rag.store.engine import DocumentRecord
146
+
147
+ doc_records = [
148
+ DocumentRecord(
149
+ id=doc["id"],
150
+ content=doc["content"],
151
+ uri=doc["uri"],
152
+ metadata=json.dumps(doc["metadata"]),
153
+ created_at=doc["created_at"],
154
+ updated_at=doc["updated_at"],
155
+ )
156
+ for doc in documents
157
+ ]
158
+ lance_store.documents_table.add(doc_records)
159
+
160
+ progress.update(task, completed=len(documents), total=len(documents))
161
+ return id_mapping
162
+
163
+ def _migrate_chunks(
164
+ self,
165
+ sqlite_conn: sqlite3.Connection,
166
+ lance_store: Store,
167
+ progress: Progress,
168
+ task: TaskID,
169
+ document_id_mapping: dict[int, str],
170
+ ):
171
+ """Migrate chunks and embeddings from SQLite to LanceDB."""
172
+ cursor = sqlite_conn.cursor()
173
+
174
+ # Get chunks first
175
+ cursor.execute("""
176
+ SELECT id, document_id, content, metadata
177
+ FROM chunks
178
+ ORDER BY id
179
+ """)
180
+
181
+ chunks_data = cursor.fetchall()
182
+
183
+ # Get embeddings separately to avoid vec0 virtual table issues
184
+ embeddings_map = {}
185
+ try:
186
+ # Try to get embeddings from the vec0 tables directly
187
+ cursor.execute("""
188
+ SELECT
189
+ r.chunk_id,
190
+ v.vectors
191
+ FROM chunk_embeddings_rowids r
192
+ JOIN chunk_embeddings_vector_chunks00 v ON r.rowid = v.rowid
193
+ """)
194
+
195
+ for row in cursor.fetchall():
196
+ chunk_id = row[0]
197
+ vectors_blob = row[1]
198
+ if vectors_blob and chunk_id not in embeddings_map:
199
+ embeddings_map[chunk_id] = vectors_blob
200
+
201
+ except sqlite3.OperationalError as e:
202
+ self.console.print(
203
+ f"[yellow]Warning: Could not extract embeddings: {e}[/yellow]"
204
+ )
205
+ self.console.print(
206
+ "[yellow]Continuing migration without embeddings...[/yellow]"
207
+ )
208
+
209
+ chunks = []
210
+ for row in chunks_data:
211
+ # Generate new UUID for chunk
212
+ chunk_uuid = str(uuid4())
213
+
214
+ # Map the old document_id to new UUID
215
+ document_uuid = document_id_mapping.get(row["document_id"])
216
+ if not document_uuid:
217
+ self.console.print(
218
+ f"[yellow]Warning: Document ID {row['document_id']} not found in mapping for chunk {row['id']}[/yellow]"
219
+ )
220
+ continue
221
+
222
+ # Get embedding for this chunk
223
+ embedding = []
224
+ embedding_blob = embeddings_map.get(row["id"])
225
+ if embedding_blob:
226
+ try:
227
+ embedding = deserialize_sqlite_embedding(embedding_blob)
228
+ except Exception as e:
229
+ self.console.print(
230
+ f"[yellow]Warning: Failed to deserialize embedding for chunk {row['id']}: {e}[/yellow]"
231
+ )
232
+ # Generate a zero vector of the expected dimension
233
+ embedding = [0.0] * lance_store.embedder._vector_dim
234
+ else:
235
+ # No embedding found, generate zero vector
236
+ embedding = [0.0] * lance_store.embedder._vector_dim
237
+
238
+ chunk_data = {
239
+ "id": chunk_uuid,
240
+ "document_id": document_uuid,
241
+ "content": row["content"],
242
+ "metadata": json.loads(row["metadata"]) if row["metadata"] else {},
243
+ "vector": embedding,
244
+ }
245
+ chunks.append(chunk_data)
246
+
247
+ # Batch insert chunks to LanceDB
248
+ if chunks:
249
+ chunk_records = [
250
+ lance_store.ChunkRecord(
251
+ id=chunk["id"],
252
+ document_id=chunk["document_id"],
253
+ content=chunk["content"],
254
+ metadata=json.dumps(chunk["metadata"]),
255
+ vector=chunk["vector"],
256
+ )
257
+ for chunk in chunks
258
+ ]
259
+ lance_store.chunks_table.add(chunk_records)
260
+
261
+ progress.update(task, completed=len(chunks), total=len(chunks))
262
+
263
+ def _migrate_settings(
264
+ self,
265
+ sqlite_conn: sqlite3.Connection,
266
+ lance_store: Store,
267
+ progress: Progress,
268
+ task: TaskID,
269
+ ):
270
+ """Migrate settings from SQLite to LanceDB."""
271
+ cursor = sqlite_conn.cursor()
272
+
273
+ try:
274
+ cursor.execute("SELECT id, settings FROM settings WHERE id = 1")
275
+ row = cursor.fetchone()
276
+
277
+ if row:
278
+ settings_data = json.loads(row["settings"]) if row["settings"] else {}
279
+
280
+ # Update the existing settings in LanceDB (use string ID)
281
+ lance_store.settings_table.update(
282
+ where="id = 'settings'",
283
+ values={"settings": json.dumps(settings_data)},
284
+ )
285
+
286
+ progress.update(task, completed=1, total=1)
287
+ else:
288
+ progress.update(task, completed=0, total=0)
289
+
290
+ except sqlite3.OperationalError:
291
+ # Settings table doesn't exist in old SQLite database
292
+ self.console.print(
293
+ "[yellow]No settings table found in SQLite database[/yellow]"
294
+ )
295
+ progress.update(task, completed=0, total=0)
296
+
297
+
298
+ async def migrate_sqlite_to_lancedb(
299
+ sqlite_path: Path, lancedb_path: Path | None = None
300
+ ) -> bool:
301
+ """
302
+ Migrate an existing SQLite database to LanceDB.
303
+
304
+ Args:
305
+ sqlite_path: Path to the existing SQLite database
306
+ lancedb_path: Path for the new LanceDB database (optional, will auto-generate if not provided)
307
+
308
+ Returns:
309
+ True if migration was successful, False otherwise
310
+ """
311
+ if lancedb_path is None:
312
+ # Auto-generate LanceDB path
313
+ lancedb_path = sqlite_path.parent / (sqlite_path.stem + ".lancedb")
314
+
315
+ migrator = SQLiteToLanceDBMigrator(sqlite_path, lancedb_path)
316
+ return migrator.migrate()
haiku/rag/qa/__init__.py CHANGED
@@ -1,44 +1,15 @@
1
1
  from haiku.rag.client import HaikuRAG
2
2
  from haiku.rag.config import Config
3
- from haiku.rag.qa.base import QuestionAnswerAgentBase
4
- from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
3
+ from haiku.rag.qa.agent import QuestionAnswerAgent
5
4
 
6
5
 
7
- def get_qa_agent(
8
- client: HaikuRAG, model: str = "", use_citations: bool = False
9
- ) -> QuestionAnswerAgentBase:
10
- """
11
- Factory function to get the appropriate QA agent based on the configuration.
12
- """
13
- if Config.QA_PROVIDER == "ollama":
14
- return QuestionAnswerOllamaAgent(
15
- client, model or Config.QA_MODEL, use_citations
16
- )
6
+ def get_qa_agent(client: HaikuRAG, use_citations: bool = False) -> QuestionAnswerAgent:
7
+ provider = Config.QA_PROVIDER
8
+ model_name = Config.QA_MODEL
17
9
 
18
- if Config.QA_PROVIDER == "openai":
19
- try:
20
- from haiku.rag.qa.openai import QuestionAnswerOpenAIAgent
21
- except ImportError:
22
- raise ImportError(
23
- "OpenAI QA agent requires the 'openai' package. "
24
- "Please install haiku.rag with the 'openai' extra:"
25
- "uv pip install haiku.rag[openai]"
26
- )
27
- return QuestionAnswerOpenAIAgent(
28
- client, model or Config.QA_MODEL, use_citations
29
- )
30
-
31
- if Config.QA_PROVIDER == "anthropic":
32
- try:
33
- from haiku.rag.qa.anthropic import QuestionAnswerAnthropicAgent
34
- except ImportError:
35
- raise ImportError(
36
- "Anthropic QA agent requires the 'anthropic' package. "
37
- "Please install haiku.rag with the 'anthropic' extra:"
38
- "uv pip install haiku.rag[anthropic]"
39
- )
40
- return QuestionAnswerAnthropicAgent(
41
- client, model or Config.QA_MODEL, use_citations
42
- )
43
-
44
- raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
10
+ return QuestionAnswerAgent(
11
+ client=client,
12
+ provider=provider,
13
+ model=model_name,
14
+ use_citations=use_citations,
15
+ )
haiku/rag/qa/agent.py ADDED
@@ -0,0 +1,76 @@
1
+ from pydantic import BaseModel, Field
2
+ from pydantic_ai import Agent, RunContext
3
+ from pydantic_ai.models.openai import OpenAIModel
4
+ from pydantic_ai.providers.ollama import OllamaProvider
5
+
6
+ from haiku.rag.client import HaikuRAG
7
+ from haiku.rag.config import Config
8
+ from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
9
+
10
+
11
+ class SearchResult(BaseModel):
12
+ content: str = Field(description="The document text content")
13
+ score: float = Field(description="Relevance score (higher is more relevant)")
14
+ document_uri: str = Field(description="Source URI/path of the document")
15
+
16
+
17
+ class Dependencies(BaseModel):
18
+ model_config = {"arbitrary_types_allowed": True}
19
+ client: HaikuRAG
20
+
21
+
22
+ class QuestionAnswerAgent:
23
+ def __init__(
24
+ self,
25
+ client: HaikuRAG,
26
+ provider: str,
27
+ model: str,
28
+ use_citations: bool = False,
29
+ q: float = 0.0,
30
+ ):
31
+ self._client = client
32
+
33
+ system_prompt = SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
34
+ model_obj = self._get_model(provider, model)
35
+
36
+ self._agent = Agent(
37
+ model=model_obj,
38
+ deps_type=Dependencies,
39
+ system_prompt=system_prompt,
40
+ )
41
+
42
+ @self._agent.tool
43
+ async def search_documents(
44
+ ctx: RunContext[Dependencies],
45
+ query: str,
46
+ limit: int = 3,
47
+ ) -> list[SearchResult]:
48
+ """Search the knowledge base for relevant documents."""
49
+ search_results = await ctx.deps.client.search(query, limit=limit)
50
+ expanded_results = await ctx.deps.client.expand_context(search_results)
51
+
52
+ return [
53
+ SearchResult(
54
+ content=chunk.content,
55
+ score=score,
56
+ document_uri=chunk.document_uri or "",
57
+ )
58
+ for chunk, score in expanded_results
59
+ ]
60
+
61
+ def _get_model(self, provider: str, model: str):
62
+ """Get the appropriate model object for the provider."""
63
+ if provider == "ollama":
64
+ return OpenAIModel(
65
+ model_name=model,
66
+ provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
67
+ )
68
+ else:
69
+ # For all other providers, use the provider:model format
70
+ return f"{provider}:{model}"
71
+
72
+ async def answer(self, question: str) -> str:
73
+ """Answer a question using the RAG system."""
74
+ deps = Dependencies(client=self._client)
75
+ result = await self._agent.run(question, deps=deps)
76
+ return result.output
haiku/rag/qa/prompts.py CHANGED
@@ -18,6 +18,7 @@ Guidelines:
18
18
  - Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
19
19
 
20
20
  Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
21
+ /no_think
21
22
  """
22
23
 
23
24
  SYSTEM_PROMPT_WITH_CITATIONS = """
@@ -55,4 +56,5 @@ Citations:
55
56
  - /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
56
57
 
57
58
  Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
59
+ /no_think
58
60
  """
@@ -31,10 +31,4 @@ def get_reranker() -> RerankerBase | None:
31
31
  except ImportError:
32
32
  return None
33
33
 
34
- if Config.RERANK_PROVIDER == "ollama":
35
- from haiku.rag.reranking.ollama import OllamaReranker
36
-
37
- _reranker = OllamaReranker()
38
- return _reranker
39
-
40
34
  return None