haiku.rag 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/a2a/__init__.py +3 -3
- haiku/rag/a2a/client.py +52 -55
- haiku/rag/app.py +19 -10
- haiku/rag/chunker.py +1 -1
- haiku/rag/cli.py +74 -33
- haiku/rag/client.py +83 -14
- haiku/rag/config/__init__.py +54 -0
- haiku/rag/config/loader.py +151 -0
- haiku/rag/config/models.py +78 -0
- haiku/rag/embeddings/__init__.py +17 -11
- haiku/rag/embeddings/base.py +10 -2
- haiku/rag/embeddings/ollama.py +11 -1
- haiku/rag/embeddings/openai.py +8 -0
- haiku/rag/embeddings/vllm.py +9 -1
- haiku/rag/embeddings/voyageai.py +8 -0
- haiku/rag/graph/common.py +2 -2
- haiku/rag/mcp.py +14 -8
- haiku/rag/monitor.py +17 -4
- haiku/rag/qa/__init__.py +16 -3
- haiku/rag/qa/agent.py +4 -2
- haiku/rag/reranking/__init__.py +24 -16
- haiku/rag/reranking/base.py +1 -1
- haiku/rag/reranking/cohere.py +2 -2
- haiku/rag/reranking/mxbai.py +1 -1
- haiku/rag/reranking/vllm.py +1 -1
- haiku/rag/store/engine.py +19 -12
- haiku/rag/store/repositories/chunk.py +12 -8
- haiku/rag/store/repositories/document.py +4 -4
- haiku/rag/store/repositories/settings.py +19 -9
- haiku/rag/utils.py +9 -9
- {haiku_rag-0.12.0.dist-info → haiku_rag-0.13.0.dist-info}/METADATA +21 -11
- {haiku_rag-0.12.0.dist-info → haiku_rag-0.13.0.dist-info}/RECORD +35 -34
- haiku/rag/config.py +0 -90
- haiku/rag/migration.py +0 -316
- {haiku_rag-0.12.0.dist-info → haiku_rag-0.13.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.12.0.dist-info → haiku_rag-0.13.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.12.0.dist-info → haiku_rag-0.13.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/client.py
CHANGED
|
@@ -8,8 +8,7 @@ from urllib.parse import urlparse
|
|
|
8
8
|
|
|
9
9
|
import httpx
|
|
10
10
|
|
|
11
|
-
from haiku.rag.config import Config
|
|
12
|
-
from haiku.rag.reader import FileReader
|
|
11
|
+
from haiku.rag.config import AppConfig, Config
|
|
13
12
|
from haiku.rag.reranking import get_reranker
|
|
14
13
|
from haiku.rag.store.engine import Store
|
|
15
14
|
from haiku.rag.store.models.chunk import Chunk
|
|
@@ -17,7 +16,6 @@ from haiku.rag.store.models.document import Document
|
|
|
17
16
|
from haiku.rag.store.repositories.chunk import ChunkRepository
|
|
18
17
|
from haiku.rag.store.repositories.document import DocumentRepository
|
|
19
18
|
from haiku.rag.store.repositories.settings import SettingsRepository
|
|
20
|
-
from haiku.rag.utils import text_to_docling_document
|
|
21
19
|
|
|
22
20
|
logger = logging.getLogger(__name__)
|
|
23
21
|
|
|
@@ -27,16 +25,23 @@ class HaikuRAG:
|
|
|
27
25
|
|
|
28
26
|
def __init__(
|
|
29
27
|
self,
|
|
30
|
-
db_path: Path
|
|
28
|
+
db_path: Path | None = None,
|
|
29
|
+
config: AppConfig = Config,
|
|
31
30
|
skip_validation: bool = False,
|
|
32
31
|
):
|
|
33
32
|
"""Initialize the RAG client with a database path.
|
|
34
33
|
|
|
35
34
|
Args:
|
|
36
|
-
db_path: Path to the database file.
|
|
35
|
+
db_path: Path to the database file. If None, uses config.storage.data_dir.
|
|
36
|
+
config: Configuration to use. Defaults to global Config.
|
|
37
37
|
skip_validation: Whether to skip configuration validation on database load.
|
|
38
38
|
"""
|
|
39
|
-
self.
|
|
39
|
+
self._config = config
|
|
40
|
+
if db_path is None:
|
|
41
|
+
db_path = self._config.storage.data_dir / "haiku.rag.lancedb"
|
|
42
|
+
self.store = Store(
|
|
43
|
+
db_path, config=self._config, skip_validation=skip_validation
|
|
44
|
+
)
|
|
40
45
|
self.document_repository = DocumentRepository(self.store)
|
|
41
46
|
self.chunk_repository = ChunkRepository(self.store)
|
|
42
47
|
|
|
@@ -91,6 +96,9 @@ class HaikuRAG:
|
|
|
91
96
|
Returns:
|
|
92
97
|
The created Document instance.
|
|
93
98
|
"""
|
|
99
|
+
# Lazy import to avoid loading docling
|
|
100
|
+
from haiku.rag.utils import text_to_docling_document
|
|
101
|
+
|
|
94
102
|
# Convert content to DoclingDocument for processing
|
|
95
103
|
docling_document = text_to_docling_document(content)
|
|
96
104
|
|
|
@@ -106,8 +114,8 @@ class HaikuRAG:
|
|
|
106
114
|
|
|
107
115
|
async def create_document_from_source(
|
|
108
116
|
self, source: str | Path, title: str | None = None, metadata: dict | None = None
|
|
109
|
-
) -> Document:
|
|
110
|
-
"""Create or update
|
|
117
|
+
) -> Document | list[Document]:
|
|
118
|
+
"""Create or update document(s) from a file path, directory, or URL.
|
|
111
119
|
|
|
112
120
|
Checks if a document with the same URI already exists:
|
|
113
121
|
- If MD5 is unchanged, returns existing document
|
|
@@ -115,16 +123,20 @@ class HaikuRAG:
|
|
|
115
123
|
- If no document exists, creates a new one
|
|
116
124
|
|
|
117
125
|
Args:
|
|
118
|
-
source: File path (as string or Path) or URL to parse
|
|
126
|
+
source: File path, directory (as string or Path), or URL to parse
|
|
127
|
+
title: Optional title (only used for single files, not directories)
|
|
119
128
|
metadata: Optional metadata dictionary
|
|
120
129
|
|
|
121
130
|
Returns:
|
|
122
|
-
Document instance (created, updated, or existing)
|
|
131
|
+
Document instance (created, updated, or existing) for single files/URLs
|
|
132
|
+
List of Document instances for directories
|
|
123
133
|
|
|
124
134
|
Raises:
|
|
125
135
|
ValueError: If the file/URL cannot be parsed or doesn't exist
|
|
126
136
|
httpx.RequestError: If URL request fails
|
|
127
137
|
"""
|
|
138
|
+
# Lazy import to avoid loading docling
|
|
139
|
+
from haiku.rag.reader import FileReader
|
|
128
140
|
|
|
129
141
|
# Normalize metadata
|
|
130
142
|
metadata = metadata or {}
|
|
@@ -142,6 +154,48 @@ class HaikuRAG:
|
|
|
142
154
|
else:
|
|
143
155
|
# Handle as regular file path
|
|
144
156
|
source_path = Path(source) if isinstance(source, str) else source
|
|
157
|
+
|
|
158
|
+
# Handle directories
|
|
159
|
+
if source_path.is_dir():
|
|
160
|
+
documents = []
|
|
161
|
+
supported_extensions = set(FileReader.extensions)
|
|
162
|
+
for file_path in source_path.rglob("*"):
|
|
163
|
+
if (
|
|
164
|
+
file_path.is_file()
|
|
165
|
+
and file_path.suffix.lower() in supported_extensions
|
|
166
|
+
):
|
|
167
|
+
doc = await self._create_document_from_file(
|
|
168
|
+
file_path, title=None, metadata=metadata
|
|
169
|
+
)
|
|
170
|
+
documents.append(doc)
|
|
171
|
+
return documents
|
|
172
|
+
|
|
173
|
+
# Handle single file
|
|
174
|
+
return await self._create_document_from_file(
|
|
175
|
+
source_path, title=title, metadata=metadata
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
async def _create_document_from_file(
|
|
179
|
+
self, source_path: Path, title: str | None = None, metadata: dict | None = None
|
|
180
|
+
) -> Document:
|
|
181
|
+
"""Create or update a document from a single file path.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
source_path: Path to the file
|
|
185
|
+
title: Optional title
|
|
186
|
+
metadata: Optional metadata dictionary
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Document instance (created, updated, or existing)
|
|
190
|
+
|
|
191
|
+
Raises:
|
|
192
|
+
ValueError: If the file cannot be parsed or doesn't exist
|
|
193
|
+
"""
|
|
194
|
+
# Lazy import to avoid loading docling
|
|
195
|
+
from haiku.rag.reader import FileReader
|
|
196
|
+
|
|
197
|
+
metadata = metadata or {}
|
|
198
|
+
|
|
145
199
|
if source_path.suffix.lower() not in FileReader.extensions:
|
|
146
200
|
raise ValueError(f"Unsupported file extension: {source_path.suffix}")
|
|
147
201
|
|
|
@@ -215,6 +269,9 @@ class HaikuRAG:
|
|
|
215
269
|
ValueError: If the content cannot be parsed
|
|
216
270
|
httpx.RequestError: If URL request fails
|
|
217
271
|
"""
|
|
272
|
+
# Lazy import to avoid loading docling
|
|
273
|
+
from haiku.rag.reader import FileReader
|
|
274
|
+
|
|
218
275
|
metadata = metadata or {}
|
|
219
276
|
|
|
220
277
|
async with httpx.AsyncClient() as client:
|
|
@@ -338,6 +395,9 @@ class HaikuRAG:
|
|
|
338
395
|
|
|
339
396
|
async def update_document(self, document: Document) -> Document:
|
|
340
397
|
"""Update an existing document."""
|
|
398
|
+
# Lazy import to avoid loading docling
|
|
399
|
+
from haiku.rag.utils import text_to_docling_document
|
|
400
|
+
|
|
341
401
|
# Convert content to DoclingDocument
|
|
342
402
|
docling_document = text_to_docling_document(document.content)
|
|
343
403
|
|
|
@@ -377,7 +437,7 @@ class HaikuRAG:
|
|
|
377
437
|
List of (chunk, score) tuples ordered by relevance.
|
|
378
438
|
"""
|
|
379
439
|
# Get reranker if available
|
|
380
|
-
reranker = get_reranker()
|
|
440
|
+
reranker = get_reranker(config=self._config)
|
|
381
441
|
|
|
382
442
|
if reranker is None:
|
|
383
443
|
# No reranking - return direct search results
|
|
@@ -399,18 +459,20 @@ class HaikuRAG:
|
|
|
399
459
|
async def expand_context(
|
|
400
460
|
self,
|
|
401
461
|
search_results: list[tuple[Chunk, float]],
|
|
402
|
-
radius: int =
|
|
462
|
+
radius: int | None = None,
|
|
403
463
|
) -> list[tuple[Chunk, float]]:
|
|
404
464
|
"""Expand search results with adjacent chunks, merging overlapping chunks.
|
|
405
465
|
|
|
406
466
|
Args:
|
|
407
467
|
search_results: List of (chunk, score) tuples from search.
|
|
408
468
|
radius: Number of adjacent chunks to include before/after each chunk.
|
|
409
|
-
|
|
469
|
+
If None, uses config.processing.context_chunk_radius.
|
|
410
470
|
|
|
411
471
|
Returns:
|
|
412
472
|
List of (chunk, score) tuples with expanded and merged context chunks.
|
|
413
473
|
"""
|
|
474
|
+
if radius is None:
|
|
475
|
+
radius = self._config.processing.context_chunk_radius
|
|
414
476
|
if radius == 0:
|
|
415
477
|
return search_results
|
|
416
478
|
|
|
@@ -540,7 +602,9 @@ class HaikuRAG:
|
|
|
540
602
|
"""
|
|
541
603
|
from haiku.rag.qa import get_qa_agent
|
|
542
604
|
|
|
543
|
-
qa_agent = get_qa_agent(
|
|
605
|
+
qa_agent = get_qa_agent(
|
|
606
|
+
self, config=self._config, use_citations=cite, system_prompt=system_prompt
|
|
607
|
+
)
|
|
544
608
|
return await qa_agent.answer(question)
|
|
545
609
|
|
|
546
610
|
async def rebuild_database(self) -> AsyncGenerator[str, None]:
|
|
@@ -556,6 +620,9 @@ class HaikuRAG:
|
|
|
556
620
|
Yields:
|
|
557
621
|
int: The ID of the document currently being processed
|
|
558
622
|
"""
|
|
623
|
+
# Lazy import to avoid loading docling
|
|
624
|
+
from haiku.rag.utils import text_to_docling_document
|
|
625
|
+
|
|
559
626
|
await self.chunk_repository.delete_all()
|
|
560
627
|
self.store.recreate_embeddings_table()
|
|
561
628
|
|
|
@@ -592,6 +659,8 @@ class HaikuRAG:
|
|
|
592
659
|
new_doc = await self.create_document_from_source(
|
|
593
660
|
source=doc.uri, metadata=doc.metadata or {}
|
|
594
661
|
)
|
|
662
|
+
# URIs always point to single files/URLs, never directories
|
|
663
|
+
assert isinstance(new_doc, Document)
|
|
595
664
|
assert new_doc.id is not None, (
|
|
596
665
|
"New document ID should not be None"
|
|
597
666
|
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from haiku.rag.config.loader import (
|
|
4
|
+
check_for_deprecated_env,
|
|
5
|
+
find_config_file,
|
|
6
|
+
generate_default_config,
|
|
7
|
+
load_config_from_env,
|
|
8
|
+
load_yaml_config,
|
|
9
|
+
)
|
|
10
|
+
from haiku.rag.config.models import (
|
|
11
|
+
A2AConfig,
|
|
12
|
+
AppConfig,
|
|
13
|
+
EmbeddingsConfig,
|
|
14
|
+
LanceDBConfig,
|
|
15
|
+
OllamaConfig,
|
|
16
|
+
ProcessingConfig,
|
|
17
|
+
ProvidersConfig,
|
|
18
|
+
QAConfig,
|
|
19
|
+
RerankingConfig,
|
|
20
|
+
ResearchConfig,
|
|
21
|
+
StorageConfig,
|
|
22
|
+
VLLMConfig,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Config",
|
|
27
|
+
"AppConfig",
|
|
28
|
+
"StorageConfig",
|
|
29
|
+
"LanceDBConfig",
|
|
30
|
+
"EmbeddingsConfig",
|
|
31
|
+
"RerankingConfig",
|
|
32
|
+
"QAConfig",
|
|
33
|
+
"ResearchConfig",
|
|
34
|
+
"ProcessingConfig",
|
|
35
|
+
"OllamaConfig",
|
|
36
|
+
"VLLMConfig",
|
|
37
|
+
"ProvidersConfig",
|
|
38
|
+
"A2AConfig",
|
|
39
|
+
"find_config_file",
|
|
40
|
+
"load_yaml_config",
|
|
41
|
+
"generate_default_config",
|
|
42
|
+
"load_config_from_env",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
# Load config from YAML file or use defaults
|
|
46
|
+
config_path = find_config_file(None)
|
|
47
|
+
if config_path:
|
|
48
|
+
yaml_data = load_yaml_config(config_path)
|
|
49
|
+
Config = AppConfig.model_validate(yaml_data)
|
|
50
|
+
else:
|
|
51
|
+
Config = AppConfig()
|
|
52
|
+
|
|
53
|
+
# Check for deprecated .env file
|
|
54
|
+
check_for_deprecated_env()
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import yaml
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def find_config_file(cli_path: Path | None = None) -> Path | None:
|
|
9
|
+
"""Find the YAML config file using the search path.
|
|
10
|
+
|
|
11
|
+
Search order:
|
|
12
|
+
1. CLI-provided path (via HAIKU_RAG_CONFIG_PATH env var or parameter)
|
|
13
|
+
2. ./haiku.rag.yaml (current directory)
|
|
14
|
+
3. ~/.config/haiku.rag/config.yaml (user config)
|
|
15
|
+
|
|
16
|
+
Returns None if no config file is found.
|
|
17
|
+
"""
|
|
18
|
+
# Check environment variable first (set by CLI --config flag)
|
|
19
|
+
if not cli_path:
|
|
20
|
+
env_path = os.getenv("HAIKU_RAG_CONFIG_PATH")
|
|
21
|
+
if env_path:
|
|
22
|
+
cli_path = Path(env_path)
|
|
23
|
+
|
|
24
|
+
if cli_path:
|
|
25
|
+
if cli_path.exists():
|
|
26
|
+
return cli_path
|
|
27
|
+
raise FileNotFoundError(f"Config file not found: {cli_path}")
|
|
28
|
+
|
|
29
|
+
cwd_config = Path.cwd() / "haiku.rag.yaml"
|
|
30
|
+
if cwd_config.exists():
|
|
31
|
+
return cwd_config
|
|
32
|
+
|
|
33
|
+
user_config_dir = Path.home() / ".config" / "haiku.rag"
|
|
34
|
+
user_config = user_config_dir / "config.yaml"
|
|
35
|
+
if user_config.exists():
|
|
36
|
+
return user_config
|
|
37
|
+
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_yaml_config(path: Path) -> dict:
|
|
42
|
+
"""Load and parse a YAML config file."""
|
|
43
|
+
with open(path) as f:
|
|
44
|
+
data = yaml.safe_load(f)
|
|
45
|
+
return data or {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def check_for_deprecated_env() -> None:
|
|
49
|
+
"""Check for .env file and warn if found."""
|
|
50
|
+
env_file = Path.cwd() / ".env"
|
|
51
|
+
if env_file.exists():
|
|
52
|
+
warnings.warn(
|
|
53
|
+
".env file detected but YAML configuration is now preferred. "
|
|
54
|
+
"Environment variable configuration is deprecated and will be removed in future versions."
|
|
55
|
+
"Run 'haiku-rag init-config' to generate a YAML config file.",
|
|
56
|
+
DeprecationWarning,
|
|
57
|
+
stacklevel=2,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_default_config() -> dict:
|
|
62
|
+
"""Generate a default YAML config structure with documentation."""
|
|
63
|
+
return {
|
|
64
|
+
"environment": "production",
|
|
65
|
+
"storage": {
|
|
66
|
+
"data_dir": "",
|
|
67
|
+
"monitor_directories": [],
|
|
68
|
+
"disable_autocreate": False,
|
|
69
|
+
"vacuum_retention_seconds": 60,
|
|
70
|
+
},
|
|
71
|
+
"lancedb": {"uri": "", "api_key": "", "region": ""},
|
|
72
|
+
"embeddings": {
|
|
73
|
+
"provider": "ollama",
|
|
74
|
+
"model": "qwen3-embedding",
|
|
75
|
+
"vector_dim": 4096,
|
|
76
|
+
},
|
|
77
|
+
"reranking": {"provider": "", "model": ""},
|
|
78
|
+
"qa": {"provider": "ollama", "model": "gpt-oss"},
|
|
79
|
+
"research": {"provider": "", "model": ""},
|
|
80
|
+
"processing": {
|
|
81
|
+
"chunk_size": 256,
|
|
82
|
+
"context_chunk_radius": 0,
|
|
83
|
+
"markdown_preprocessor": "",
|
|
84
|
+
},
|
|
85
|
+
"providers": {
|
|
86
|
+
"ollama": {"base_url": "http://localhost:11434"},
|
|
87
|
+
"vllm": {
|
|
88
|
+
"embeddings_base_url": "",
|
|
89
|
+
"rerank_base_url": "",
|
|
90
|
+
"qa_base_url": "",
|
|
91
|
+
"research_base_url": "",
|
|
92
|
+
},
|
|
93
|
+
},
|
|
94
|
+
"a2a": {"max_contexts": 1000},
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def load_config_from_env() -> dict:
|
|
99
|
+
"""Load current config from environment variables (for migration)."""
|
|
100
|
+
result = {}
|
|
101
|
+
|
|
102
|
+
env_mappings = {
|
|
103
|
+
"ENV": "environment",
|
|
104
|
+
"DEFAULT_DATA_DIR": ("storage", "data_dir"),
|
|
105
|
+
"MONITOR_DIRECTORIES": ("storage", "monitor_directories"),
|
|
106
|
+
"DISABLE_DB_AUTOCREATE": ("storage", "disable_autocreate"),
|
|
107
|
+
"VACUUM_RETENTION_SECONDS": ("storage", "vacuum_retention_seconds"),
|
|
108
|
+
"LANCEDB_URI": ("lancedb", "uri"),
|
|
109
|
+
"LANCEDB_API_KEY": ("lancedb", "api_key"),
|
|
110
|
+
"LANCEDB_REGION": ("lancedb", "region"),
|
|
111
|
+
"EMBEDDINGS_PROVIDER": ("embeddings", "provider"),
|
|
112
|
+
"EMBEDDINGS_MODEL": ("embeddings", "model"),
|
|
113
|
+
"EMBEDDINGS_VECTOR_DIM": ("embeddings", "vector_dim"),
|
|
114
|
+
"RERANK_PROVIDER": ("reranking", "provider"),
|
|
115
|
+
"RERANK_MODEL": ("reranking", "model"),
|
|
116
|
+
"QA_PROVIDER": ("qa", "provider"),
|
|
117
|
+
"QA_MODEL": ("qa", "model"),
|
|
118
|
+
"RESEARCH_PROVIDER": ("research", "provider"),
|
|
119
|
+
"RESEARCH_MODEL": ("research", "model"),
|
|
120
|
+
"CHUNK_SIZE": ("processing", "chunk_size"),
|
|
121
|
+
"CONTEXT_CHUNK_RADIUS": ("processing", "context_chunk_radius"),
|
|
122
|
+
"MARKDOWN_PREPROCESSOR": ("processing", "markdown_preprocessor"),
|
|
123
|
+
"OLLAMA_BASE_URL": ("providers", "ollama", "base_url"),
|
|
124
|
+
"VLLM_EMBEDDINGS_BASE_URL": ("providers", "vllm", "embeddings_base_url"),
|
|
125
|
+
"VLLM_RERANK_BASE_URL": ("providers", "vllm", "rerank_base_url"),
|
|
126
|
+
"VLLM_QA_BASE_URL": ("providers", "vllm", "qa_base_url"),
|
|
127
|
+
"VLLM_RESEARCH_BASE_URL": ("providers", "vllm", "research_base_url"),
|
|
128
|
+
"A2A_MAX_CONTEXTS": ("a2a", "max_contexts"),
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
for env_var, path in env_mappings.items():
|
|
132
|
+
value = os.getenv(env_var)
|
|
133
|
+
if value is not None:
|
|
134
|
+
# Special handling for MONITOR_DIRECTORIES - parse comma-separated list
|
|
135
|
+
if env_var == "MONITOR_DIRECTORIES":
|
|
136
|
+
if value.strip():
|
|
137
|
+
value = [p.strip() for p in value.split(",") if p.strip()]
|
|
138
|
+
else:
|
|
139
|
+
value = []
|
|
140
|
+
|
|
141
|
+
if isinstance(path, tuple):
|
|
142
|
+
current = result
|
|
143
|
+
for key in path[:-1]:
|
|
144
|
+
if key not in current:
|
|
145
|
+
current[key] = {}
|
|
146
|
+
current = current[key]
|
|
147
|
+
current[path[-1]] = value
|
|
148
|
+
else:
|
|
149
|
+
result[path] = value
|
|
150
|
+
|
|
151
|
+
return result
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from haiku.rag.utils import get_default_data_dir
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StorageConfig(BaseModel):
|
|
9
|
+
data_dir: Path = Field(default_factory=get_default_data_dir)
|
|
10
|
+
monitor_directories: list[Path] = []
|
|
11
|
+
disable_autocreate: bool = False
|
|
12
|
+
vacuum_retention_seconds: int = 60
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LanceDBConfig(BaseModel):
|
|
16
|
+
uri: str = ""
|
|
17
|
+
api_key: str = ""
|
|
18
|
+
region: str = ""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EmbeddingsConfig(BaseModel):
|
|
22
|
+
provider: str = "ollama"
|
|
23
|
+
model: str = "qwen3-embedding"
|
|
24
|
+
vector_dim: int = 4096
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RerankingConfig(BaseModel):
|
|
28
|
+
provider: str = ""
|
|
29
|
+
model: str = ""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QAConfig(BaseModel):
|
|
33
|
+
provider: str = "ollama"
|
|
34
|
+
model: str = "gpt-oss"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ResearchConfig(BaseModel):
|
|
38
|
+
provider: str = "ollama"
|
|
39
|
+
model: str = "gpt-oss"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ProcessingConfig(BaseModel):
|
|
43
|
+
chunk_size: int = 256
|
|
44
|
+
context_chunk_radius: int = 0
|
|
45
|
+
markdown_preprocessor: str = ""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class OllamaConfig(BaseModel):
|
|
49
|
+
base_url: str = "http://localhost:11434"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class VLLMConfig(BaseModel):
|
|
53
|
+
embeddings_base_url: str = ""
|
|
54
|
+
rerank_base_url: str = ""
|
|
55
|
+
qa_base_url: str = ""
|
|
56
|
+
research_base_url: str = ""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ProvidersConfig(BaseModel):
|
|
60
|
+
ollama: OllamaConfig = Field(default_factory=OllamaConfig)
|
|
61
|
+
vllm: VLLMConfig = Field(default_factory=VLLMConfig)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class A2AConfig(BaseModel):
|
|
65
|
+
max_contexts: int = 1000
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AppConfig(BaseModel):
|
|
69
|
+
environment: str = "production"
|
|
70
|
+
storage: StorageConfig = Field(default_factory=StorageConfig)
|
|
71
|
+
lancedb: LanceDBConfig = Field(default_factory=LanceDBConfig)
|
|
72
|
+
embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig)
|
|
73
|
+
reranking: RerankingConfig = Field(default_factory=RerankingConfig)
|
|
74
|
+
qa: QAConfig = Field(default_factory=QAConfig)
|
|
75
|
+
research: ResearchConfig = Field(default_factory=ResearchConfig)
|
|
76
|
+
processing: ProcessingConfig = Field(default_factory=ProcessingConfig)
|
|
77
|
+
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
|
78
|
+
a2a: A2AConfig = Field(default_factory=A2AConfig)
|
haiku/rag/embeddings/__init__.py
CHANGED
|
@@ -1,17 +1,23 @@
|
|
|
1
|
-
from haiku.rag.config import Config
|
|
1
|
+
from haiku.rag.config import AppConfig, Config
|
|
2
2
|
from haiku.rag.embeddings.base import EmbedderBase
|
|
3
3
|
from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
def get_embedder() -> EmbedderBase:
|
|
6
|
+
def get_embedder(config: AppConfig = Config) -> EmbedderBase:
|
|
7
7
|
"""
|
|
8
8
|
Factory function to get the appropriate embedder based on the configuration.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
config: Configuration to use. Defaults to global Config.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
An embedder instance configured according to the config.
|
|
9
15
|
"""
|
|
10
16
|
|
|
11
|
-
if
|
|
12
|
-
return OllamaEmbedder(
|
|
17
|
+
if config.embeddings.provider == "ollama":
|
|
18
|
+
return OllamaEmbedder(config.embeddings.model, config.embeddings.vector_dim)
|
|
13
19
|
|
|
14
|
-
if
|
|
20
|
+
if config.embeddings.provider == "voyageai":
|
|
15
21
|
try:
|
|
16
22
|
from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
|
|
17
23
|
except ImportError:
|
|
@@ -20,16 +26,16 @@ def get_embedder() -> EmbedderBase:
|
|
|
20
26
|
"Please install haiku.rag with the 'voyageai' extra: "
|
|
21
27
|
"uv pip install haiku.rag[voyageai]"
|
|
22
28
|
)
|
|
23
|
-
return VoyageAIEmbedder(
|
|
29
|
+
return VoyageAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
|
|
24
30
|
|
|
25
|
-
if
|
|
31
|
+
if config.embeddings.provider == "openai":
|
|
26
32
|
from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
|
|
27
33
|
|
|
28
|
-
return OpenAIEmbedder(
|
|
34
|
+
return OpenAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
|
|
29
35
|
|
|
30
|
-
if
|
|
36
|
+
if config.embeddings.provider == "vllm":
|
|
31
37
|
from haiku.rag.embeddings.vllm import Embedder as VllmEmbedder
|
|
32
38
|
|
|
33
|
-
return VllmEmbedder(
|
|
39
|
+
return VllmEmbedder(config.embeddings.model, config.embeddings.vector_dim)
|
|
34
40
|
|
|
35
|
-
raise ValueError(f"Unsupported embedding provider: {
|
|
41
|
+
raise ValueError(f"Unsupported embedding provider: {config.embeddings.provider}")
|
haiku/rag/embeddings/base.py
CHANGED
|
@@ -1,14 +1,22 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
1
3
|
from haiku.rag.config import Config
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class EmbedderBase:
|
|
5
|
-
_model: str = Config.
|
|
6
|
-
_vector_dim: int = Config.
|
|
7
|
+
_model: str = Config.embeddings.model
|
|
8
|
+
_vector_dim: int = Config.embeddings.vector_dim
|
|
7
9
|
|
|
8
10
|
def __init__(self, model: str, vector_dim: int):
|
|
9
11
|
self._model = model
|
|
10
12
|
self._vector_dim = vector_dim
|
|
11
13
|
|
|
14
|
+
@overload
|
|
15
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
16
|
+
|
|
17
|
+
@overload
|
|
18
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
19
|
+
|
|
12
20
|
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
13
21
|
raise NotImplementedError(
|
|
14
22
|
"Embedder is an abstract class. Please implement the embed method in a subclass."
|
haiku/rag/embeddings/ollama.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
1
3
|
from openai import AsyncOpenAI
|
|
2
4
|
|
|
3
5
|
from haiku.rag.config import Config
|
|
@@ -5,8 +7,16 @@ from haiku.rag.embeddings.base import EmbedderBase
|
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class Embedder(EmbedderBase):
|
|
10
|
+
@overload
|
|
11
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
12
|
+
|
|
13
|
+
@overload
|
|
14
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
15
|
+
|
|
8
16
|
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
9
|
-
client = AsyncOpenAI(
|
|
17
|
+
client = AsyncOpenAI(
|
|
18
|
+
base_url=f"{Config.providers.ollama.base_url}/v1", api_key="dummy"
|
|
19
|
+
)
|
|
10
20
|
if not text:
|
|
11
21
|
return []
|
|
12
22
|
response = await client.embeddings.create(
|
haiku/rag/embeddings/openai.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
1
3
|
from openai import AsyncOpenAI
|
|
2
4
|
|
|
3
5
|
from haiku.rag.embeddings.base import EmbedderBase
|
|
4
6
|
|
|
5
7
|
|
|
6
8
|
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
7
15
|
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
8
16
|
client = AsyncOpenAI()
|
|
9
17
|
if not text:
|
haiku/rag/embeddings/vllm.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
1
3
|
from openai import AsyncOpenAI
|
|
2
4
|
|
|
3
5
|
from haiku.rag.config import Config
|
|
@@ -5,9 +7,15 @@ from haiku.rag.embeddings.base import EmbedderBase
|
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class Embedder(EmbedderBase):
|
|
10
|
+
@overload
|
|
11
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
12
|
+
|
|
13
|
+
@overload
|
|
14
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
15
|
+
|
|
8
16
|
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
9
17
|
client = AsyncOpenAI(
|
|
10
|
-
base_url=f"{Config.
|
|
18
|
+
base_url=f"{Config.providers.vllm.embeddings_base_url}/v1", api_key="dummy"
|
|
11
19
|
)
|
|
12
20
|
if not text:
|
|
13
21
|
return []
|
haiku/rag/embeddings/voyageai.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
1
|
try:
|
|
2
|
+
from typing import overload
|
|
3
|
+
|
|
2
4
|
from voyageai.client import Client # type: ignore
|
|
3
5
|
|
|
4
6
|
from haiku.rag.embeddings.base import EmbedderBase
|
|
5
7
|
|
|
6
8
|
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
7
15
|
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
8
16
|
client = Client()
|
|
9
17
|
if not text:
|
haiku/rag/graph/common.py
CHANGED
|
@@ -15,13 +15,13 @@ def get_model(provider: str, model: str) -> Any:
|
|
|
15
15
|
if provider == "ollama":
|
|
16
16
|
return OpenAIChatModel(
|
|
17
17
|
model_name=model,
|
|
18
|
-
provider=OllamaProvider(base_url=f"{Config.
|
|
18
|
+
provider=OllamaProvider(base_url=f"{Config.providers.ollama.base_url}/v1"),
|
|
19
19
|
)
|
|
20
20
|
elif provider == "vllm":
|
|
21
21
|
return OpenAIChatModel(
|
|
22
22
|
model_name=model,
|
|
23
23
|
provider=OpenAIProvider(
|
|
24
|
-
base_url=f"{Config.
|
|
24
|
+
base_url=f"{Config.providers.vllm.research_base_url or Config.providers.vllm.qa_base_url}/v1",
|
|
25
25
|
api_key="none",
|
|
26
26
|
),
|
|
27
27
|
)
|