haiku.rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

haiku/rag/__init__.py ADDED
File without changes
haiku/rag/app.py ADDED
@@ -0,0 +1,107 @@
1
+ from pathlib import Path
2
+
3
+ from rich.console import Console
4
+ from rich.markdown import Markdown
5
+
6
+ from haiku.rag.client import HaikuRAG
7
+ from haiku.rag.store.models.chunk import Chunk
8
+ from haiku.rag.store.models.document import Document
9
+
10
+
11
+ class HaikuRAGApp:
12
+ def __init__(self, db_path: Path):
13
+ self.db_path = db_path
14
+ self.console = Console()
15
+
16
+ async def list_documents(self):
17
+ async with HaikuRAG(db_path=self.db_path) as self.client:
18
+ documents = await self.client.list_documents()
19
+ for doc in documents:
20
+ self._rich_print_document(doc, truncate=True)
21
+
22
+ async def add_document_from_text(self, text: str):
23
+ async with HaikuRAG(db_path=self.db_path) as self.client:
24
+ doc = await self.client.create_document(text)
25
+ self._rich_print_document(doc, truncate=True)
26
+ self.console.print(
27
+ f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
28
+ )
29
+
30
+ async def add_document_from_source(self, file_path: Path):
31
+ async with HaikuRAG(db_path=self.db_path) as self.client:
32
+ doc = await self.client.create_document_from_source(file_path)
33
+ self._rich_print_document(doc, truncate=True)
34
+ self.console.print(
35
+ f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
36
+ )
37
+
38
+ async def get_document(self, doc_id: int):
39
+ async with HaikuRAG(db_path=self.db_path) as self.client:
40
+ doc = await self.client.get_document_by_id(doc_id)
41
+ if doc is None:
42
+ self.console.print(f"[red]Document with id {doc_id} not found.[/red]")
43
+ return
44
+ self._rich_print_document(doc, truncate=False)
45
+
46
+ async def delete_document(self, doc_id: int):
47
+ async with HaikuRAG(db_path=self.db_path) as self.client:
48
+ await self.client.delete_document(doc_id)
49
+ self.console.print(f"[b]Document {doc_id} deleted successfully.[/b]")
50
+
51
+ async def search(self, query: str, limit: int = 5, k: int = 60):
52
+ async with HaikuRAG(db_path=self.db_path) as self.client:
53
+ results = await self.client.search(query, limit=limit, k=k)
54
+ if not results:
55
+ self.console.print("[red]No results found.[/red]")
56
+ return
57
+ for chunk, score in results:
58
+ self._rich_print_search_result(chunk, score)
59
+
60
+ def _rich_print_document(self, doc: Document, truncate: bool = False):
61
+ """Format a document for display."""
62
+ if truncate:
63
+ content = doc.content.splitlines()
64
+ if len(content) > 3:
65
+ content = content[:3] + ["\n…"]
66
+ content = "\n".join(content)
67
+ content = Markdown(content)
68
+ else:
69
+ content = Markdown(doc.content)
70
+ self.console.print(
71
+ f"[repr.attrib_name]id[/repr.attrib_name]: {doc.id} [repr.attrib_name]uri[/repr.attrib_name]: {doc.uri} [repr.attrib_name]meta[/repr.attrib_name]: {doc.metadata}"
72
+ )
73
+ self.console.print(
74
+ f"[repr.attrib_name]created at[/repr.attrib_name]: {doc.created_at} [repr.attrib_name]updated at[/repr.attrib_name]: {doc.updated_at}"
75
+ )
76
+ self.console.print("[repr.attrib_name]content[/repr.attrib_name]:")
77
+ self.console.print(content)
78
+ self.console.rule()
79
+
80
+ def _rich_print_search_result(self, chunk: Chunk, score: float):
81
+ """Format a search result chunk for display."""
82
+ content = Markdown(chunk.content)
83
+ self.console.print(
84
+ f"[repr.attrib_name]document_id[/repr.attrib_name]: {chunk.document_id} "
85
+ f"[repr.attrib_name]score[/repr.attrib_name]: {score:.4f}"
86
+ )
87
+ self.console.print("[repr.attrib_name]content[/repr.attrib_name]:")
88
+ self.console.print(content)
89
+ self.console.rule()
90
+
91
+ def serve(self, transport: str | None = None):
92
+ """Start the MCP server."""
93
+ from haiku.rag.mcp import create_mcp_server
94
+
95
+ server = create_mcp_server(self.db_path)
96
+
97
+ if transport == "stdio":
98
+ self.console.print("[green]Starting MCP server on stdio...[/green]")
99
+ server.run("stdio")
100
+ elif transport == "sse":
101
+ self.console.print(
102
+ "[green]Starting MCP server with streamable HTTP...[/green]"
103
+ )
104
+ server.run("sse")
105
+ else:
106
+ self.console.print("[green]Starting MCP server with HTTP...[/green]")
107
+ server.run("streamable-http")
haiku/rag/chunker.py ADDED
@@ -0,0 +1,76 @@
1
+ from typing import ClassVar
2
+
3
+ import tiktoken
4
+
5
+ from haiku.rag.config import Config
6
+
7
+
8
+ class Chunker:
9
+ """
10
+ A class that chunks text into smaller pieces for embedding and retrieval.
11
+
12
+ Parameters
13
+ ----------
14
+ chunk_size : int
15
+ The maximum size of a chunk in characters.
16
+ chunk_overlap : int
17
+ The number of characters of overlap between chunks.
18
+ """
19
+
20
+ encoder: ClassVar[tiktoken.Encoding] = tiktoken.encoding_for_model("gpt-4o")
21
+
22
+ def __init__(
23
+ self,
24
+ chunk_size: int = Config.CHUNK_SIZE,
25
+ chunk_overlap: int = Config.CHUNK_OVERLAP,
26
+ ):
27
+ self.chunk_size = chunk_size
28
+ self.chunk_overlap = chunk_overlap
29
+
30
+ async def chunk(self, text: str) -> list[str]:
31
+ """
32
+ Split the text into chunks.
33
+
34
+ Parameters
35
+ ----------
36
+ text : str
37
+ The text to be split into chunks.
38
+
39
+ Returns
40
+ -------
41
+ list
42
+ A list of text chunks.
43
+ """
44
+ if not text:
45
+ return []
46
+
47
+ encoded_tokens = self.encoder.encode(text, disallowed_special=())
48
+
49
+ if self.chunk_size > len(encoded_tokens):
50
+ return [text]
51
+
52
+ chunks = []
53
+ i = 0
54
+ split_id_counter = 0
55
+ while i < len(encoded_tokens):
56
+ # Overlap
57
+ start_i = i
58
+ end_i = min(i + self.chunk_size, len(encoded_tokens))
59
+
60
+ chunk_tokens = encoded_tokens[start_i:end_i]
61
+ chunk_text = self.encoder.decode(chunk_tokens)
62
+
63
+ chunks.append(chunk_text)
64
+ split_id_counter += 1
65
+
66
+ # Exit loop if this was the last possible chunk
67
+ if end_i == len(encoded_tokens):
68
+ break
69
+
70
+ i += (
71
+ self.chunk_size - self.chunk_overlap
72
+ ) # Step forward, considering overlap
73
+ return chunks
74
+
75
+
76
+ chunker = Chunker()
haiku/rag/cli.py ADDED
@@ -0,0 +1,153 @@
1
+ import asyncio
2
+ from pathlib import Path
3
+
4
+ import typer
5
+ from rich.console import Console
6
+
7
+ from haiku.rag.app import HaikuRAGApp
8
+ from haiku.rag.utils import get_default_data_dir
9
+
10
+ cli = typer.Typer(
11
+ context_settings={"help_option_names": ["-h", "--help"]}, no_args_is_help=True
12
+ )
13
+
14
+ console = Console()
15
+ event_loop = asyncio.get_event_loop()
16
+
17
+
18
+ @cli.command("list", help="List all stored documents")
19
+ def list_documents(
20
+ db: Path = typer.Option(
21
+ get_default_data_dir() / "haiku.rag.sqlite",
22
+ "--db",
23
+ help="Path to the SQLite database file",
24
+ ),
25
+ ):
26
+ app = HaikuRAGApp(db_path=db)
27
+ event_loop.run_until_complete(app.list_documents())
28
+
29
+
30
+ @cli.command("add", help="Add a document from text input")
31
+ def add_document_text(
32
+ text: str = typer.Argument(
33
+ help="The text content of the document to add",
34
+ ),
35
+ db: Path = typer.Option(
36
+ get_default_data_dir() / "haiku.rag.sqlite",
37
+ "--db",
38
+ help="Path to the SQLite database file",
39
+ ),
40
+ ):
41
+ app = HaikuRAGApp(db_path=db)
42
+ event_loop.run_until_complete(app.add_document_from_text(text=text))
43
+
44
+
45
+ @cli.command("add-src", help="Add a document from a file path or URL")
46
+ def add_document_src(
47
+ file_path: Path = typer.Argument(
48
+ help="The file path or URL of the document to add",
49
+ ),
50
+ db: Path = typer.Option(
51
+ get_default_data_dir() / "haiku.rag.sqlite",
52
+ "--db",
53
+ help="Path to the SQLite database file",
54
+ ),
55
+ ):
56
+ app = HaikuRAGApp(db_path=db)
57
+ event_loop.run_until_complete(app.add_document_from_source(file_path=file_path))
58
+
59
+
60
+ @cli.command("get", help="Get and display a document by its ID")
61
+ def get_document(
62
+ doc_id: int = typer.Argument(
63
+ help="The ID of the document to get",
64
+ ),
65
+ db: Path = typer.Option(
66
+ get_default_data_dir() / "haiku.rag.sqlite",
67
+ "--db",
68
+ help="Path to the SQLite database file",
69
+ ),
70
+ ):
71
+ app = HaikuRAGApp(db_path=db)
72
+ event_loop.run_until_complete(app.get_document(doc_id=doc_id))
73
+
74
+
75
+ @cli.command("delete", help="Delete a document by its ID")
76
+ def delete_document(
77
+ doc_id: int = typer.Argument(
78
+ help="The ID of the document to delete",
79
+ ),
80
+ db: Path = typer.Option(
81
+ get_default_data_dir() / "haiku.rag.sqlite",
82
+ "--db",
83
+ help="Path to the SQLite database file",
84
+ ),
85
+ ):
86
+ app = HaikuRAGApp(db_path=db)
87
+ event_loop.run_until_complete(app.delete_document(doc_id=doc_id))
88
+
89
+
90
+ @cli.command("search", help="Search for documents by a query")
91
+ def search(
92
+ query: str = typer.Argument(
93
+ help="The search query to use",
94
+ ),
95
+ limit: int = typer.Option(
96
+ 5,
97
+ "--limit",
98
+ "-l",
99
+ help="Maximum number of results to return",
100
+ ),
101
+ k: int = typer.Option(
102
+ 60,
103
+ "--k",
104
+ help="Reciprocal Rank Fusion k parameter",
105
+ ),
106
+ db: Path = typer.Option(
107
+ get_default_data_dir() / "haiku.rag.sqlite",
108
+ "--db",
109
+ help="Path to the SQLite database file",
110
+ ),
111
+ ):
112
+ app = HaikuRAGApp(db_path=db)
113
+ event_loop.run_until_complete(app.search(query=query, limit=limit, k=k))
114
+
115
+
116
+ @cli.command(
117
+ "serve", help="Start the haiku.rag MCP server (by default in streamable HTTP mode)"
118
+ )
119
+ def serve(
120
+ db: Path = typer.Option(
121
+ get_default_data_dir() / "haiku.rag.sqlite",
122
+ "--db",
123
+ help="Path to the SQLite database file",
124
+ ),
125
+ stdio: bool = typer.Option(
126
+ False,
127
+ "--stdio",
128
+ help="Run MCP server on stdio Transport",
129
+ ),
130
+ sse: bool = typer.Option(
131
+ False,
132
+ "--sse",
133
+ help="Run MCP server on SSE transport",
134
+ ),
135
+ ) -> None:
136
+ """Start the MCP server."""
137
+ if stdio and sse:
138
+ console.print("[red]Error: Cannot use both --stdio and --http options[/red]")
139
+ raise typer.Exit(1)
140
+
141
+ app = HaikuRAGApp(db_path=db)
142
+
143
+ transport = None
144
+ if stdio:
145
+ transport = "stdio"
146
+ elif sse:
147
+ transport = "sse"
148
+
149
+ app.serve(transport=transport)
150
+
151
+
152
+ if __name__ == "__main__":
153
+ cli()
haiku/rag/client.py ADDED
@@ -0,0 +1,261 @@
1
+ import hashlib
2
+ import mimetypes
3
+ import tempfile
4
+ from pathlib import Path
5
+ from typing import Literal
6
+ from urllib.parse import urlparse
7
+
8
+ import httpx
9
+
10
+ from haiku.rag.config import Config
11
+ from haiku.rag.reader import FileReader
12
+ from haiku.rag.store.engine import Store
13
+ from haiku.rag.store.models.chunk import Chunk
14
+ from haiku.rag.store.models.document import Document
15
+ from haiku.rag.store.repositories.chunk import ChunkRepository
16
+ from haiku.rag.store.repositories.document import DocumentRepository
17
+
18
+
19
+ class HaikuRAG:
20
+ """High-level haiku-rag client."""
21
+
22
+ def __init__(
23
+ self,
24
+ db_path: Path | Literal[":memory:"] = Config.DEFAULT_DATA_DIR
25
+ / "haiku.rag.sqlite",
26
+ ):
27
+ """Initialize the RAG client with a database path."""
28
+ if isinstance(db_path, Path):
29
+ if not db_path.parent.exists():
30
+ Path.mkdir(db_path.parent, parents=True)
31
+ self.store = Store(db_path)
32
+ self.document_repository = DocumentRepository(self.store)
33
+ self.chunk_repository = ChunkRepository(self.store)
34
+
35
+ async def __aenter__(self):
36
+ """Async context manager entry."""
37
+ return self
38
+
39
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
40
+ """Async context manager exit."""
41
+ self.close()
42
+ return False
43
+
44
+ async def create_document(
45
+ self, content: str, uri: str | None = None, metadata: dict | None = None
46
+ ) -> Document:
47
+ """Create a new document with optional URI and metadata."""
48
+ document = Document(
49
+ content=content,
50
+ uri=uri,
51
+ metadata=metadata or {},
52
+ )
53
+ return await self.document_repository.create(document)
54
+
55
+ async def create_document_from_source(
56
+ self, source: str | Path, metadata: dict = {}
57
+ ) -> Document:
58
+ """Create or update a document from a file path or URL.
59
+
60
+ Checks if a document with the same URI already exists:
61
+ - If MD5 is unchanged, returns existing document
62
+ - If MD5 changed, updates the document
63
+ - If no document exists, creates a new one
64
+
65
+ Args:
66
+ source: File path (as string or Path) or URL to parse
67
+ metadata: Optional metadata dictionary
68
+
69
+ Returns:
70
+ Document instance (created, updated, or existing)
71
+
72
+ Raises:
73
+ ValueError: If the file/URL cannot be parsed or doesn't exist
74
+ httpx.RequestError: If URL request fails
75
+ """
76
+
77
+ # Check if it's a URL
78
+ source_str = str(source)
79
+ parsed_url = urlparse(source_str)
80
+ if parsed_url.scheme in ("http", "https"):
81
+ return await self._create_or_update_document_from_url(source_str, metadata)
82
+
83
+ # Handle as file path
84
+ source_path = Path(source) if isinstance(source, str) else source
85
+ if source_path.suffix.lower() not in FileReader.extensions:
86
+ raise ValueError(f"Unsupported file extension: {source_path.suffix}")
87
+
88
+ if not source_path.exists():
89
+ raise ValueError(f"File does not exist: {source_path}")
90
+
91
+ uri = str(source_path.resolve())
92
+ md5_hash = hashlib.md5(source_path.read_bytes()).hexdigest()
93
+
94
+ # Check if document already exists
95
+ existing_doc = await self.get_document_by_uri(uri)
96
+ if existing_doc and existing_doc.metadata.get("md5") == md5_hash:
97
+ # MD5 unchanged, return existing document
98
+ return existing_doc
99
+
100
+ content = FileReader.parse_file(source_path)
101
+
102
+ # Get content type from file extension
103
+ content_type, _ = mimetypes.guess_type(str(source_path))
104
+ if not content_type:
105
+ content_type = "application/octet-stream"
106
+
107
+ # Merge metadata with contentType and md5
108
+ metadata.update({"contentType": content_type, "md5": md5_hash})
109
+
110
+ if existing_doc:
111
+ # Update existing document
112
+ existing_doc.content = content
113
+ existing_doc.metadata = metadata
114
+ return await self.update_document(existing_doc)
115
+ else:
116
+ # Create new document
117
+ return await self.create_document(
118
+ content=content, uri=uri, metadata=metadata
119
+ )
120
+
121
+ async def _create_or_update_document_from_url(
122
+ self, url: str, metadata: dict = {}
123
+ ) -> Document:
124
+ """Create or update a document from a URL by downloading and parsing the content.
125
+
126
+ Checks if a document with the same URI already exists:
127
+ - If MD5 is unchanged, returns existing document
128
+ - If MD5 changed, updates the document
129
+ - If no document exists, creates a new one
130
+
131
+ Args:
132
+ url: URL to download and parse
133
+ metadata: Optional metadata dictionary
134
+
135
+ Returns:
136
+ Document instance (created, updated, or existing)
137
+
138
+ Raises:
139
+ ValueError: If the content cannot be parsed
140
+ httpx.RequestError: If URL request fails
141
+ """
142
+ async with httpx.AsyncClient() as client:
143
+ response = await client.get(url)
144
+ response.raise_for_status()
145
+
146
+ md5_hash = hashlib.md5(response.content).hexdigest()
147
+
148
+ # Check if document already exists
149
+ existing_doc = await self.get_document_by_uri(url)
150
+ if existing_doc and existing_doc.metadata.get("md5") == md5_hash:
151
+ # MD5 unchanged, return existing document
152
+ return existing_doc
153
+
154
+ # Get content type to determine file extension
155
+ content_type = response.headers.get("content-type", "").lower()
156
+ file_extension = self._get_extension_from_content_type_or_url(
157
+ url, content_type
158
+ )
159
+
160
+ if file_extension not in FileReader.extensions:
161
+ raise ValueError(
162
+ f"Unsupported content type/extension: {content_type}/{file_extension}"
163
+ )
164
+
165
+ # Create a temporary file with the appropriate extension
166
+ with tempfile.NamedTemporaryFile(
167
+ mode="wb", suffix=file_extension, delete=False
168
+ ) as temp_file:
169
+ temp_file.write(response.content)
170
+ temp_path = Path(temp_file.name)
171
+
172
+ try:
173
+ # Parse the content using FileReader
174
+ content = FileReader.parse_file(temp_path)
175
+
176
+ # Merge metadata with contentType and md5
177
+ metadata.update({"contentType": content_type, "md5": md5_hash})
178
+
179
+ if existing_doc:
180
+ existing_doc.content = content
181
+ existing_doc.metadata = metadata
182
+ return await self.update_document(existing_doc)
183
+ else:
184
+ return await self.create_document(
185
+ content=content, uri=url, metadata=metadata
186
+ )
187
+ finally:
188
+ # Clean up temporary file
189
+ temp_path.unlink(missing_ok=True)
190
+
191
+ def _get_extension_from_content_type_or_url(
192
+ self, url: str, content_type: str
193
+ ) -> str:
194
+ """Determine file extension from content type or URL."""
195
+ # Common content type mappings
196
+ content_type_map = {
197
+ "text/html": ".html",
198
+ "text/plain": ".txt",
199
+ "text/markdown": ".md",
200
+ "application/pdf": ".pdf",
201
+ "application/json": ".json",
202
+ "text/csv": ".csv",
203
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
204
+ "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
205
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
206
+ }
207
+
208
+ # Try content type first
209
+ for ct, ext in content_type_map.items():
210
+ if ct in content_type:
211
+ return ext
212
+
213
+ # Try URL extension
214
+ parsed_url = urlparse(url)
215
+ path = Path(parsed_url.path)
216
+ if path.suffix:
217
+ return path.suffix.lower()
218
+
219
+ # Default to .html for web content
220
+ return ".html"
221
+
222
+ async def get_document_by_id(self, document_id: int) -> Document | None:
223
+ """Get a document by its ID."""
224
+ return await self.document_repository.get_by_id(document_id)
225
+
226
+ async def get_document_by_uri(self, uri: str) -> Document | None:
227
+ """Get a document by its URI."""
228
+ return await self.document_repository.get_by_uri(uri)
229
+
230
+ async def update_document(self, document: Document) -> Document:
231
+ """Update an existing document."""
232
+ return await self.document_repository.update(document)
233
+
234
+ async def delete_document(self, document_id: int) -> bool:
235
+ """Delete a document by its ID."""
236
+ return await self.document_repository.delete(document_id)
237
+
238
+ async def list_documents(
239
+ self, limit: int | None = None, offset: int | None = None
240
+ ) -> list[Document]:
241
+ """List all documents with optional pagination."""
242
+ return await self.document_repository.list_all(limit=limit, offset=offset)
243
+
244
+ async def search(
245
+ self, query: str, limit: int = 5, k: int = 60
246
+ ) -> list[tuple[Chunk, float]]:
247
+ """Search for relevant chunks using hybrid search (vector similarity + full-text search).
248
+
249
+ Args:
250
+ query: The search query string
251
+ limit: Maximum number of results to return
252
+ k: Parameter for Reciprocal Rank Fusion (default: 60)
253
+
254
+ Returns:
255
+ List of (chunk, score) tuples ordered by relevance
256
+ """
257
+ return await self.chunk_repository.search_chunks_hybrid(query, limit, k)
258
+
259
+ def close(self):
260
+ """Close the underlying store connection."""
261
+ self.store.close()
haiku/rag/config.py ADDED
@@ -0,0 +1,28 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from dotenv import load_dotenv
5
+ from pydantic import BaseModel
6
+
7
+ from haiku.rag.utils import get_default_data_dir
8
+
9
+ load_dotenv()
10
+
11
+
12
+ class AppConfig(BaseModel):
13
+ ENV: str = "development"
14
+
15
+ DEFAULT_DATA_DIR: Path = get_default_data_dir()
16
+
17
+ EMBEDDING_PROVIDER: str = "ollama"
18
+ EMBEDDING_MODEL: str = "mxbai-embed-large"
19
+ EMBEDDING_VECTOR_DIM: int = 1024
20
+
21
+ CHUNK_SIZE: int = 256
22
+ CHUNK_OVERLAP: int = 32
23
+
24
+ OLLAMA_BASE_URL: str = "http://localhost:11434"
25
+
26
+
27
+ # Expose Config object for app to import
28
+ Config = AppConfig.model_validate(os.environ)
@@ -0,0 +1,24 @@
1
+ from haiku.rag.config import Config
2
+ from haiku.rag.embeddings.base import EmbedderBase
3
+ from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
4
+
5
+
6
+ def get_embedder() -> EmbedderBase:
7
+ """
8
+ Factory function to get the appropriate embedder based on the configuration.
9
+ """
10
+
11
+ if Config.EMBEDDING_PROVIDER == "ollama":
12
+ return OllamaEmbedder(Config.EMBEDDING_MODEL, Config.EMBEDDING_VECTOR_DIM)
13
+
14
+ if Config.EMBEDDING_PROVIDER == "voyageai":
15
+ try:
16
+ from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
17
+ except ImportError:
18
+ raise ImportError(
19
+ "VoyageAI embedder requires the 'voyageai' package. "
20
+ "Please install haiku.rag with the 'voyageai' extra:"
21
+ "uv pip install haiku.rag --extra voyageai"
22
+ )
23
+ return VoyageAIEmbedder(Config.EMBEDDING_MODEL, Config.EMBEDDING_VECTOR_DIM)
24
+ raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDING_PROVIDER}")
@@ -0,0 +1,12 @@
1
+ class EmbedderBase:
2
+ _model: str = ""
3
+ _vector_dim: int = 0
4
+
5
+ def __init__(self, model: str, vector_dim: int):
6
+ self._model = model
7
+ self._vector_dim = vector_dim
8
+
9
+ async def embed(self, text: str) -> list[float]:
10
+ raise NotImplementedError(
11
+ "Embedder is an abstract class. Please implement the embed method in a subclass."
12
+ )
@@ -0,0 +1,14 @@
1
+ from ollama import AsyncClient
2
+
3
+ from haiku.rag.config import Config
4
+ from haiku.rag.embeddings.base import EmbedderBase
5
+
6
+
7
+ class Embedder(EmbedderBase):
8
+ _model: str = Config.EMBEDDING_MODEL
9
+ _vector_dim: int = 1024
10
+
11
+ async def embed(self, text: str) -> list[float]:
12
+ client = AsyncClient(host=Config.OLLAMA_BASE_URL)
13
+ res = await client.embeddings(model=self._model, prompt=text)
14
+ return list(res["embedding"])