haiku.rag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +107 -0
- haiku/rag/chunker.py +76 -0
- haiku/rag/cli.py +153 -0
- haiku/rag/client.py +261 -0
- haiku/rag/config.py +28 -0
- haiku/rag/embeddings/__init__.py +24 -0
- haiku/rag/embeddings/base.py +12 -0
- haiku/rag/embeddings/ollama.py +14 -0
- haiku/rag/embeddings/voyageai.py +17 -0
- haiku/rag/mcp.py +141 -0
- haiku/rag/reader.py +52 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +80 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +12 -0
- haiku/rag/store/models/document.py +16 -0
- haiku/rag/store/repositories/__init__.py +5 -0
- haiku/rag/store/repositories/base.py +40 -0
- haiku/rag/store/repositories/chunk.py +424 -0
- haiku/rag/store/repositories/document.py +210 -0
- haiku/rag/utils.py +25 -0
- haiku_rag-0.1.0.dist-info/METADATA +195 -0
- haiku_rag-0.1.0.dist-info/RECORD +27 -0
- haiku_rag-0.1.0.dist-info/WHEEL +4 -0
- haiku_rag-0.1.0.dist-info/entry_points.txt +2 -0
- haiku_rag-0.1.0.dist-info/licenses/LICENSE +7 -0
haiku/rag/__init__.py
ADDED
|
File without changes
|
haiku/rag/app.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.markdown import Markdown
|
|
5
|
+
|
|
6
|
+
from haiku.rag.client import HaikuRAG
|
|
7
|
+
from haiku.rag.store.models.chunk import Chunk
|
|
8
|
+
from haiku.rag.store.models.document import Document
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HaikuRAGApp:
|
|
12
|
+
def __init__(self, db_path: Path):
|
|
13
|
+
self.db_path = db_path
|
|
14
|
+
self.console = Console()
|
|
15
|
+
|
|
16
|
+
async def list_documents(self):
|
|
17
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
18
|
+
documents = await self.client.list_documents()
|
|
19
|
+
for doc in documents:
|
|
20
|
+
self._rich_print_document(doc, truncate=True)
|
|
21
|
+
|
|
22
|
+
async def add_document_from_text(self, text: str):
|
|
23
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
24
|
+
doc = await self.client.create_document(text)
|
|
25
|
+
self._rich_print_document(doc, truncate=True)
|
|
26
|
+
self.console.print(
|
|
27
|
+
f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
async def add_document_from_source(self, file_path: Path):
|
|
31
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
32
|
+
doc = await self.client.create_document_from_source(file_path)
|
|
33
|
+
self._rich_print_document(doc, truncate=True)
|
|
34
|
+
self.console.print(
|
|
35
|
+
f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
async def get_document(self, doc_id: int):
|
|
39
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
40
|
+
doc = await self.client.get_document_by_id(doc_id)
|
|
41
|
+
if doc is None:
|
|
42
|
+
self.console.print(f"[red]Document with id {doc_id} not found.[/red]")
|
|
43
|
+
return
|
|
44
|
+
self._rich_print_document(doc, truncate=False)
|
|
45
|
+
|
|
46
|
+
async def delete_document(self, doc_id: int):
|
|
47
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
48
|
+
await self.client.delete_document(doc_id)
|
|
49
|
+
self.console.print(f"[b]Document {doc_id} deleted successfully.[/b]")
|
|
50
|
+
|
|
51
|
+
async def search(self, query: str, limit: int = 5, k: int = 60):
|
|
52
|
+
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
53
|
+
results = await self.client.search(query, limit=limit, k=k)
|
|
54
|
+
if not results:
|
|
55
|
+
self.console.print("[red]No results found.[/red]")
|
|
56
|
+
return
|
|
57
|
+
for chunk, score in results:
|
|
58
|
+
self._rich_print_search_result(chunk, score)
|
|
59
|
+
|
|
60
|
+
def _rich_print_document(self, doc: Document, truncate: bool = False):
|
|
61
|
+
"""Format a document for display."""
|
|
62
|
+
if truncate:
|
|
63
|
+
content = doc.content.splitlines()
|
|
64
|
+
if len(content) > 3:
|
|
65
|
+
content = content[:3] + ["\n…"]
|
|
66
|
+
content = "\n".join(content)
|
|
67
|
+
content = Markdown(content)
|
|
68
|
+
else:
|
|
69
|
+
content = Markdown(doc.content)
|
|
70
|
+
self.console.print(
|
|
71
|
+
f"[repr.attrib_name]id[/repr.attrib_name]: {doc.id} [repr.attrib_name]uri[/repr.attrib_name]: {doc.uri} [repr.attrib_name]meta[/repr.attrib_name]: {doc.metadata}"
|
|
72
|
+
)
|
|
73
|
+
self.console.print(
|
|
74
|
+
f"[repr.attrib_name]created at[/repr.attrib_name]: {doc.created_at} [repr.attrib_name]updated at[/repr.attrib_name]: {doc.updated_at}"
|
|
75
|
+
)
|
|
76
|
+
self.console.print("[repr.attrib_name]content[/repr.attrib_name]:")
|
|
77
|
+
self.console.print(content)
|
|
78
|
+
self.console.rule()
|
|
79
|
+
|
|
80
|
+
def _rich_print_search_result(self, chunk: Chunk, score: float):
|
|
81
|
+
"""Format a search result chunk for display."""
|
|
82
|
+
content = Markdown(chunk.content)
|
|
83
|
+
self.console.print(
|
|
84
|
+
f"[repr.attrib_name]document_id[/repr.attrib_name]: {chunk.document_id} "
|
|
85
|
+
f"[repr.attrib_name]score[/repr.attrib_name]: {score:.4f}"
|
|
86
|
+
)
|
|
87
|
+
self.console.print("[repr.attrib_name]content[/repr.attrib_name]:")
|
|
88
|
+
self.console.print(content)
|
|
89
|
+
self.console.rule()
|
|
90
|
+
|
|
91
|
+
def serve(self, transport: str | None = None):
|
|
92
|
+
"""Start the MCP server."""
|
|
93
|
+
from haiku.rag.mcp import create_mcp_server
|
|
94
|
+
|
|
95
|
+
server = create_mcp_server(self.db_path)
|
|
96
|
+
|
|
97
|
+
if transport == "stdio":
|
|
98
|
+
self.console.print("[green]Starting MCP server on stdio...[/green]")
|
|
99
|
+
server.run("stdio")
|
|
100
|
+
elif transport == "sse":
|
|
101
|
+
self.console.print(
|
|
102
|
+
"[green]Starting MCP server with streamable HTTP...[/green]"
|
|
103
|
+
)
|
|
104
|
+
server.run("sse")
|
|
105
|
+
else:
|
|
106
|
+
self.console.print("[green]Starting MCP server with HTTP...[/green]")
|
|
107
|
+
server.run("streamable-http")
|
haiku/rag/chunker.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from typing import ClassVar
|
|
2
|
+
|
|
3
|
+
import tiktoken
|
|
4
|
+
|
|
5
|
+
from haiku.rag.config import Config
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Chunker:
|
|
9
|
+
"""
|
|
10
|
+
A class that chunks text into smaller pieces for embedding and retrieval.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
chunk_size : int
|
|
15
|
+
The maximum size of a chunk in characters.
|
|
16
|
+
chunk_overlap : int
|
|
17
|
+
The number of characters of overlap between chunks.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
encoder: ClassVar[tiktoken.Encoding] = tiktoken.encoding_for_model("gpt-4o")
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
chunk_size: int = Config.CHUNK_SIZE,
|
|
25
|
+
chunk_overlap: int = Config.CHUNK_OVERLAP,
|
|
26
|
+
):
|
|
27
|
+
self.chunk_size = chunk_size
|
|
28
|
+
self.chunk_overlap = chunk_overlap
|
|
29
|
+
|
|
30
|
+
async def chunk(self, text: str) -> list[str]:
|
|
31
|
+
"""
|
|
32
|
+
Split the text into chunks.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
text : str
|
|
37
|
+
The text to be split into chunks.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
list
|
|
42
|
+
A list of text chunks.
|
|
43
|
+
"""
|
|
44
|
+
if not text:
|
|
45
|
+
return []
|
|
46
|
+
|
|
47
|
+
encoded_tokens = self.encoder.encode(text, disallowed_special=())
|
|
48
|
+
|
|
49
|
+
if self.chunk_size > len(encoded_tokens):
|
|
50
|
+
return [text]
|
|
51
|
+
|
|
52
|
+
chunks = []
|
|
53
|
+
i = 0
|
|
54
|
+
split_id_counter = 0
|
|
55
|
+
while i < len(encoded_tokens):
|
|
56
|
+
# Overlap
|
|
57
|
+
start_i = i
|
|
58
|
+
end_i = min(i + self.chunk_size, len(encoded_tokens))
|
|
59
|
+
|
|
60
|
+
chunk_tokens = encoded_tokens[start_i:end_i]
|
|
61
|
+
chunk_text = self.encoder.decode(chunk_tokens)
|
|
62
|
+
|
|
63
|
+
chunks.append(chunk_text)
|
|
64
|
+
split_id_counter += 1
|
|
65
|
+
|
|
66
|
+
# Exit loop if this was the last possible chunk
|
|
67
|
+
if end_i == len(encoded_tokens):
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
i += (
|
|
71
|
+
self.chunk_size - self.chunk_overlap
|
|
72
|
+
) # Step forward, considering overlap
|
|
73
|
+
return chunks
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
chunker = Chunker()
|
haiku/rag/cli.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
|
|
7
|
+
from haiku.rag.app import HaikuRAGApp
|
|
8
|
+
from haiku.rag.utils import get_default_data_dir
|
|
9
|
+
|
|
10
|
+
cli = typer.Typer(
|
|
11
|
+
context_settings={"help_option_names": ["-h", "--help"]}, no_args_is_help=True
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
console = Console()
|
|
15
|
+
event_loop = asyncio.get_event_loop()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@cli.command("list", help="List all stored documents")
|
|
19
|
+
def list_documents(
|
|
20
|
+
db: Path = typer.Option(
|
|
21
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
22
|
+
"--db",
|
|
23
|
+
help="Path to the SQLite database file",
|
|
24
|
+
),
|
|
25
|
+
):
|
|
26
|
+
app = HaikuRAGApp(db_path=db)
|
|
27
|
+
event_loop.run_until_complete(app.list_documents())
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@cli.command("add", help="Add a document from text input")
|
|
31
|
+
def add_document_text(
|
|
32
|
+
text: str = typer.Argument(
|
|
33
|
+
help="The text content of the document to add",
|
|
34
|
+
),
|
|
35
|
+
db: Path = typer.Option(
|
|
36
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
37
|
+
"--db",
|
|
38
|
+
help="Path to the SQLite database file",
|
|
39
|
+
),
|
|
40
|
+
):
|
|
41
|
+
app = HaikuRAGApp(db_path=db)
|
|
42
|
+
event_loop.run_until_complete(app.add_document_from_text(text=text))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@cli.command("add-src", help="Add a document from a file path or URL")
|
|
46
|
+
def add_document_src(
|
|
47
|
+
file_path: Path = typer.Argument(
|
|
48
|
+
help="The file path or URL of the document to add",
|
|
49
|
+
),
|
|
50
|
+
db: Path = typer.Option(
|
|
51
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
52
|
+
"--db",
|
|
53
|
+
help="Path to the SQLite database file",
|
|
54
|
+
),
|
|
55
|
+
):
|
|
56
|
+
app = HaikuRAGApp(db_path=db)
|
|
57
|
+
event_loop.run_until_complete(app.add_document_from_source(file_path=file_path))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@cli.command("get", help="Get and display a document by its ID")
|
|
61
|
+
def get_document(
|
|
62
|
+
doc_id: int = typer.Argument(
|
|
63
|
+
help="The ID of the document to get",
|
|
64
|
+
),
|
|
65
|
+
db: Path = typer.Option(
|
|
66
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
67
|
+
"--db",
|
|
68
|
+
help="Path to the SQLite database file",
|
|
69
|
+
),
|
|
70
|
+
):
|
|
71
|
+
app = HaikuRAGApp(db_path=db)
|
|
72
|
+
event_loop.run_until_complete(app.get_document(doc_id=doc_id))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@cli.command("delete", help="Delete a document by its ID")
|
|
76
|
+
def delete_document(
|
|
77
|
+
doc_id: int = typer.Argument(
|
|
78
|
+
help="The ID of the document to delete",
|
|
79
|
+
),
|
|
80
|
+
db: Path = typer.Option(
|
|
81
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
82
|
+
"--db",
|
|
83
|
+
help="Path to the SQLite database file",
|
|
84
|
+
),
|
|
85
|
+
):
|
|
86
|
+
app = HaikuRAGApp(db_path=db)
|
|
87
|
+
event_loop.run_until_complete(app.delete_document(doc_id=doc_id))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@cli.command("search", help="Search for documents by a query")
|
|
91
|
+
def search(
|
|
92
|
+
query: str = typer.Argument(
|
|
93
|
+
help="The search query to use",
|
|
94
|
+
),
|
|
95
|
+
limit: int = typer.Option(
|
|
96
|
+
5,
|
|
97
|
+
"--limit",
|
|
98
|
+
"-l",
|
|
99
|
+
help="Maximum number of results to return",
|
|
100
|
+
),
|
|
101
|
+
k: int = typer.Option(
|
|
102
|
+
60,
|
|
103
|
+
"--k",
|
|
104
|
+
help="Reciprocal Rank Fusion k parameter",
|
|
105
|
+
),
|
|
106
|
+
db: Path = typer.Option(
|
|
107
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
108
|
+
"--db",
|
|
109
|
+
help="Path to the SQLite database file",
|
|
110
|
+
),
|
|
111
|
+
):
|
|
112
|
+
app = HaikuRAGApp(db_path=db)
|
|
113
|
+
event_loop.run_until_complete(app.search(query=query, limit=limit, k=k))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@cli.command(
|
|
117
|
+
"serve", help="Start the haiku.rag MCP server (by default in streamable HTTP mode)"
|
|
118
|
+
)
|
|
119
|
+
def serve(
|
|
120
|
+
db: Path = typer.Option(
|
|
121
|
+
get_default_data_dir() / "haiku.rag.sqlite",
|
|
122
|
+
"--db",
|
|
123
|
+
help="Path to the SQLite database file",
|
|
124
|
+
),
|
|
125
|
+
stdio: bool = typer.Option(
|
|
126
|
+
False,
|
|
127
|
+
"--stdio",
|
|
128
|
+
help="Run MCP server on stdio Transport",
|
|
129
|
+
),
|
|
130
|
+
sse: bool = typer.Option(
|
|
131
|
+
False,
|
|
132
|
+
"--sse",
|
|
133
|
+
help="Run MCP server on SSE transport",
|
|
134
|
+
),
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Start the MCP server."""
|
|
137
|
+
if stdio and sse:
|
|
138
|
+
console.print("[red]Error: Cannot use both --stdio and --http options[/red]")
|
|
139
|
+
raise typer.Exit(1)
|
|
140
|
+
|
|
141
|
+
app = HaikuRAGApp(db_path=db)
|
|
142
|
+
|
|
143
|
+
transport = None
|
|
144
|
+
if stdio:
|
|
145
|
+
transport = "stdio"
|
|
146
|
+
elif sse:
|
|
147
|
+
transport = "sse"
|
|
148
|
+
|
|
149
|
+
app.serve(transport=transport)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
if __name__ == "__main__":
|
|
153
|
+
cli()
|
haiku/rag/client.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import mimetypes
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Literal
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from haiku.rag.config import Config
|
|
11
|
+
from haiku.rag.reader import FileReader
|
|
12
|
+
from haiku.rag.store.engine import Store
|
|
13
|
+
from haiku.rag.store.models.chunk import Chunk
|
|
14
|
+
from haiku.rag.store.models.document import Document
|
|
15
|
+
from haiku.rag.store.repositories.chunk import ChunkRepository
|
|
16
|
+
from haiku.rag.store.repositories.document import DocumentRepository
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HaikuRAG:
|
|
20
|
+
"""High-level haiku-rag client."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
db_path: Path | Literal[":memory:"] = Config.DEFAULT_DATA_DIR
|
|
25
|
+
/ "haiku.rag.sqlite",
|
|
26
|
+
):
|
|
27
|
+
"""Initialize the RAG client with a database path."""
|
|
28
|
+
if isinstance(db_path, Path):
|
|
29
|
+
if not db_path.parent.exists():
|
|
30
|
+
Path.mkdir(db_path.parent, parents=True)
|
|
31
|
+
self.store = Store(db_path)
|
|
32
|
+
self.document_repository = DocumentRepository(self.store)
|
|
33
|
+
self.chunk_repository = ChunkRepository(self.store)
|
|
34
|
+
|
|
35
|
+
async def __aenter__(self):
|
|
36
|
+
"""Async context manager entry."""
|
|
37
|
+
return self
|
|
38
|
+
|
|
39
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
40
|
+
"""Async context manager exit."""
|
|
41
|
+
self.close()
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
async def create_document(
|
|
45
|
+
self, content: str, uri: str | None = None, metadata: dict | None = None
|
|
46
|
+
) -> Document:
|
|
47
|
+
"""Create a new document with optional URI and metadata."""
|
|
48
|
+
document = Document(
|
|
49
|
+
content=content,
|
|
50
|
+
uri=uri,
|
|
51
|
+
metadata=metadata or {},
|
|
52
|
+
)
|
|
53
|
+
return await self.document_repository.create(document)
|
|
54
|
+
|
|
55
|
+
async def create_document_from_source(
|
|
56
|
+
self, source: str | Path, metadata: dict = {}
|
|
57
|
+
) -> Document:
|
|
58
|
+
"""Create or update a document from a file path or URL.
|
|
59
|
+
|
|
60
|
+
Checks if a document with the same URI already exists:
|
|
61
|
+
- If MD5 is unchanged, returns existing document
|
|
62
|
+
- If MD5 changed, updates the document
|
|
63
|
+
- If no document exists, creates a new one
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
source: File path (as string or Path) or URL to parse
|
|
67
|
+
metadata: Optional metadata dictionary
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Document instance (created, updated, or existing)
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If the file/URL cannot be parsed or doesn't exist
|
|
74
|
+
httpx.RequestError: If URL request fails
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Check if it's a URL
|
|
78
|
+
source_str = str(source)
|
|
79
|
+
parsed_url = urlparse(source_str)
|
|
80
|
+
if parsed_url.scheme in ("http", "https"):
|
|
81
|
+
return await self._create_or_update_document_from_url(source_str, metadata)
|
|
82
|
+
|
|
83
|
+
# Handle as file path
|
|
84
|
+
source_path = Path(source) if isinstance(source, str) else source
|
|
85
|
+
if source_path.suffix.lower() not in FileReader.extensions:
|
|
86
|
+
raise ValueError(f"Unsupported file extension: {source_path.suffix}")
|
|
87
|
+
|
|
88
|
+
if not source_path.exists():
|
|
89
|
+
raise ValueError(f"File does not exist: {source_path}")
|
|
90
|
+
|
|
91
|
+
uri = str(source_path.resolve())
|
|
92
|
+
md5_hash = hashlib.md5(source_path.read_bytes()).hexdigest()
|
|
93
|
+
|
|
94
|
+
# Check if document already exists
|
|
95
|
+
existing_doc = await self.get_document_by_uri(uri)
|
|
96
|
+
if existing_doc and existing_doc.metadata.get("md5") == md5_hash:
|
|
97
|
+
# MD5 unchanged, return existing document
|
|
98
|
+
return existing_doc
|
|
99
|
+
|
|
100
|
+
content = FileReader.parse_file(source_path)
|
|
101
|
+
|
|
102
|
+
# Get content type from file extension
|
|
103
|
+
content_type, _ = mimetypes.guess_type(str(source_path))
|
|
104
|
+
if not content_type:
|
|
105
|
+
content_type = "application/octet-stream"
|
|
106
|
+
|
|
107
|
+
# Merge metadata with contentType and md5
|
|
108
|
+
metadata.update({"contentType": content_type, "md5": md5_hash})
|
|
109
|
+
|
|
110
|
+
if existing_doc:
|
|
111
|
+
# Update existing document
|
|
112
|
+
existing_doc.content = content
|
|
113
|
+
existing_doc.metadata = metadata
|
|
114
|
+
return await self.update_document(existing_doc)
|
|
115
|
+
else:
|
|
116
|
+
# Create new document
|
|
117
|
+
return await self.create_document(
|
|
118
|
+
content=content, uri=uri, metadata=metadata
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def _create_or_update_document_from_url(
|
|
122
|
+
self, url: str, metadata: dict = {}
|
|
123
|
+
) -> Document:
|
|
124
|
+
"""Create or update a document from a URL by downloading and parsing the content.
|
|
125
|
+
|
|
126
|
+
Checks if a document with the same URI already exists:
|
|
127
|
+
- If MD5 is unchanged, returns existing document
|
|
128
|
+
- If MD5 changed, updates the document
|
|
129
|
+
- If no document exists, creates a new one
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
url: URL to download and parse
|
|
133
|
+
metadata: Optional metadata dictionary
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Document instance (created, updated, or existing)
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
ValueError: If the content cannot be parsed
|
|
140
|
+
httpx.RequestError: If URL request fails
|
|
141
|
+
"""
|
|
142
|
+
async with httpx.AsyncClient() as client:
|
|
143
|
+
response = await client.get(url)
|
|
144
|
+
response.raise_for_status()
|
|
145
|
+
|
|
146
|
+
md5_hash = hashlib.md5(response.content).hexdigest()
|
|
147
|
+
|
|
148
|
+
# Check if document already exists
|
|
149
|
+
existing_doc = await self.get_document_by_uri(url)
|
|
150
|
+
if existing_doc and existing_doc.metadata.get("md5") == md5_hash:
|
|
151
|
+
# MD5 unchanged, return existing document
|
|
152
|
+
return existing_doc
|
|
153
|
+
|
|
154
|
+
# Get content type to determine file extension
|
|
155
|
+
content_type = response.headers.get("content-type", "").lower()
|
|
156
|
+
file_extension = self._get_extension_from_content_type_or_url(
|
|
157
|
+
url, content_type
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if file_extension not in FileReader.extensions:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Unsupported content type/extension: {content_type}/{file_extension}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Create a temporary file with the appropriate extension
|
|
166
|
+
with tempfile.NamedTemporaryFile(
|
|
167
|
+
mode="wb", suffix=file_extension, delete=False
|
|
168
|
+
) as temp_file:
|
|
169
|
+
temp_file.write(response.content)
|
|
170
|
+
temp_path = Path(temp_file.name)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Parse the content using FileReader
|
|
174
|
+
content = FileReader.parse_file(temp_path)
|
|
175
|
+
|
|
176
|
+
# Merge metadata with contentType and md5
|
|
177
|
+
metadata.update({"contentType": content_type, "md5": md5_hash})
|
|
178
|
+
|
|
179
|
+
if existing_doc:
|
|
180
|
+
existing_doc.content = content
|
|
181
|
+
existing_doc.metadata = metadata
|
|
182
|
+
return await self.update_document(existing_doc)
|
|
183
|
+
else:
|
|
184
|
+
return await self.create_document(
|
|
185
|
+
content=content, uri=url, metadata=metadata
|
|
186
|
+
)
|
|
187
|
+
finally:
|
|
188
|
+
# Clean up temporary file
|
|
189
|
+
temp_path.unlink(missing_ok=True)
|
|
190
|
+
|
|
191
|
+
def _get_extension_from_content_type_or_url(
|
|
192
|
+
self, url: str, content_type: str
|
|
193
|
+
) -> str:
|
|
194
|
+
"""Determine file extension from content type or URL."""
|
|
195
|
+
# Common content type mappings
|
|
196
|
+
content_type_map = {
|
|
197
|
+
"text/html": ".html",
|
|
198
|
+
"text/plain": ".txt",
|
|
199
|
+
"text/markdown": ".md",
|
|
200
|
+
"application/pdf": ".pdf",
|
|
201
|
+
"application/json": ".json",
|
|
202
|
+
"text/csv": ".csv",
|
|
203
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
|
204
|
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
|
205
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
# Try content type first
|
|
209
|
+
for ct, ext in content_type_map.items():
|
|
210
|
+
if ct in content_type:
|
|
211
|
+
return ext
|
|
212
|
+
|
|
213
|
+
# Try URL extension
|
|
214
|
+
parsed_url = urlparse(url)
|
|
215
|
+
path = Path(parsed_url.path)
|
|
216
|
+
if path.suffix:
|
|
217
|
+
return path.suffix.lower()
|
|
218
|
+
|
|
219
|
+
# Default to .html for web content
|
|
220
|
+
return ".html"
|
|
221
|
+
|
|
222
|
+
async def get_document_by_id(self, document_id: int) -> Document | None:
|
|
223
|
+
"""Get a document by its ID."""
|
|
224
|
+
return await self.document_repository.get_by_id(document_id)
|
|
225
|
+
|
|
226
|
+
async def get_document_by_uri(self, uri: str) -> Document | None:
|
|
227
|
+
"""Get a document by its URI."""
|
|
228
|
+
return await self.document_repository.get_by_uri(uri)
|
|
229
|
+
|
|
230
|
+
async def update_document(self, document: Document) -> Document:
|
|
231
|
+
"""Update an existing document."""
|
|
232
|
+
return await self.document_repository.update(document)
|
|
233
|
+
|
|
234
|
+
async def delete_document(self, document_id: int) -> bool:
|
|
235
|
+
"""Delete a document by its ID."""
|
|
236
|
+
return await self.document_repository.delete(document_id)
|
|
237
|
+
|
|
238
|
+
async def list_documents(
|
|
239
|
+
self, limit: int | None = None, offset: int | None = None
|
|
240
|
+
) -> list[Document]:
|
|
241
|
+
"""List all documents with optional pagination."""
|
|
242
|
+
return await self.document_repository.list_all(limit=limit, offset=offset)
|
|
243
|
+
|
|
244
|
+
async def search(
|
|
245
|
+
self, query: str, limit: int = 5, k: int = 60
|
|
246
|
+
) -> list[tuple[Chunk, float]]:
|
|
247
|
+
"""Search for relevant chunks using hybrid search (vector similarity + full-text search).
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
query: The search query string
|
|
251
|
+
limit: Maximum number of results to return
|
|
252
|
+
k: Parameter for Reciprocal Rank Fusion (default: 60)
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
List of (chunk, score) tuples ordered by relevance
|
|
256
|
+
"""
|
|
257
|
+
return await self.chunk_repository.search_chunks_hybrid(query, limit, k)
|
|
258
|
+
|
|
259
|
+
def close(self):
|
|
260
|
+
"""Close the underlying store connection."""
|
|
261
|
+
self.store.close()
|
haiku/rag/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from dotenv import load_dotenv
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from haiku.rag.utils import get_default_data_dir
|
|
8
|
+
|
|
9
|
+
load_dotenv()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AppConfig(BaseModel):
|
|
13
|
+
ENV: str = "development"
|
|
14
|
+
|
|
15
|
+
DEFAULT_DATA_DIR: Path = get_default_data_dir()
|
|
16
|
+
|
|
17
|
+
EMBEDDING_PROVIDER: str = "ollama"
|
|
18
|
+
EMBEDDING_MODEL: str = "mxbai-embed-large"
|
|
19
|
+
EMBEDDING_VECTOR_DIM: int = 1024
|
|
20
|
+
|
|
21
|
+
CHUNK_SIZE: int = 256
|
|
22
|
+
CHUNK_OVERLAP: int = 32
|
|
23
|
+
|
|
24
|
+
OLLAMA_BASE_URL: str = "http://localhost:11434"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Expose Config object for app to import
|
|
28
|
+
Config = AppConfig.model_validate(os.environ)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from haiku.rag.config import Config
|
|
2
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
3
|
+
from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_embedder() -> EmbedderBase:
|
|
7
|
+
"""
|
|
8
|
+
Factory function to get the appropriate embedder based on the configuration.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
if Config.EMBEDDING_PROVIDER == "ollama":
|
|
12
|
+
return OllamaEmbedder(Config.EMBEDDING_MODEL, Config.EMBEDDING_VECTOR_DIM)
|
|
13
|
+
|
|
14
|
+
if Config.EMBEDDING_PROVIDER == "voyageai":
|
|
15
|
+
try:
|
|
16
|
+
from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
|
|
17
|
+
except ImportError:
|
|
18
|
+
raise ImportError(
|
|
19
|
+
"VoyageAI embedder requires the 'voyageai' package. "
|
|
20
|
+
"Please install haiku.rag with the 'voyageai' extra:"
|
|
21
|
+
"uv pip install haiku.rag --extra voyageai"
|
|
22
|
+
)
|
|
23
|
+
return VoyageAIEmbedder(Config.EMBEDDING_MODEL, Config.EMBEDDING_VECTOR_DIM)
|
|
24
|
+
raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDING_PROVIDER}")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class EmbedderBase:
|
|
2
|
+
_model: str = ""
|
|
3
|
+
_vector_dim: int = 0
|
|
4
|
+
|
|
5
|
+
def __init__(self, model: str, vector_dim: int):
|
|
6
|
+
self._model = model
|
|
7
|
+
self._vector_dim = vector_dim
|
|
8
|
+
|
|
9
|
+
async def embed(self, text: str) -> list[float]:
|
|
10
|
+
raise NotImplementedError(
|
|
11
|
+
"Embedder is an abstract class. Please implement the embed method in a subclass."
|
|
12
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from ollama import AsyncClient
|
|
2
|
+
|
|
3
|
+
from haiku.rag.config import Config
|
|
4
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Embedder(EmbedderBase):
|
|
8
|
+
_model: str = Config.EMBEDDING_MODEL
|
|
9
|
+
_vector_dim: int = 1024
|
|
10
|
+
|
|
11
|
+
async def embed(self, text: str) -> list[float]:
|
|
12
|
+
client = AsyncClient(host=Config.OLLAMA_BASE_URL)
|
|
13
|
+
res = await client.embeddings(model=self._model, prompt=text)
|
|
14
|
+
return list(res["embedding"])
|