haiku.rag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +107 -0
- haiku/rag/chunker.py +76 -0
- haiku/rag/cli.py +153 -0
- haiku/rag/client.py +261 -0
- haiku/rag/config.py +28 -0
- haiku/rag/embeddings/__init__.py +24 -0
- haiku/rag/embeddings/base.py +12 -0
- haiku/rag/embeddings/ollama.py +14 -0
- haiku/rag/embeddings/voyageai.py +17 -0
- haiku/rag/mcp.py +141 -0
- haiku/rag/reader.py +52 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +80 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +12 -0
- haiku/rag/store/models/document.py +16 -0
- haiku/rag/store/repositories/__init__.py +5 -0
- haiku/rag/store/repositories/base.py +40 -0
- haiku/rag/store/repositories/chunk.py +424 -0
- haiku/rag/store/repositories/document.py +210 -0
- haiku/rag/utils.py +25 -0
- haiku_rag-0.1.0.dist-info/METADATA +195 -0
- haiku_rag-0.1.0.dist-info/RECORD +27 -0
- haiku_rag-0.1.0.dist-info/WHEEL +4 -0
- haiku_rag-0.1.0.dist-info/entry_points.txt +2 -0
- haiku_rag-0.1.0.dist-info/licenses/LICENSE +7 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from haiku.rag.chunker import chunker
|
|
5
|
+
from haiku.rag.embeddings import get_embedder
|
|
6
|
+
from haiku.rag.store.models.chunk import Chunk
|
|
7
|
+
from haiku.rag.store.repositories.base import BaseRepository
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ChunkRepository(BaseRepository[Chunk]):
|
|
11
|
+
"""Repository for Chunk database operations."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, store):
|
|
14
|
+
super().__init__(store)
|
|
15
|
+
self.embedder = get_embedder()
|
|
16
|
+
|
|
17
|
+
async def create(self, entity: Chunk, commit: bool = True) -> Chunk:
|
|
18
|
+
"""Create a chunk in the database."""
|
|
19
|
+
if self.store._connection is None:
|
|
20
|
+
raise ValueError("Store connection is not available")
|
|
21
|
+
|
|
22
|
+
cursor = self.store._connection.cursor()
|
|
23
|
+
cursor.execute(
|
|
24
|
+
"""
|
|
25
|
+
INSERT INTO chunks (document_id, content, metadata)
|
|
26
|
+
VALUES (:document_id, :content, :metadata)
|
|
27
|
+
""",
|
|
28
|
+
{
|
|
29
|
+
"document_id": entity.document_id,
|
|
30
|
+
"content": entity.content,
|
|
31
|
+
"metadata": json.dumps(entity.metadata),
|
|
32
|
+
},
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
entity.id = cursor.lastrowid
|
|
36
|
+
|
|
37
|
+
# Generate and store embedding
|
|
38
|
+
embedding = await self.embedder.embed(entity.content)
|
|
39
|
+
serialized_embedding = self.store.serialize_embedding(embedding)
|
|
40
|
+
cursor.execute(
|
|
41
|
+
"""
|
|
42
|
+
INSERT INTO chunk_embeddings (chunk_id, embedding)
|
|
43
|
+
VALUES (:chunk_id, :embedding)
|
|
44
|
+
""",
|
|
45
|
+
{"chunk_id": entity.id, "embedding": serialized_embedding},
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Insert into FTS5 table for full-text search
|
|
49
|
+
cursor.execute(
|
|
50
|
+
"""
|
|
51
|
+
INSERT INTO chunks_fts(rowid, content)
|
|
52
|
+
VALUES (:rowid, :content)
|
|
53
|
+
""",
|
|
54
|
+
{"rowid": entity.id, "content": entity.content},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if commit:
|
|
58
|
+
self.store._connection.commit()
|
|
59
|
+
return entity
|
|
60
|
+
|
|
61
|
+
async def get_by_id(self, entity_id: int) -> Chunk | None:
|
|
62
|
+
"""Get a chunk by its ID."""
|
|
63
|
+
if self.store._connection is None:
|
|
64
|
+
raise ValueError("Store connection is not available")
|
|
65
|
+
|
|
66
|
+
cursor = self.store._connection.cursor()
|
|
67
|
+
cursor.execute(
|
|
68
|
+
"""
|
|
69
|
+
SELECT id, document_id, content, metadata
|
|
70
|
+
FROM chunks WHERE id = :id
|
|
71
|
+
""",
|
|
72
|
+
{"id": entity_id},
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
row = cursor.fetchone()
|
|
76
|
+
if row is None:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
chunk_id, document_id, content, metadata_json = row
|
|
80
|
+
metadata = json.loads(metadata_json) if metadata_json else {}
|
|
81
|
+
|
|
82
|
+
return Chunk(
|
|
83
|
+
id=chunk_id, document_id=document_id, content=content, metadata=metadata
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
async def update(self, entity: Chunk) -> Chunk:
|
|
87
|
+
"""Update an existing chunk."""
|
|
88
|
+
if self.store._connection is None:
|
|
89
|
+
raise ValueError("Store connection is not available")
|
|
90
|
+
if entity.id is None:
|
|
91
|
+
raise ValueError("Chunk ID is required for update")
|
|
92
|
+
|
|
93
|
+
cursor = self.store._connection.cursor()
|
|
94
|
+
cursor.execute(
|
|
95
|
+
"""
|
|
96
|
+
UPDATE chunks
|
|
97
|
+
SET document_id = :document_id, content = :content, metadata = :metadata
|
|
98
|
+
WHERE id = :id
|
|
99
|
+
""",
|
|
100
|
+
{
|
|
101
|
+
"document_id": entity.document_id,
|
|
102
|
+
"content": entity.content,
|
|
103
|
+
"metadata": json.dumps(entity.metadata),
|
|
104
|
+
"id": entity.id,
|
|
105
|
+
},
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Regenerate and update embedding
|
|
109
|
+
embedding = await self.embedder.embed(entity.content)
|
|
110
|
+
serialized_embedding = self.store.serialize_embedding(embedding)
|
|
111
|
+
cursor.execute(
|
|
112
|
+
"""
|
|
113
|
+
UPDATE chunk_embeddings
|
|
114
|
+
SET embedding = :embedding
|
|
115
|
+
WHERE chunk_id = :chunk_id
|
|
116
|
+
""",
|
|
117
|
+
{"embedding": serialized_embedding, "chunk_id": entity.id},
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Update FTS5 table
|
|
121
|
+
cursor.execute(
|
|
122
|
+
"""
|
|
123
|
+
UPDATE chunks_fts
|
|
124
|
+
SET content = :content
|
|
125
|
+
WHERE rowid = :rowid
|
|
126
|
+
""",
|
|
127
|
+
{"content": entity.content, "rowid": entity.id},
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
self.store._connection.commit()
|
|
131
|
+
return entity
|
|
132
|
+
|
|
133
|
+
async def delete(self, entity_id: int, commit: bool = True) -> bool:
|
|
134
|
+
"""Delete a chunk by its ID."""
|
|
135
|
+
if self.store._connection is None:
|
|
136
|
+
raise ValueError("Store connection is not available")
|
|
137
|
+
|
|
138
|
+
cursor = self.store._connection.cursor()
|
|
139
|
+
|
|
140
|
+
# Delete from FTS5 table first
|
|
141
|
+
cursor.execute(
|
|
142
|
+
"DELETE FROM chunks_fts WHERE rowid = :rowid", {"rowid": entity_id}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Delete the embedding
|
|
146
|
+
cursor.execute(
|
|
147
|
+
"DELETE FROM chunk_embeddings WHERE chunk_id = :chunk_id",
|
|
148
|
+
{"chunk_id": entity_id},
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Delete the chunk
|
|
152
|
+
cursor.execute("DELETE FROM chunks WHERE id = :id", {"id": entity_id})
|
|
153
|
+
|
|
154
|
+
deleted = cursor.rowcount > 0
|
|
155
|
+
if commit:
|
|
156
|
+
self.store._connection.commit()
|
|
157
|
+
return deleted
|
|
158
|
+
|
|
159
|
+
async def list_all(
|
|
160
|
+
self, limit: int | None = None, offset: int | None = None
|
|
161
|
+
) -> list[Chunk]:
|
|
162
|
+
"""List all chunks with optional pagination."""
|
|
163
|
+
if self.store._connection is None:
|
|
164
|
+
raise ValueError("Store connection is not available")
|
|
165
|
+
|
|
166
|
+
cursor = self.store._connection.cursor()
|
|
167
|
+
query = "SELECT id, document_id, content, metadata FROM chunks ORDER BY document_id, id"
|
|
168
|
+
params = {}
|
|
169
|
+
|
|
170
|
+
if limit is not None:
|
|
171
|
+
query += " LIMIT :limit"
|
|
172
|
+
params["limit"] = limit
|
|
173
|
+
|
|
174
|
+
if offset is not None:
|
|
175
|
+
query += " OFFSET :offset"
|
|
176
|
+
params["offset"] = offset
|
|
177
|
+
|
|
178
|
+
cursor.execute(query, params)
|
|
179
|
+
rows = cursor.fetchall()
|
|
180
|
+
|
|
181
|
+
return [
|
|
182
|
+
Chunk(
|
|
183
|
+
id=chunk_id,
|
|
184
|
+
document_id=document_id,
|
|
185
|
+
content=content,
|
|
186
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
187
|
+
)
|
|
188
|
+
for chunk_id, document_id, content, metadata_json in rows
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
async def create_chunks_for_document(
|
|
192
|
+
self, document_id: int, content: str, commit: bool = True
|
|
193
|
+
) -> list[Chunk]:
|
|
194
|
+
"""Create chunks and embeddings for a document."""
|
|
195
|
+
# Chunk the document content
|
|
196
|
+
chunk_texts = await chunker.chunk(content)
|
|
197
|
+
created_chunks = []
|
|
198
|
+
|
|
199
|
+
# Create chunks with embeddings using the create method
|
|
200
|
+
for order, chunk_text in enumerate(chunk_texts):
|
|
201
|
+
# Create chunk with order in metadata
|
|
202
|
+
chunk = Chunk(
|
|
203
|
+
document_id=document_id, content=chunk_text, metadata={"order": order}
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
created_chunk = await self.create(chunk, commit=commit)
|
|
207
|
+
created_chunks.append(created_chunk)
|
|
208
|
+
|
|
209
|
+
return created_chunks
|
|
210
|
+
|
|
211
|
+
async def delete_by_document_id(
|
|
212
|
+
self, document_id: int, commit: bool = True
|
|
213
|
+
) -> bool:
|
|
214
|
+
"""Delete all chunks for a document."""
|
|
215
|
+
chunks = await self.get_by_document_id(document_id)
|
|
216
|
+
|
|
217
|
+
deleted_any = False
|
|
218
|
+
for chunk in chunks:
|
|
219
|
+
if chunk.id is not None:
|
|
220
|
+
deleted = await self.delete(chunk.id, commit=False)
|
|
221
|
+
deleted_any = deleted_any or deleted
|
|
222
|
+
|
|
223
|
+
if commit and deleted_any and self.store._connection:
|
|
224
|
+
self.store._connection.commit()
|
|
225
|
+
return deleted_any
|
|
226
|
+
|
|
227
|
+
async def search_chunks(
|
|
228
|
+
self, query: str, limit: int = 5
|
|
229
|
+
) -> list[tuple[Chunk, float]]:
|
|
230
|
+
"""Search for relevant chunks using vector similarity."""
|
|
231
|
+
if self.store._connection is None:
|
|
232
|
+
raise ValueError("Store connection is not available")
|
|
233
|
+
|
|
234
|
+
cursor = self.store._connection.cursor()
|
|
235
|
+
|
|
236
|
+
# Generate embedding for the query
|
|
237
|
+
query_embedding = await self.embedder.embed(query)
|
|
238
|
+
serialized_query_embedding = self.store.serialize_embedding(query_embedding)
|
|
239
|
+
|
|
240
|
+
# Search for similar chunks using sqlite-vec
|
|
241
|
+
cursor.execute(
|
|
242
|
+
"""
|
|
243
|
+
SELECT c.id, c.document_id, c.content, c.metadata, distance
|
|
244
|
+
FROM chunk_embeddings
|
|
245
|
+
JOIN chunks c ON c.id = chunk_embeddings.chunk_id
|
|
246
|
+
WHERE embedding MATCH :embedding AND k = :k
|
|
247
|
+
ORDER BY distance
|
|
248
|
+
""",
|
|
249
|
+
{"embedding": serialized_query_embedding, "k": limit},
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
results = cursor.fetchall()
|
|
253
|
+
return [
|
|
254
|
+
(
|
|
255
|
+
Chunk(
|
|
256
|
+
id=chunk_id,
|
|
257
|
+
document_id=document_id,
|
|
258
|
+
content=content,
|
|
259
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
260
|
+
),
|
|
261
|
+
1.0 / (1.0 + distance),
|
|
262
|
+
)
|
|
263
|
+
for chunk_id, document_id, content, metadata_json, distance in results
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
async def search_chunks_fts(
|
|
267
|
+
self, query: str, limit: int = 5
|
|
268
|
+
) -> list[tuple[Chunk, float]]:
|
|
269
|
+
"""Search for chunks using FTS5 full-text search."""
|
|
270
|
+
if self.store._connection is None:
|
|
271
|
+
raise ValueError("Store connection is not available")
|
|
272
|
+
|
|
273
|
+
cursor = self.store._connection.cursor()
|
|
274
|
+
|
|
275
|
+
# Clean the query for FTS5 - extract keywords for better matching
|
|
276
|
+
# Remove special characters and split into words
|
|
277
|
+
words = re.findall(r"\b\w+\b", query.lower())
|
|
278
|
+
# Join with OR to find chunks containing any of the keywords
|
|
279
|
+
fts_query = " OR ".join(words) if words else query
|
|
280
|
+
|
|
281
|
+
# Search using FTS5
|
|
282
|
+
cursor.execute(
|
|
283
|
+
"""
|
|
284
|
+
SELECT c.id, c.document_id, c.content, c.metadata, rank
|
|
285
|
+
FROM chunks_fts
|
|
286
|
+
JOIN chunks c ON c.id = chunks_fts.rowid
|
|
287
|
+
WHERE chunks_fts MATCH :query
|
|
288
|
+
ORDER BY rank
|
|
289
|
+
LIMIT :limit
|
|
290
|
+
""",
|
|
291
|
+
{"query": fts_query, "limit": limit},
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
results = cursor.fetchall()
|
|
295
|
+
|
|
296
|
+
return [
|
|
297
|
+
(
|
|
298
|
+
Chunk(
|
|
299
|
+
id=chunk_id,
|
|
300
|
+
document_id=document_id,
|
|
301
|
+
content=content,
|
|
302
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
303
|
+
),
|
|
304
|
+
-rank,
|
|
305
|
+
)
|
|
306
|
+
for chunk_id, document_id, content, metadata_json, rank in results
|
|
307
|
+
# FTS5 rank is negative BM25 score
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
async def search_chunks_hybrid(
|
|
311
|
+
self, query: str, limit: int = 5, k: int = 60
|
|
312
|
+
) -> list[tuple[Chunk, float]]:
|
|
313
|
+
"""Hybrid search using Reciprocal Rank Fusion (RRF) combining vector similarity and FTS5 full-text search."""
|
|
314
|
+
if self.store._connection is None:
|
|
315
|
+
raise ValueError("Store connection is not available")
|
|
316
|
+
|
|
317
|
+
cursor = self.store._connection.cursor()
|
|
318
|
+
|
|
319
|
+
# Generate embedding for the query
|
|
320
|
+
query_embedding = await self.embedder.embed(query)
|
|
321
|
+
serialized_query_embedding = self.store.serialize_embedding(query_embedding)
|
|
322
|
+
|
|
323
|
+
# Clean the query for FTS5 - extract keywords for better matching
|
|
324
|
+
# Remove special characters and split into words
|
|
325
|
+
words = re.findall(r"\b\w+\b", query.lower())
|
|
326
|
+
# Join with OR to find chunks containing any of the keywords
|
|
327
|
+
fts_query = " OR ".join(words) if words else query
|
|
328
|
+
|
|
329
|
+
# Perform hybrid search using RRF (Reciprocal Rank Fusion)
|
|
330
|
+
cursor.execute(
|
|
331
|
+
"""
|
|
332
|
+
WITH vector_search AS (
|
|
333
|
+
SELECT
|
|
334
|
+
c.id,
|
|
335
|
+
c.document_id,
|
|
336
|
+
c.content,
|
|
337
|
+
c.metadata,
|
|
338
|
+
ROW_NUMBER() OVER (ORDER BY ce.distance) as vector_rank
|
|
339
|
+
FROM chunk_embeddings ce
|
|
340
|
+
JOIN chunks c ON c.id = ce.chunk_id
|
|
341
|
+
WHERE ce.embedding MATCH :embedding AND k = :k_vector
|
|
342
|
+
ORDER BY ce.distance
|
|
343
|
+
),
|
|
344
|
+
fts_search AS (
|
|
345
|
+
SELECT
|
|
346
|
+
c.id,
|
|
347
|
+
c.document_id,
|
|
348
|
+
c.content,
|
|
349
|
+
c.metadata,
|
|
350
|
+
ROW_NUMBER() OVER (ORDER BY chunks_fts.rank) as fts_rank
|
|
351
|
+
FROM chunks_fts
|
|
352
|
+
JOIN chunks c ON c.id = chunks_fts.rowid
|
|
353
|
+
WHERE chunks_fts MATCH :fts_query
|
|
354
|
+
ORDER BY chunks_fts.rank
|
|
355
|
+
),
|
|
356
|
+
all_chunks AS (
|
|
357
|
+
SELECT id, document_id, content, metadata FROM vector_search
|
|
358
|
+
UNION
|
|
359
|
+
SELECT id, document_id, content, metadata FROM fts_search
|
|
360
|
+
),
|
|
361
|
+
rrf_scores AS (
|
|
362
|
+
SELECT
|
|
363
|
+
a.id,
|
|
364
|
+
a.document_id,
|
|
365
|
+
a.content,
|
|
366
|
+
a.metadata,
|
|
367
|
+
COALESCE(1.0 / (:k + v.vector_rank), 0) + COALESCE(1.0 / (:k + f.fts_rank), 0) as rrf_score
|
|
368
|
+
FROM all_chunks a
|
|
369
|
+
LEFT JOIN vector_search v ON a.id = v.id
|
|
370
|
+
LEFT JOIN fts_search f ON a.id = f.id
|
|
371
|
+
)
|
|
372
|
+
SELECT id, document_id, content, metadata, rrf_score
|
|
373
|
+
FROM rrf_scores
|
|
374
|
+
ORDER BY rrf_score DESC
|
|
375
|
+
LIMIT :limit
|
|
376
|
+
""",
|
|
377
|
+
{
|
|
378
|
+
"embedding": serialized_query_embedding,
|
|
379
|
+
"k_vector": limit * 3,
|
|
380
|
+
"fts_query": fts_query,
|
|
381
|
+
"k": k,
|
|
382
|
+
"limit": limit,
|
|
383
|
+
},
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
results = cursor.fetchall()
|
|
387
|
+
return [
|
|
388
|
+
(
|
|
389
|
+
Chunk(
|
|
390
|
+
id=chunk_id,
|
|
391
|
+
document_id=document_id,
|
|
392
|
+
content=content,
|
|
393
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
394
|
+
),
|
|
395
|
+
rrf_score,
|
|
396
|
+
)
|
|
397
|
+
for chunk_id, document_id, content, metadata_json, rrf_score in results
|
|
398
|
+
]
|
|
399
|
+
|
|
400
|
+
async def get_by_document_id(self, document_id: int) -> list[Chunk]:
|
|
401
|
+
"""Get all chunks for a specific document."""
|
|
402
|
+
if self.store._connection is None:
|
|
403
|
+
raise ValueError("Store connection is not available")
|
|
404
|
+
|
|
405
|
+
cursor = self.store._connection.cursor()
|
|
406
|
+
cursor.execute(
|
|
407
|
+
"""
|
|
408
|
+
SELECT id, document_id, content, metadata
|
|
409
|
+
FROM chunks WHERE document_id = :document_id
|
|
410
|
+
ORDER BY JSON_EXTRACT(metadata, '$.order')
|
|
411
|
+
""",
|
|
412
|
+
{"document_id": document_id},
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
rows = cursor.fetchall()
|
|
416
|
+
return [
|
|
417
|
+
Chunk(
|
|
418
|
+
id=chunk_id,
|
|
419
|
+
document_id=document_id,
|
|
420
|
+
content=content,
|
|
421
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
422
|
+
)
|
|
423
|
+
for chunk_id, document_id, content, metadata_json in rows
|
|
424
|
+
]
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from haiku.rag.store.models.document import Document
|
|
4
|
+
from haiku.rag.store.repositories.base import BaseRepository
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DocumentRepository(BaseRepository[Document]):
|
|
8
|
+
"""Repository for Document database operations."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, store, chunk_repository=None):
|
|
11
|
+
super().__init__(store)
|
|
12
|
+
# Avoid circular import by using late import if not provided
|
|
13
|
+
if chunk_repository is None:
|
|
14
|
+
from haiku.rag.store.repositories.chunk import ChunkRepository
|
|
15
|
+
|
|
16
|
+
chunk_repository = ChunkRepository(store)
|
|
17
|
+
self.chunk_repository = chunk_repository
|
|
18
|
+
|
|
19
|
+
async def create(self, entity: Document) -> Document:
|
|
20
|
+
"""Create a document with its chunks and embeddings."""
|
|
21
|
+
if self.store._connection is None:
|
|
22
|
+
raise ValueError("Store connection is not available")
|
|
23
|
+
|
|
24
|
+
cursor = self.store._connection.cursor()
|
|
25
|
+
|
|
26
|
+
# Start transaction
|
|
27
|
+
cursor.execute("BEGIN TRANSACTION")
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
# Insert the document
|
|
31
|
+
cursor.execute(
|
|
32
|
+
"""
|
|
33
|
+
INSERT INTO documents (content, uri, metadata, created_at, updated_at)
|
|
34
|
+
VALUES (:content, :uri, :metadata, :created_at, :updated_at)
|
|
35
|
+
""",
|
|
36
|
+
{
|
|
37
|
+
"content": entity.content,
|
|
38
|
+
"uri": entity.uri,
|
|
39
|
+
"metadata": json.dumps(entity.metadata),
|
|
40
|
+
"created_at": entity.created_at,
|
|
41
|
+
"updated_at": entity.updated_at,
|
|
42
|
+
},
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
document_id = cursor.lastrowid
|
|
46
|
+
assert document_id is not None, "Failed to create document in database"
|
|
47
|
+
entity.id = document_id
|
|
48
|
+
|
|
49
|
+
# Create chunks and embeddings using ChunkRepository
|
|
50
|
+
await self.chunk_repository.create_chunks_for_document(
|
|
51
|
+
document_id, entity.content, commit=False
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
cursor.execute("COMMIT")
|
|
55
|
+
return entity
|
|
56
|
+
|
|
57
|
+
except Exception:
|
|
58
|
+
cursor.execute("ROLLBACK")
|
|
59
|
+
raise
|
|
60
|
+
|
|
61
|
+
async def get_by_id(self, entity_id: int) -> Document | None:
|
|
62
|
+
"""Get a document by its ID."""
|
|
63
|
+
if self.store._connection is None:
|
|
64
|
+
raise ValueError("Store connection is not available")
|
|
65
|
+
|
|
66
|
+
cursor = self.store._connection.cursor()
|
|
67
|
+
cursor.execute(
|
|
68
|
+
"""
|
|
69
|
+
SELECT id, content, uri, metadata, created_at, updated_at
|
|
70
|
+
FROM documents WHERE id = :id
|
|
71
|
+
""",
|
|
72
|
+
{"id": entity_id},
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
row = cursor.fetchone()
|
|
76
|
+
if row is None:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
document_id, content, uri, metadata_json, created_at, updated_at = row
|
|
80
|
+
metadata = json.loads(metadata_json) if metadata_json else {}
|
|
81
|
+
|
|
82
|
+
return Document(
|
|
83
|
+
id=document_id,
|
|
84
|
+
content=content,
|
|
85
|
+
uri=uri,
|
|
86
|
+
metadata=metadata,
|
|
87
|
+
created_at=created_at,
|
|
88
|
+
updated_at=updated_at,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
async def get_by_uri(self, uri: str) -> Document | None:
|
|
92
|
+
"""Get a document by its URI."""
|
|
93
|
+
if self.store._connection is None:
|
|
94
|
+
raise ValueError("Store connection is not available")
|
|
95
|
+
|
|
96
|
+
cursor = self.store._connection.cursor()
|
|
97
|
+
cursor.execute(
|
|
98
|
+
"""
|
|
99
|
+
SELECT id, content, uri, metadata, created_at, updated_at
|
|
100
|
+
FROM documents WHERE uri = :uri
|
|
101
|
+
""",
|
|
102
|
+
{"uri": uri},
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
row = cursor.fetchone()
|
|
106
|
+
if row is None:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
document_id, content, uri, metadata_json, created_at, updated_at = row
|
|
110
|
+
metadata = json.loads(metadata_json) if metadata_json else {}
|
|
111
|
+
|
|
112
|
+
return Document(
|
|
113
|
+
id=document_id,
|
|
114
|
+
content=content,
|
|
115
|
+
uri=uri,
|
|
116
|
+
metadata=metadata,
|
|
117
|
+
created_at=created_at,
|
|
118
|
+
updated_at=updated_at,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def update(self, entity: Document) -> Document:
|
|
122
|
+
"""Update an existing document and regenerate its chunks and embeddings."""
|
|
123
|
+
if self.store._connection is None:
|
|
124
|
+
raise ValueError("Store connection is not available")
|
|
125
|
+
if entity.id is None:
|
|
126
|
+
raise ValueError("Document ID is required for update")
|
|
127
|
+
|
|
128
|
+
cursor = self.store._connection.cursor()
|
|
129
|
+
|
|
130
|
+
# Start transaction
|
|
131
|
+
cursor.execute("BEGIN TRANSACTION")
|
|
132
|
+
|
|
133
|
+
try:
|
|
134
|
+
# Update the document
|
|
135
|
+
cursor.execute(
|
|
136
|
+
"""
|
|
137
|
+
UPDATE documents
|
|
138
|
+
SET content = :content, uri = :uri, metadata = :metadata, updated_at = :updated_at
|
|
139
|
+
WHERE id = :id
|
|
140
|
+
""",
|
|
141
|
+
{
|
|
142
|
+
"content": entity.content,
|
|
143
|
+
"uri": entity.uri,
|
|
144
|
+
"metadata": json.dumps(entity.metadata),
|
|
145
|
+
"updated_at": entity.updated_at,
|
|
146
|
+
"id": entity.id,
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Delete existing chunks and regenerate using ChunkRepository
|
|
151
|
+
await self.chunk_repository.delete_by_document_id(entity.id, commit=False)
|
|
152
|
+
await self.chunk_repository.create_chunks_for_document(
|
|
153
|
+
entity.id, entity.content, commit=False
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
cursor.execute("COMMIT")
|
|
157
|
+
return entity
|
|
158
|
+
|
|
159
|
+
except Exception:
|
|
160
|
+
cursor.execute("ROLLBACK")
|
|
161
|
+
raise
|
|
162
|
+
|
|
163
|
+
async def delete(self, entity_id: int) -> bool:
|
|
164
|
+
"""Delete a document and all its associated chunks and embeddings."""
|
|
165
|
+
# Delete chunks and embeddings first
|
|
166
|
+
await self.chunk_repository.delete_by_document_id(entity_id)
|
|
167
|
+
|
|
168
|
+
if self.store._connection is None:
|
|
169
|
+
raise ValueError("Store connection is not available")
|
|
170
|
+
|
|
171
|
+
cursor = self.store._connection.cursor()
|
|
172
|
+
cursor.execute("DELETE FROM documents WHERE id = :id", {"id": entity_id})
|
|
173
|
+
|
|
174
|
+
deleted = cursor.rowcount > 0
|
|
175
|
+
self.store._connection.commit()
|
|
176
|
+
return deleted
|
|
177
|
+
|
|
178
|
+
async def list_all(
|
|
179
|
+
self, limit: int | None = None, offset: int | None = None
|
|
180
|
+
) -> list[Document]:
|
|
181
|
+
"""List all documents with optional pagination."""
|
|
182
|
+
if self.store._connection is None:
|
|
183
|
+
raise ValueError("Store connection is not available")
|
|
184
|
+
|
|
185
|
+
cursor = self.store._connection.cursor()
|
|
186
|
+
query = "SELECT id, content, uri, metadata, created_at, updated_at FROM documents ORDER BY created_at DESC"
|
|
187
|
+
params = {}
|
|
188
|
+
|
|
189
|
+
if limit is not None:
|
|
190
|
+
query += " LIMIT :limit"
|
|
191
|
+
params["limit"] = limit
|
|
192
|
+
|
|
193
|
+
if offset is not None:
|
|
194
|
+
query += " OFFSET :offset"
|
|
195
|
+
params["offset"] = offset
|
|
196
|
+
|
|
197
|
+
cursor.execute(query, params)
|
|
198
|
+
rows = cursor.fetchall()
|
|
199
|
+
|
|
200
|
+
return [
|
|
201
|
+
Document(
|
|
202
|
+
id=document_id,
|
|
203
|
+
content=content,
|
|
204
|
+
uri=uri,
|
|
205
|
+
metadata=json.loads(metadata_json) if metadata_json else {},
|
|
206
|
+
created_at=created_at,
|
|
207
|
+
updated_at=updated_at,
|
|
208
|
+
)
|
|
209
|
+
for document_id, content, uri, metadata_json, created_at, updated_at in rows
|
|
210
|
+
]
|
haiku/rag/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_default_data_dir() -> Path:
|
|
6
|
+
"""
|
|
7
|
+
Get the user data directory for the current system platform.
|
|
8
|
+
|
|
9
|
+
Linux: ~/.local/share/haiku.rag
|
|
10
|
+
macOS: ~/Library/Application Support/haiku.rag
|
|
11
|
+
Windows: C:/Users/<USER>/AppData/Roaming/haiku.rag
|
|
12
|
+
|
|
13
|
+
:return: User Data Path
|
|
14
|
+
:rtype: Path
|
|
15
|
+
"""
|
|
16
|
+
home = Path.home()
|
|
17
|
+
|
|
18
|
+
system_paths = {
|
|
19
|
+
"win32": home / "AppData/Roaming/haiku.rag",
|
|
20
|
+
"linux": home / ".local/share/haiku.rag",
|
|
21
|
+
"darwin": home / "Library/Application Support/haiku.rag",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
data_path = system_paths[sys.platform]
|
|
25
|
+
return data_path
|