haiku.rag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +107 -0
- haiku/rag/chunker.py +76 -0
- haiku/rag/cli.py +153 -0
- haiku/rag/client.py +261 -0
- haiku/rag/config.py +28 -0
- haiku/rag/embeddings/__init__.py +24 -0
- haiku/rag/embeddings/base.py +12 -0
- haiku/rag/embeddings/ollama.py +14 -0
- haiku/rag/embeddings/voyageai.py +17 -0
- haiku/rag/mcp.py +141 -0
- haiku/rag/reader.py +52 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +80 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +12 -0
- haiku/rag/store/models/document.py +16 -0
- haiku/rag/store/repositories/__init__.py +5 -0
- haiku/rag/store/repositories/base.py +40 -0
- haiku/rag/store/repositories/chunk.py +424 -0
- haiku/rag/store/repositories/document.py +210 -0
- haiku/rag/utils.py +25 -0
- haiku_rag-0.1.0.dist-info/METADATA +195 -0
- haiku_rag-0.1.0.dist-info/RECORD +27 -0
- haiku_rag-0.1.0.dist-info/WHEEL +4 -0
- haiku_rag-0.1.0.dist-info/entry_points.txt +2 -0
- haiku_rag-0.1.0.dist-info/licenses/LICENSE +7 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from voyageai.client import Client # type: ignore
|
|
3
|
+
|
|
4
|
+
from haiku.rag.config import Config
|
|
5
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
6
|
+
|
|
7
|
+
class Embedder(EmbedderBase):
|
|
8
|
+
_model: str = Config.EMBEDDING_MODEL
|
|
9
|
+
_vector_dim: int = 1024
|
|
10
|
+
|
|
11
|
+
async def embed(self, text: str) -> list[float]:
|
|
12
|
+
client = Client()
|
|
13
|
+
res = client.embed([text], model=self._model, output_dtype="float")
|
|
14
|
+
return res.embeddings[0] # type: ignore[return-value]
|
|
15
|
+
|
|
16
|
+
except ImportError:
|
|
17
|
+
pass
|
haiku/rag/mcp.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
|
|
4
|
+
from fastmcp import FastMCP
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from haiku.rag.client import HaikuRAG
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SearchResult(BaseModel):
|
|
11
|
+
document_id: int
|
|
12
|
+
content: str
|
|
13
|
+
score: float
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DocumentResult(BaseModel):
|
|
17
|
+
id: int | None
|
|
18
|
+
content: str
|
|
19
|
+
uri: str | None = None
|
|
20
|
+
metadata: dict[str, Any] = {}
|
|
21
|
+
created_at: str
|
|
22
|
+
updated_at: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_mcp_server(db_path: Path | Literal[":memory:"]) -> FastMCP:
|
|
26
|
+
"""Create an MCP server with the specified database path."""
|
|
27
|
+
mcp = FastMCP("haiku-rag")
|
|
28
|
+
|
|
29
|
+
@mcp.tool()
|
|
30
|
+
async def add_document_from_file(
|
|
31
|
+
file_path: str, metadata: dict[str, Any] | None = None
|
|
32
|
+
) -> int | None:
|
|
33
|
+
"""Add a document to the RAG system from a file path."""
|
|
34
|
+
try:
|
|
35
|
+
async with HaikuRAG(db_path) as rag:
|
|
36
|
+
document = await rag.create_document_from_source(
|
|
37
|
+
Path(file_path), metadata or {}
|
|
38
|
+
)
|
|
39
|
+
return document.id
|
|
40
|
+
except Exception:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
@mcp.tool()
|
|
44
|
+
async def add_document_from_url(
|
|
45
|
+
url: str, metadata: dict[str, Any] | None = None
|
|
46
|
+
) -> int | None:
|
|
47
|
+
"""Add a document to the RAG system from a URL."""
|
|
48
|
+
try:
|
|
49
|
+
async with HaikuRAG(db_path) as rag:
|
|
50
|
+
document = await rag.create_document_from_source(url, metadata or {})
|
|
51
|
+
return document.id
|
|
52
|
+
except Exception:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
@mcp.tool()
|
|
56
|
+
async def add_document_from_text(
|
|
57
|
+
content: str, uri: str | None = None, metadata: dict[str, Any] | None = None
|
|
58
|
+
) -> int | None:
|
|
59
|
+
"""Add a document to the RAG system from text content."""
|
|
60
|
+
try:
|
|
61
|
+
async with HaikuRAG(db_path) as rag:
|
|
62
|
+
document = await rag.create_document(content, uri, metadata or {})
|
|
63
|
+
return document.id
|
|
64
|
+
except Exception:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
@mcp.tool()
|
|
68
|
+
async def search_documents(query: str, limit: int = 5) -> list[SearchResult]:
|
|
69
|
+
"""Search the RAG system for documents using hybrid search (vector similarity + full-text search)."""
|
|
70
|
+
try:
|
|
71
|
+
async with HaikuRAG(db_path) as rag:
|
|
72
|
+
results = await rag.search(query, limit)
|
|
73
|
+
|
|
74
|
+
search_results = []
|
|
75
|
+
for chunk, score in results:
|
|
76
|
+
search_results.append(
|
|
77
|
+
SearchResult(
|
|
78
|
+
document_id=chunk.document_id,
|
|
79
|
+
content=chunk.content,
|
|
80
|
+
score=score,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return search_results
|
|
85
|
+
except Exception:
|
|
86
|
+
return []
|
|
87
|
+
|
|
88
|
+
@mcp.tool()
|
|
89
|
+
async def get_document(document_id: int) -> DocumentResult | None:
|
|
90
|
+
"""Get a document by its ID."""
|
|
91
|
+
try:
|
|
92
|
+
async with HaikuRAG(db_path) as rag:
|
|
93
|
+
document = await rag.get_document_by_id(document_id)
|
|
94
|
+
|
|
95
|
+
if document is None:
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
return DocumentResult(
|
|
99
|
+
id=document.id,
|
|
100
|
+
content=document.content,
|
|
101
|
+
uri=document.uri,
|
|
102
|
+
metadata=document.metadata,
|
|
103
|
+
created_at=str(document.created_at),
|
|
104
|
+
updated_at=str(document.updated_at),
|
|
105
|
+
)
|
|
106
|
+
except Exception:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
@mcp.tool()
|
|
110
|
+
async def list_documents(
|
|
111
|
+
limit: int | None = None, offset: int | None = None
|
|
112
|
+
) -> list[DocumentResult]:
|
|
113
|
+
"""List all documents with optional pagination."""
|
|
114
|
+
try:
|
|
115
|
+
async with HaikuRAG(db_path) as rag:
|
|
116
|
+
documents = await rag.list_documents(limit, offset)
|
|
117
|
+
|
|
118
|
+
return [
|
|
119
|
+
DocumentResult(
|
|
120
|
+
id=doc.id,
|
|
121
|
+
content=doc.content,
|
|
122
|
+
uri=doc.uri,
|
|
123
|
+
metadata=doc.metadata,
|
|
124
|
+
created_at=str(doc.created_at),
|
|
125
|
+
updated_at=str(doc.updated_at),
|
|
126
|
+
)
|
|
127
|
+
for doc in documents
|
|
128
|
+
]
|
|
129
|
+
except Exception:
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
@mcp.tool()
|
|
133
|
+
async def delete_document(document_id: int) -> bool:
|
|
134
|
+
"""Delete a document by its ID."""
|
|
135
|
+
try:
|
|
136
|
+
async with HaikuRAG(db_path) as rag:
|
|
137
|
+
return await rag.delete_document(document_id)
|
|
138
|
+
except Exception:
|
|
139
|
+
return False
|
|
140
|
+
|
|
141
|
+
return mcp
|
haiku/rag/reader.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
|
|
4
|
+
from markitdown import MarkItDown
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FileReader:
|
|
8
|
+
extensions: ClassVar[list[str]] = [
|
|
9
|
+
".astro",
|
|
10
|
+
".c",
|
|
11
|
+
".cpp",
|
|
12
|
+
".css",
|
|
13
|
+
".csv",
|
|
14
|
+
".docx",
|
|
15
|
+
".go",
|
|
16
|
+
".h",
|
|
17
|
+
".hpp",
|
|
18
|
+
".html",
|
|
19
|
+
".java",
|
|
20
|
+
".js",
|
|
21
|
+
".json",
|
|
22
|
+
".kt",
|
|
23
|
+
".md",
|
|
24
|
+
".mdx",
|
|
25
|
+
".mjs",
|
|
26
|
+
".mp3",
|
|
27
|
+
".pdf",
|
|
28
|
+
".php",
|
|
29
|
+
".pptx",
|
|
30
|
+
".py",
|
|
31
|
+
".rb",
|
|
32
|
+
".rs",
|
|
33
|
+
".svelte",
|
|
34
|
+
".swift",
|
|
35
|
+
".ts",
|
|
36
|
+
".tsx",
|
|
37
|
+
".txt",
|
|
38
|
+
".vue",
|
|
39
|
+
".wav",
|
|
40
|
+
".xml",
|
|
41
|
+
".xlsx",
|
|
42
|
+
".yaml",
|
|
43
|
+
".yml",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def parse_file(path: Path) -> str:
|
|
48
|
+
try:
|
|
49
|
+
reader = MarkItDown()
|
|
50
|
+
return reader.convert(path).text_content
|
|
51
|
+
except Exception:
|
|
52
|
+
raise ValueError(f"Failed to parse file: {path}")
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import sqlite3
|
|
2
|
+
import struct
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import sqlite_vec
|
|
7
|
+
|
|
8
|
+
from haiku.rag.embeddings import get_embedder
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Store:
|
|
12
|
+
def __init__(self, db_path: Path | Literal[":memory:"]):
|
|
13
|
+
self.db_path: Path | Literal[":memory:"] = db_path
|
|
14
|
+
self._connection = self.create_db()
|
|
15
|
+
|
|
16
|
+
def create_db(self) -> sqlite3.Connection:
|
|
17
|
+
"""Create the database and tables with sqlite-vec support for embeddings."""
|
|
18
|
+
db = sqlite3.connect(self.db_path)
|
|
19
|
+
db.enable_load_extension(True)
|
|
20
|
+
sqlite_vec.load(db)
|
|
21
|
+
|
|
22
|
+
# Create documents table
|
|
23
|
+
db.execute("""
|
|
24
|
+
CREATE TABLE IF NOT EXISTS documents (
|
|
25
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
26
|
+
content TEXT NOT NULL,
|
|
27
|
+
uri TEXT,
|
|
28
|
+
metadata TEXT DEFAULT '{}',
|
|
29
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
30
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
31
|
+
)
|
|
32
|
+
""")
|
|
33
|
+
|
|
34
|
+
# Create chunks table
|
|
35
|
+
db.execute("""
|
|
36
|
+
CREATE TABLE IF NOT EXISTS chunks (
|
|
37
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
38
|
+
document_id INTEGER NOT NULL,
|
|
39
|
+
content TEXT NOT NULL,
|
|
40
|
+
metadata TEXT DEFAULT '{}',
|
|
41
|
+
FOREIGN KEY (document_id) REFERENCES documents (id) ON DELETE CASCADE
|
|
42
|
+
)
|
|
43
|
+
""")
|
|
44
|
+
|
|
45
|
+
# Create vector table for chunk embeddings
|
|
46
|
+
embedder = get_embedder()
|
|
47
|
+
db.execute(f"""
|
|
48
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS chunk_embeddings USING vec0(
|
|
49
|
+
chunk_id INTEGER PRIMARY KEY,
|
|
50
|
+
embedding FLOAT[{embedder._vector_dim}]
|
|
51
|
+
)
|
|
52
|
+
""")
|
|
53
|
+
|
|
54
|
+
# Create FTS5 table for full-text search
|
|
55
|
+
db.execute("""
|
|
56
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
|
57
|
+
content,
|
|
58
|
+
content='chunks',
|
|
59
|
+
content_rowid='id'
|
|
60
|
+
)
|
|
61
|
+
""")
|
|
62
|
+
|
|
63
|
+
# Create indexes for better performance
|
|
64
|
+
db.execute(
|
|
65
|
+
"CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id)"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
db.commit()
|
|
69
|
+
return db
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def serialize_embedding(embedding: list[float]) -> bytes:
|
|
73
|
+
"""Serialize a list of floats to bytes for sqlite-vec storage."""
|
|
74
|
+
return struct.pack(f"{len(embedding)}f", *embedding)
|
|
75
|
+
|
|
76
|
+
def close(self):
|
|
77
|
+
"""Close the database connection if it's an in-memory database."""
|
|
78
|
+
if self._connection is not None:
|
|
79
|
+
self._connection.close()
|
|
80
|
+
self._connection = None
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Document(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Represents a document with an ID, content, and metadata.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
id: int | None = None
|
|
12
|
+
content: str
|
|
13
|
+
uri: str | None = None
|
|
14
|
+
metadata: dict = {}
|
|
15
|
+
created_at: datetime = Field(default_factory=datetime.now)
|
|
16
|
+
updated_at: datetime = Field(default_factory=datetime.now)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
from haiku.rag.store.engine import Store
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseRepository(ABC, Generic[T]):
|
|
10
|
+
"""Base repository interface for database operations."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, store: Store):
|
|
13
|
+
self.store = store
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def create(self, entity: T) -> T:
|
|
17
|
+
"""Create a new entity in the database."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
async def get_by_id(self, entity_id: int) -> T | None:
|
|
22
|
+
"""Get an entity by its ID."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
async def update(self, entity: T) -> T:
|
|
27
|
+
"""Update an existing entity."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def delete(self, entity_id: int) -> bool:
|
|
32
|
+
"""Delete an entity by its ID."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def list_all(
|
|
37
|
+
self, limit: int | None = None, offset: int | None = None
|
|
38
|
+
) -> list[T]:
|
|
39
|
+
"""List all entities with optional pagination."""
|
|
40
|
+
pass
|