haiku.rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

@@ -0,0 +1,17 @@
1
+ try:
2
+ from voyageai.client import Client # type: ignore
3
+
4
+ from haiku.rag.config import Config
5
+ from haiku.rag.embeddings.base import EmbedderBase
6
+
7
+ class Embedder(EmbedderBase):
8
+ _model: str = Config.EMBEDDING_MODEL
9
+ _vector_dim: int = 1024
10
+
11
+ async def embed(self, text: str) -> list[float]:
12
+ client = Client()
13
+ res = client.embed([text], model=self._model, output_dtype="float")
14
+ return res.embeddings[0] # type: ignore[return-value]
15
+
16
+ except ImportError:
17
+ pass
haiku/rag/mcp.py ADDED
@@ -0,0 +1,141 @@
1
+ from pathlib import Path
2
+ from typing import Any, Literal
3
+
4
+ from fastmcp import FastMCP
5
+ from pydantic import BaseModel
6
+
7
+ from haiku.rag.client import HaikuRAG
8
+
9
+
10
+ class SearchResult(BaseModel):
11
+ document_id: int
12
+ content: str
13
+ score: float
14
+
15
+
16
+ class DocumentResult(BaseModel):
17
+ id: int | None
18
+ content: str
19
+ uri: str | None = None
20
+ metadata: dict[str, Any] = {}
21
+ created_at: str
22
+ updated_at: str
23
+
24
+
25
+ def create_mcp_server(db_path: Path | Literal[":memory:"]) -> FastMCP:
26
+ """Create an MCP server with the specified database path."""
27
+ mcp = FastMCP("haiku-rag")
28
+
29
+ @mcp.tool()
30
+ async def add_document_from_file(
31
+ file_path: str, metadata: dict[str, Any] | None = None
32
+ ) -> int | None:
33
+ """Add a document to the RAG system from a file path."""
34
+ try:
35
+ async with HaikuRAG(db_path) as rag:
36
+ document = await rag.create_document_from_source(
37
+ Path(file_path), metadata or {}
38
+ )
39
+ return document.id
40
+ except Exception:
41
+ return None
42
+
43
+ @mcp.tool()
44
+ async def add_document_from_url(
45
+ url: str, metadata: dict[str, Any] | None = None
46
+ ) -> int | None:
47
+ """Add a document to the RAG system from a URL."""
48
+ try:
49
+ async with HaikuRAG(db_path) as rag:
50
+ document = await rag.create_document_from_source(url, metadata or {})
51
+ return document.id
52
+ except Exception:
53
+ return None
54
+
55
+ @mcp.tool()
56
+ async def add_document_from_text(
57
+ content: str, uri: str | None = None, metadata: dict[str, Any] | None = None
58
+ ) -> int | None:
59
+ """Add a document to the RAG system from text content."""
60
+ try:
61
+ async with HaikuRAG(db_path) as rag:
62
+ document = await rag.create_document(content, uri, metadata or {})
63
+ return document.id
64
+ except Exception:
65
+ return None
66
+
67
+ @mcp.tool()
68
+ async def search_documents(query: str, limit: int = 5) -> list[SearchResult]:
69
+ """Search the RAG system for documents using hybrid search (vector similarity + full-text search)."""
70
+ try:
71
+ async with HaikuRAG(db_path) as rag:
72
+ results = await rag.search(query, limit)
73
+
74
+ search_results = []
75
+ for chunk, score in results:
76
+ search_results.append(
77
+ SearchResult(
78
+ document_id=chunk.document_id,
79
+ content=chunk.content,
80
+ score=score,
81
+ )
82
+ )
83
+
84
+ return search_results
85
+ except Exception:
86
+ return []
87
+
88
+ @mcp.tool()
89
+ async def get_document(document_id: int) -> DocumentResult | None:
90
+ """Get a document by its ID."""
91
+ try:
92
+ async with HaikuRAG(db_path) as rag:
93
+ document = await rag.get_document_by_id(document_id)
94
+
95
+ if document is None:
96
+ return None
97
+
98
+ return DocumentResult(
99
+ id=document.id,
100
+ content=document.content,
101
+ uri=document.uri,
102
+ metadata=document.metadata,
103
+ created_at=str(document.created_at),
104
+ updated_at=str(document.updated_at),
105
+ )
106
+ except Exception:
107
+ return None
108
+
109
+ @mcp.tool()
110
+ async def list_documents(
111
+ limit: int | None = None, offset: int | None = None
112
+ ) -> list[DocumentResult]:
113
+ """List all documents with optional pagination."""
114
+ try:
115
+ async with HaikuRAG(db_path) as rag:
116
+ documents = await rag.list_documents(limit, offset)
117
+
118
+ return [
119
+ DocumentResult(
120
+ id=doc.id,
121
+ content=doc.content,
122
+ uri=doc.uri,
123
+ metadata=doc.metadata,
124
+ created_at=str(doc.created_at),
125
+ updated_at=str(doc.updated_at),
126
+ )
127
+ for doc in documents
128
+ ]
129
+ except Exception:
130
+ return []
131
+
132
+ @mcp.tool()
133
+ async def delete_document(document_id: int) -> bool:
134
+ """Delete a document by its ID."""
135
+ try:
136
+ async with HaikuRAG(db_path) as rag:
137
+ return await rag.delete_document(document_id)
138
+ except Exception:
139
+ return False
140
+
141
+ return mcp
haiku/rag/reader.py ADDED
@@ -0,0 +1,52 @@
1
+ from pathlib import Path
2
+ from typing import ClassVar
3
+
4
+ from markitdown import MarkItDown
5
+
6
+
7
+ class FileReader:
8
+ extensions: ClassVar[list[str]] = [
9
+ ".astro",
10
+ ".c",
11
+ ".cpp",
12
+ ".css",
13
+ ".csv",
14
+ ".docx",
15
+ ".go",
16
+ ".h",
17
+ ".hpp",
18
+ ".html",
19
+ ".java",
20
+ ".js",
21
+ ".json",
22
+ ".kt",
23
+ ".md",
24
+ ".mdx",
25
+ ".mjs",
26
+ ".mp3",
27
+ ".pdf",
28
+ ".php",
29
+ ".pptx",
30
+ ".py",
31
+ ".rb",
32
+ ".rs",
33
+ ".svelte",
34
+ ".swift",
35
+ ".ts",
36
+ ".tsx",
37
+ ".txt",
38
+ ".vue",
39
+ ".wav",
40
+ ".xml",
41
+ ".xlsx",
42
+ ".yaml",
43
+ ".yml",
44
+ ]
45
+
46
+ @staticmethod
47
+ def parse_file(path: Path) -> str:
48
+ try:
49
+ reader = MarkItDown()
50
+ return reader.convert(path).text_content
51
+ except Exception:
52
+ raise ValueError(f"Failed to parse file: {path}")
@@ -0,0 +1,4 @@
1
+ from .engine import Store
2
+ from .models import Chunk, Document
3
+
4
+ __all__ = ["Store", "Chunk", "Document"]
@@ -0,0 +1,80 @@
1
+ import sqlite3
2
+ import struct
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ import sqlite_vec
7
+
8
+ from haiku.rag.embeddings import get_embedder
9
+
10
+
11
+ class Store:
12
+ def __init__(self, db_path: Path | Literal[":memory:"]):
13
+ self.db_path: Path | Literal[":memory:"] = db_path
14
+ self._connection = self.create_db()
15
+
16
+ def create_db(self) -> sqlite3.Connection:
17
+ """Create the database and tables with sqlite-vec support for embeddings."""
18
+ db = sqlite3.connect(self.db_path)
19
+ db.enable_load_extension(True)
20
+ sqlite_vec.load(db)
21
+
22
+ # Create documents table
23
+ db.execute("""
24
+ CREATE TABLE IF NOT EXISTS documents (
25
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
26
+ content TEXT NOT NULL,
27
+ uri TEXT,
28
+ metadata TEXT DEFAULT '{}',
29
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
30
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
31
+ )
32
+ """)
33
+
34
+ # Create chunks table
35
+ db.execute("""
36
+ CREATE TABLE IF NOT EXISTS chunks (
37
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
38
+ document_id INTEGER NOT NULL,
39
+ content TEXT NOT NULL,
40
+ metadata TEXT DEFAULT '{}',
41
+ FOREIGN KEY (document_id) REFERENCES documents (id) ON DELETE CASCADE
42
+ )
43
+ """)
44
+
45
+ # Create vector table for chunk embeddings
46
+ embedder = get_embedder()
47
+ db.execute(f"""
48
+ CREATE VIRTUAL TABLE IF NOT EXISTS chunk_embeddings USING vec0(
49
+ chunk_id INTEGER PRIMARY KEY,
50
+ embedding FLOAT[{embedder._vector_dim}]
51
+ )
52
+ """)
53
+
54
+ # Create FTS5 table for full-text search
55
+ db.execute("""
56
+ CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
57
+ content,
58
+ content='chunks',
59
+ content_rowid='id'
60
+ )
61
+ """)
62
+
63
+ # Create indexes for better performance
64
+ db.execute(
65
+ "CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id)"
66
+ )
67
+
68
+ db.commit()
69
+ return db
70
+
71
+ @staticmethod
72
+ def serialize_embedding(embedding: list[float]) -> bytes:
73
+ """Serialize a list of floats to bytes for sqlite-vec storage."""
74
+ return struct.pack(f"{len(embedding)}f", *embedding)
75
+
76
+ def close(self):
77
+ """Close the database connection if it's an in-memory database."""
78
+ if self._connection is not None:
79
+ self._connection.close()
80
+ self._connection = None
@@ -0,0 +1,4 @@
1
+ from .chunk import Chunk
2
+ from .document import Document
3
+
4
+ __all__ = ["Chunk", "Document"]
@@ -0,0 +1,12 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class Chunk(BaseModel):
5
+ """
6
+ Represents a document with an ID, content, and metadata.
7
+ """
8
+
9
+ id: int | None = None
10
+ document_id: int
11
+ content: str
12
+ metadata: dict = {}
@@ -0,0 +1,16 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class Document(BaseModel):
7
+ """
8
+ Represents a document with an ID, content, and metadata.
9
+ """
10
+
11
+ id: int | None = None
12
+ content: str
13
+ uri: str | None = None
14
+ metadata: dict = {}
15
+ created_at: datetime = Field(default_factory=datetime.now)
16
+ updated_at: datetime = Field(default_factory=datetime.now)
@@ -0,0 +1,5 @@
1
+ from haiku.rag.store.repositories.base import BaseRepository
2
+ from haiku.rag.store.repositories.chunk import ChunkRepository
3
+ from haiku.rag.store.repositories.document import DocumentRepository
4
+
5
+ __all__ = ["BaseRepository", "DocumentRepository", "ChunkRepository"]
@@ -0,0 +1,40 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ from haiku.rag.store.engine import Store
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class BaseRepository(ABC, Generic[T]):
10
+ """Base repository interface for database operations."""
11
+
12
+ def __init__(self, store: Store):
13
+ self.store = store
14
+
15
+ @abstractmethod
16
+ async def create(self, entity: T) -> T:
17
+ """Create a new entity in the database."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ async def get_by_id(self, entity_id: int) -> T | None:
22
+ """Get an entity by its ID."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ async def update(self, entity: T) -> T:
27
+ """Update an existing entity."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def delete(self, entity_id: int) -> bool:
32
+ """Delete an entity by its ID."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ async def list_all(
37
+ self, limit: int | None = None, offset: int | None = None
38
+ ) -> list[T]:
39
+ """List all entities with optional pagination."""
40
+ pass