haiku.rag 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

haiku/rag/client.py CHANGED
@@ -8,8 +8,7 @@ from urllib.parse import urlparse
8
8
 
9
9
  import httpx
10
10
 
11
- from haiku.rag.config import Config
12
- from haiku.rag.reader import FileReader
11
+ from haiku.rag.config import AppConfig, Config
13
12
  from haiku.rag.reranking import get_reranker
14
13
  from haiku.rag.store.engine import Store
15
14
  from haiku.rag.store.models.chunk import Chunk
@@ -17,7 +16,6 @@ from haiku.rag.store.models.document import Document
17
16
  from haiku.rag.store.repositories.chunk import ChunkRepository
18
17
  from haiku.rag.store.repositories.document import DocumentRepository
19
18
  from haiku.rag.store.repositories.settings import SettingsRepository
20
- from haiku.rag.utils import text_to_docling_document
21
19
 
22
20
  logger = logging.getLogger(__name__)
23
21
 
@@ -27,16 +25,23 @@ class HaikuRAG:
27
25
 
28
26
  def __init__(
29
27
  self,
30
- db_path: Path = Config.DEFAULT_DATA_DIR / "haiku.rag.lancedb",
28
+ db_path: Path | None = None,
29
+ config: AppConfig = Config,
31
30
  skip_validation: bool = False,
32
31
  ):
33
32
  """Initialize the RAG client with a database path.
34
33
 
35
34
  Args:
36
- db_path: Path to the database file.
35
+ db_path: Path to the database file. If None, uses config.storage.data_dir.
36
+ config: Configuration to use. Defaults to global Config.
37
37
  skip_validation: Whether to skip configuration validation on database load.
38
38
  """
39
- self.store = Store(db_path, skip_validation=skip_validation)
39
+ self._config = config
40
+ if db_path is None:
41
+ db_path = self._config.storage.data_dir / "haiku.rag.lancedb"
42
+ self.store = Store(
43
+ db_path, config=self._config, skip_validation=skip_validation
44
+ )
40
45
  self.document_repository = DocumentRepository(self.store)
41
46
  self.chunk_repository = ChunkRepository(self.store)
42
47
 
@@ -91,6 +96,9 @@ class HaikuRAG:
91
96
  Returns:
92
97
  The created Document instance.
93
98
  """
99
+ # Lazy import to avoid loading docling
100
+ from haiku.rag.utils import text_to_docling_document
101
+
94
102
  # Convert content to DoclingDocument for processing
95
103
  docling_document = text_to_docling_document(content)
96
104
 
@@ -106,8 +114,8 @@ class HaikuRAG:
106
114
 
107
115
  async def create_document_from_source(
108
116
  self, source: str | Path, title: str | None = None, metadata: dict | None = None
109
- ) -> Document:
110
- """Create or update a document from a file path or URL.
117
+ ) -> Document | list[Document]:
118
+ """Create or update document(s) from a file path, directory, or URL.
111
119
 
112
120
  Checks if a document with the same URI already exists:
113
121
  - If MD5 is unchanged, returns existing document
@@ -115,16 +123,20 @@ class HaikuRAG:
115
123
  - If no document exists, creates a new one
116
124
 
117
125
  Args:
118
- source: File path (as string or Path) or URL to parse
126
+ source: File path, directory (as string or Path), or URL to parse
127
+ title: Optional title (only used for single files, not directories)
119
128
  metadata: Optional metadata dictionary
120
129
 
121
130
  Returns:
122
- Document instance (created, updated, or existing)
131
+ Document instance (created, updated, or existing) for single files/URLs
132
+ List of Document instances for directories
123
133
 
124
134
  Raises:
125
135
  ValueError: If the file/URL cannot be parsed or doesn't exist
126
136
  httpx.RequestError: If URL request fails
127
137
  """
138
+ # Lazy import to avoid loading docling
139
+ from haiku.rag.reader import FileReader
128
140
 
129
141
  # Normalize metadata
130
142
  metadata = metadata or {}
@@ -142,6 +154,48 @@ class HaikuRAG:
142
154
  else:
143
155
  # Handle as regular file path
144
156
  source_path = Path(source) if isinstance(source, str) else source
157
+
158
+ # Handle directories
159
+ if source_path.is_dir():
160
+ documents = []
161
+ supported_extensions = set(FileReader.extensions)
162
+ for file_path in source_path.rglob("*"):
163
+ if (
164
+ file_path.is_file()
165
+ and file_path.suffix.lower() in supported_extensions
166
+ ):
167
+ doc = await self._create_document_from_file(
168
+ file_path, title=None, metadata=metadata
169
+ )
170
+ documents.append(doc)
171
+ return documents
172
+
173
+ # Handle single file
174
+ return await self._create_document_from_file(
175
+ source_path, title=title, metadata=metadata
176
+ )
177
+
178
+ async def _create_document_from_file(
179
+ self, source_path: Path, title: str | None = None, metadata: dict | None = None
180
+ ) -> Document:
181
+ """Create or update a document from a single file path.
182
+
183
+ Args:
184
+ source_path: Path to the file
185
+ title: Optional title
186
+ metadata: Optional metadata dictionary
187
+
188
+ Returns:
189
+ Document instance (created, updated, or existing)
190
+
191
+ Raises:
192
+ ValueError: If the file cannot be parsed or doesn't exist
193
+ """
194
+ # Lazy import to avoid loading docling
195
+ from haiku.rag.reader import FileReader
196
+
197
+ metadata = metadata or {}
198
+
145
199
  if source_path.suffix.lower() not in FileReader.extensions:
146
200
  raise ValueError(f"Unsupported file extension: {source_path.suffix}")
147
201
 
@@ -215,6 +269,9 @@ class HaikuRAG:
215
269
  ValueError: If the content cannot be parsed
216
270
  httpx.RequestError: If URL request fails
217
271
  """
272
+ # Lazy import to avoid loading docling
273
+ from haiku.rag.reader import FileReader
274
+
218
275
  metadata = metadata or {}
219
276
 
220
277
  async with httpx.AsyncClient() as client:
@@ -338,6 +395,9 @@ class HaikuRAG:
338
395
 
339
396
  async def update_document(self, document: Document) -> Document:
340
397
  """Update an existing document."""
398
+ # Lazy import to avoid loading docling
399
+ from haiku.rag.utils import text_to_docling_document
400
+
341
401
  # Convert content to DoclingDocument
342
402
  docling_document = text_to_docling_document(document.content)
343
403
 
@@ -377,7 +437,7 @@ class HaikuRAG:
377
437
  List of (chunk, score) tuples ordered by relevance.
378
438
  """
379
439
  # Get reranker if available
380
- reranker = get_reranker()
440
+ reranker = get_reranker(config=self._config)
381
441
 
382
442
  if reranker is None:
383
443
  # No reranking - return direct search results
@@ -399,18 +459,20 @@ class HaikuRAG:
399
459
  async def expand_context(
400
460
  self,
401
461
  search_results: list[tuple[Chunk, float]],
402
- radius: int = Config.CONTEXT_CHUNK_RADIUS,
462
+ radius: int | None = None,
403
463
  ) -> list[tuple[Chunk, float]]:
404
464
  """Expand search results with adjacent chunks, merging overlapping chunks.
405
465
 
406
466
  Args:
407
467
  search_results: List of (chunk, score) tuples from search.
408
468
  radius: Number of adjacent chunks to include before/after each chunk.
409
- Defaults to CONTEXT_CHUNK_RADIUS config setting.
469
+ If None, uses config.processing.context_chunk_radius.
410
470
 
411
471
  Returns:
412
472
  List of (chunk, score) tuples with expanded and merged context chunks.
413
473
  """
474
+ if radius is None:
475
+ radius = self._config.processing.context_chunk_radius
414
476
  if radius == 0:
415
477
  return search_results
416
478
 
@@ -540,7 +602,9 @@ class HaikuRAG:
540
602
  """
541
603
  from haiku.rag.qa import get_qa_agent
542
604
 
543
- qa_agent = get_qa_agent(self, use_citations=cite, system_prompt=system_prompt)
605
+ qa_agent = get_qa_agent(
606
+ self, config=self._config, use_citations=cite, system_prompt=system_prompt
607
+ )
544
608
  return await qa_agent.answer(question)
545
609
 
546
610
  async def rebuild_database(self) -> AsyncGenerator[str, None]:
@@ -556,6 +620,9 @@ class HaikuRAG:
556
620
  Yields:
557
621
  int: The ID of the document currently being processed
558
622
  """
623
+ # Lazy import to avoid loading docling
624
+ from haiku.rag.utils import text_to_docling_document
625
+
559
626
  await self.chunk_repository.delete_all()
560
627
  self.store.recreate_embeddings_table()
561
628
 
@@ -592,6 +659,8 @@ class HaikuRAG:
592
659
  new_doc = await self.create_document_from_source(
593
660
  source=doc.uri, metadata=doc.metadata or {}
594
661
  )
662
+ # URIs always point to single files/URLs, never directories
663
+ assert isinstance(new_doc, Document)
595
664
  assert new_doc.id is not None, (
596
665
  "New document ID should not be None"
597
666
  )
@@ -0,0 +1,54 @@
1
+ import os
2
+
3
+ from haiku.rag.config.loader import (
4
+ check_for_deprecated_env,
5
+ find_config_file,
6
+ generate_default_config,
7
+ load_config_from_env,
8
+ load_yaml_config,
9
+ )
10
+ from haiku.rag.config.models import (
11
+ A2AConfig,
12
+ AppConfig,
13
+ EmbeddingsConfig,
14
+ LanceDBConfig,
15
+ OllamaConfig,
16
+ ProcessingConfig,
17
+ ProvidersConfig,
18
+ QAConfig,
19
+ RerankingConfig,
20
+ ResearchConfig,
21
+ StorageConfig,
22
+ VLLMConfig,
23
+ )
24
+
25
+ __all__ = [
26
+ "Config",
27
+ "AppConfig",
28
+ "StorageConfig",
29
+ "LanceDBConfig",
30
+ "EmbeddingsConfig",
31
+ "RerankingConfig",
32
+ "QAConfig",
33
+ "ResearchConfig",
34
+ "ProcessingConfig",
35
+ "OllamaConfig",
36
+ "VLLMConfig",
37
+ "ProvidersConfig",
38
+ "A2AConfig",
39
+ "find_config_file",
40
+ "load_yaml_config",
41
+ "generate_default_config",
42
+ "load_config_from_env",
43
+ ]
44
+
45
+ # Load config from YAML file or use defaults
46
+ config_path = find_config_file(None)
47
+ if config_path:
48
+ yaml_data = load_yaml_config(config_path)
49
+ Config = AppConfig.model_validate(yaml_data)
50
+ else:
51
+ Config = AppConfig()
52
+
53
+ # Check for deprecated .env file
54
+ check_for_deprecated_env()
@@ -0,0 +1,151 @@
1
+ import os
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+
7
+
8
+ def find_config_file(cli_path: Path | None = None) -> Path | None:
9
+ """Find the YAML config file using the search path.
10
+
11
+ Search order:
12
+ 1. CLI-provided path (via HAIKU_RAG_CONFIG_PATH env var or parameter)
13
+ 2. ./haiku.rag.yaml (current directory)
14
+ 3. ~/.config/haiku.rag/config.yaml (user config)
15
+
16
+ Returns None if no config file is found.
17
+ """
18
+ # Check environment variable first (set by CLI --config flag)
19
+ if not cli_path:
20
+ env_path = os.getenv("HAIKU_RAG_CONFIG_PATH")
21
+ if env_path:
22
+ cli_path = Path(env_path)
23
+
24
+ if cli_path:
25
+ if cli_path.exists():
26
+ return cli_path
27
+ raise FileNotFoundError(f"Config file not found: {cli_path}")
28
+
29
+ cwd_config = Path.cwd() / "haiku.rag.yaml"
30
+ if cwd_config.exists():
31
+ return cwd_config
32
+
33
+ user_config_dir = Path.home() / ".config" / "haiku.rag"
34
+ user_config = user_config_dir / "config.yaml"
35
+ if user_config.exists():
36
+ return user_config
37
+
38
+ return None
39
+
40
+
41
+ def load_yaml_config(path: Path) -> dict:
42
+ """Load and parse a YAML config file."""
43
+ with open(path) as f:
44
+ data = yaml.safe_load(f)
45
+ return data or {}
46
+
47
+
48
+ def check_for_deprecated_env() -> None:
49
+ """Check for .env file and warn if found."""
50
+ env_file = Path.cwd() / ".env"
51
+ if env_file.exists():
52
+ warnings.warn(
53
+ ".env file detected but YAML configuration is now preferred. "
54
+ "Environment variable configuration is deprecated and will be removed in future versions."
55
+ "Run 'haiku-rag init-config' to generate a YAML config file.",
56
+ DeprecationWarning,
57
+ stacklevel=2,
58
+ )
59
+
60
+
61
+ def generate_default_config() -> dict:
62
+ """Generate a default YAML config structure with documentation."""
63
+ return {
64
+ "environment": "production",
65
+ "storage": {
66
+ "data_dir": "",
67
+ "monitor_directories": [],
68
+ "disable_autocreate": False,
69
+ "vacuum_retention_seconds": 60,
70
+ },
71
+ "lancedb": {"uri": "", "api_key": "", "region": ""},
72
+ "embeddings": {
73
+ "provider": "ollama",
74
+ "model": "qwen3-embedding",
75
+ "vector_dim": 4096,
76
+ },
77
+ "reranking": {"provider": "", "model": ""},
78
+ "qa": {"provider": "ollama", "model": "gpt-oss"},
79
+ "research": {"provider": "", "model": ""},
80
+ "processing": {
81
+ "chunk_size": 256,
82
+ "context_chunk_radius": 0,
83
+ "markdown_preprocessor": "",
84
+ },
85
+ "providers": {
86
+ "ollama": {"base_url": "http://localhost:11434"},
87
+ "vllm": {
88
+ "embeddings_base_url": "",
89
+ "rerank_base_url": "",
90
+ "qa_base_url": "",
91
+ "research_base_url": "",
92
+ },
93
+ },
94
+ "a2a": {"max_contexts": 1000},
95
+ }
96
+
97
+
98
+ def load_config_from_env() -> dict:
99
+ """Load current config from environment variables (for migration)."""
100
+ result = {}
101
+
102
+ env_mappings = {
103
+ "ENV": "environment",
104
+ "DEFAULT_DATA_DIR": ("storage", "data_dir"),
105
+ "MONITOR_DIRECTORIES": ("storage", "monitor_directories"),
106
+ "DISABLE_DB_AUTOCREATE": ("storage", "disable_autocreate"),
107
+ "VACUUM_RETENTION_SECONDS": ("storage", "vacuum_retention_seconds"),
108
+ "LANCEDB_URI": ("lancedb", "uri"),
109
+ "LANCEDB_API_KEY": ("lancedb", "api_key"),
110
+ "LANCEDB_REGION": ("lancedb", "region"),
111
+ "EMBEDDINGS_PROVIDER": ("embeddings", "provider"),
112
+ "EMBEDDINGS_MODEL": ("embeddings", "model"),
113
+ "EMBEDDINGS_VECTOR_DIM": ("embeddings", "vector_dim"),
114
+ "RERANK_PROVIDER": ("reranking", "provider"),
115
+ "RERANK_MODEL": ("reranking", "model"),
116
+ "QA_PROVIDER": ("qa", "provider"),
117
+ "QA_MODEL": ("qa", "model"),
118
+ "RESEARCH_PROVIDER": ("research", "provider"),
119
+ "RESEARCH_MODEL": ("research", "model"),
120
+ "CHUNK_SIZE": ("processing", "chunk_size"),
121
+ "CONTEXT_CHUNK_RADIUS": ("processing", "context_chunk_radius"),
122
+ "MARKDOWN_PREPROCESSOR": ("processing", "markdown_preprocessor"),
123
+ "OLLAMA_BASE_URL": ("providers", "ollama", "base_url"),
124
+ "VLLM_EMBEDDINGS_BASE_URL": ("providers", "vllm", "embeddings_base_url"),
125
+ "VLLM_RERANK_BASE_URL": ("providers", "vllm", "rerank_base_url"),
126
+ "VLLM_QA_BASE_URL": ("providers", "vllm", "qa_base_url"),
127
+ "VLLM_RESEARCH_BASE_URL": ("providers", "vllm", "research_base_url"),
128
+ "A2A_MAX_CONTEXTS": ("a2a", "max_contexts"),
129
+ }
130
+
131
+ for env_var, path in env_mappings.items():
132
+ value = os.getenv(env_var)
133
+ if value is not None:
134
+ # Special handling for MONITOR_DIRECTORIES - parse comma-separated list
135
+ if env_var == "MONITOR_DIRECTORIES":
136
+ if value.strip():
137
+ value = [p.strip() for p in value.split(",") if p.strip()]
138
+ else:
139
+ value = []
140
+
141
+ if isinstance(path, tuple):
142
+ current = result
143
+ for key in path[:-1]:
144
+ if key not in current:
145
+ current[key] = {}
146
+ current = current[key]
147
+ current[path[-1]] = value
148
+ else:
149
+ result[path] = value
150
+
151
+ return result
@@ -0,0 +1,78 @@
1
+ from pathlib import Path
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from haiku.rag.utils import get_default_data_dir
6
+
7
+
8
+ class StorageConfig(BaseModel):
9
+ data_dir: Path = Field(default_factory=get_default_data_dir)
10
+ monitor_directories: list[Path] = []
11
+ disable_autocreate: bool = False
12
+ vacuum_retention_seconds: int = 60
13
+
14
+
15
+ class LanceDBConfig(BaseModel):
16
+ uri: str = ""
17
+ api_key: str = ""
18
+ region: str = ""
19
+
20
+
21
+ class EmbeddingsConfig(BaseModel):
22
+ provider: str = "ollama"
23
+ model: str = "qwen3-embedding"
24
+ vector_dim: int = 4096
25
+
26
+
27
+ class RerankingConfig(BaseModel):
28
+ provider: str = ""
29
+ model: str = ""
30
+
31
+
32
+ class QAConfig(BaseModel):
33
+ provider: str = "ollama"
34
+ model: str = "gpt-oss"
35
+
36
+
37
+ class ResearchConfig(BaseModel):
38
+ provider: str = "ollama"
39
+ model: str = "gpt-oss"
40
+
41
+
42
+ class ProcessingConfig(BaseModel):
43
+ chunk_size: int = 256
44
+ context_chunk_radius: int = 0
45
+ markdown_preprocessor: str = ""
46
+
47
+
48
+ class OllamaConfig(BaseModel):
49
+ base_url: str = "http://localhost:11434"
50
+
51
+
52
+ class VLLMConfig(BaseModel):
53
+ embeddings_base_url: str = ""
54
+ rerank_base_url: str = ""
55
+ qa_base_url: str = ""
56
+ research_base_url: str = ""
57
+
58
+
59
+ class ProvidersConfig(BaseModel):
60
+ ollama: OllamaConfig = Field(default_factory=OllamaConfig)
61
+ vllm: VLLMConfig = Field(default_factory=VLLMConfig)
62
+
63
+
64
+ class A2AConfig(BaseModel):
65
+ max_contexts: int = 1000
66
+
67
+
68
+ class AppConfig(BaseModel):
69
+ environment: str = "production"
70
+ storage: StorageConfig = Field(default_factory=StorageConfig)
71
+ lancedb: LanceDBConfig = Field(default_factory=LanceDBConfig)
72
+ embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig)
73
+ reranking: RerankingConfig = Field(default_factory=RerankingConfig)
74
+ qa: QAConfig = Field(default_factory=QAConfig)
75
+ research: ResearchConfig = Field(default_factory=ResearchConfig)
76
+ processing: ProcessingConfig = Field(default_factory=ProcessingConfig)
77
+ providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
78
+ a2a: A2AConfig = Field(default_factory=A2AConfig)
@@ -1,17 +1,23 @@
1
- from haiku.rag.config import Config
1
+ from haiku.rag.config import AppConfig, Config
2
2
  from haiku.rag.embeddings.base import EmbedderBase
3
3
  from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
4
4
 
5
5
 
6
- def get_embedder() -> EmbedderBase:
6
+ def get_embedder(config: AppConfig = Config) -> EmbedderBase:
7
7
  """
8
8
  Factory function to get the appropriate embedder based on the configuration.
9
+
10
+ Args:
11
+ config: Configuration to use. Defaults to global Config.
12
+
13
+ Returns:
14
+ An embedder instance configured according to the config.
9
15
  """
10
16
 
11
- if Config.EMBEDDINGS_PROVIDER == "ollama":
12
- return OllamaEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
17
+ if config.embeddings.provider == "ollama":
18
+ return OllamaEmbedder(config.embeddings.model, config.embeddings.vector_dim)
13
19
 
14
- if Config.EMBEDDINGS_PROVIDER == "voyageai":
20
+ if config.embeddings.provider == "voyageai":
15
21
  try:
16
22
  from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
17
23
  except ImportError:
@@ -20,16 +26,16 @@ def get_embedder() -> EmbedderBase:
20
26
  "Please install haiku.rag with the 'voyageai' extra: "
21
27
  "uv pip install haiku.rag[voyageai]"
22
28
  )
23
- return VoyageAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
29
+ return VoyageAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
24
30
 
25
- if Config.EMBEDDINGS_PROVIDER == "openai":
31
+ if config.embeddings.provider == "openai":
26
32
  from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
27
33
 
28
- return OpenAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
34
+ return OpenAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
29
35
 
30
- if Config.EMBEDDINGS_PROVIDER == "vllm":
36
+ if config.embeddings.provider == "vllm":
31
37
  from haiku.rag.embeddings.vllm import Embedder as VllmEmbedder
32
38
 
33
- return VllmEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
39
+ return VllmEmbedder(config.embeddings.model, config.embeddings.vector_dim)
34
40
 
35
- raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDINGS_PROVIDER}")
41
+ raise ValueError(f"Unsupported embedding provider: {config.embeddings.provider}")
@@ -1,14 +1,22 @@
1
+ from typing import overload
2
+
1
3
  from haiku.rag.config import Config
2
4
 
3
5
 
4
6
  class EmbedderBase:
5
- _model: str = Config.EMBEDDINGS_MODEL
6
- _vector_dim: int = Config.EMBEDDINGS_VECTOR_DIM
7
+ _model: str = Config.embeddings.model
8
+ _vector_dim: int = Config.embeddings.vector_dim
7
9
 
8
10
  def __init__(self, model: str, vector_dim: int):
9
11
  self._model = model
10
12
  self._vector_dim = vector_dim
11
13
 
14
+ @overload
15
+ async def embed(self, text: str) -> list[float]: ...
16
+
17
+ @overload
18
+ async def embed(self, text: list[str]) -> list[list[float]]: ...
19
+
12
20
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
13
21
  raise NotImplementedError(
14
22
  "Embedder is an abstract class. Please implement the embed method in a subclass."
@@ -1,3 +1,5 @@
1
+ from typing import overload
2
+
1
3
  from openai import AsyncOpenAI
2
4
 
3
5
  from haiku.rag.config import Config
@@ -5,8 +7,16 @@ from haiku.rag.embeddings.base import EmbedderBase
5
7
 
6
8
 
7
9
  class Embedder(EmbedderBase):
10
+ @overload
11
+ async def embed(self, text: str) -> list[float]: ...
12
+
13
+ @overload
14
+ async def embed(self, text: list[str]) -> list[list[float]]: ...
15
+
8
16
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
9
- client = AsyncOpenAI(base_url=f"{Config.OLLAMA_BASE_URL}/v1", api_key="dummy")
17
+ client = AsyncOpenAI(
18
+ base_url=f"{Config.providers.ollama.base_url}/v1", api_key="dummy"
19
+ )
10
20
  if not text:
11
21
  return []
12
22
  response = await client.embeddings.create(
@@ -1,9 +1,17 @@
1
+ from typing import overload
2
+
1
3
  from openai import AsyncOpenAI
2
4
 
3
5
  from haiku.rag.embeddings.base import EmbedderBase
4
6
 
5
7
 
6
8
  class Embedder(EmbedderBase):
9
+ @overload
10
+ async def embed(self, text: str) -> list[float]: ...
11
+
12
+ @overload
13
+ async def embed(self, text: list[str]) -> list[list[float]]: ...
14
+
7
15
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
8
16
  client = AsyncOpenAI()
9
17
  if not text:
@@ -1,3 +1,5 @@
1
+ from typing import overload
2
+
1
3
  from openai import AsyncOpenAI
2
4
 
3
5
  from haiku.rag.config import Config
@@ -5,9 +7,15 @@ from haiku.rag.embeddings.base import EmbedderBase
5
7
 
6
8
 
7
9
  class Embedder(EmbedderBase):
10
+ @overload
11
+ async def embed(self, text: str) -> list[float]: ...
12
+
13
+ @overload
14
+ async def embed(self, text: list[str]) -> list[list[float]]: ...
15
+
8
16
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
9
17
  client = AsyncOpenAI(
10
- base_url=f"{Config.VLLM_EMBEDDINGS_BASE_URL}/v1", api_key="dummy"
18
+ base_url=f"{Config.providers.vllm.embeddings_base_url}/v1", api_key="dummy"
11
19
  )
12
20
  if not text:
13
21
  return []
@@ -1,9 +1,17 @@
1
1
  try:
2
+ from typing import overload
3
+
2
4
  from voyageai.client import Client # type: ignore
3
5
 
4
6
  from haiku.rag.embeddings.base import EmbedderBase
5
7
 
6
8
  class Embedder(EmbedderBase):
9
+ @overload
10
+ async def embed(self, text: str) -> list[float]: ...
11
+
12
+ @overload
13
+ async def embed(self, text: list[str]) -> list[list[float]]: ...
14
+
7
15
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
8
16
  client = Client()
9
17
  if not text:
haiku/rag/graph/common.py CHANGED
@@ -15,13 +15,13 @@ def get_model(provider: str, model: str) -> Any:
15
15
  if provider == "ollama":
16
16
  return OpenAIChatModel(
17
17
  model_name=model,
18
- provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
18
+ provider=OllamaProvider(base_url=f"{Config.providers.ollama.base_url}/v1"),
19
19
  )
20
20
  elif provider == "vllm":
21
21
  return OpenAIChatModel(
22
22
  model_name=model,
23
23
  provider=OpenAIProvider(
24
- base_url=f"{Config.VLLM_RESEARCH_BASE_URL or Config.VLLM_QA_BASE_URL}/v1",
24
+ base_url=f"{Config.providers.vllm.research_base_url or Config.providers.vllm.qa_base_url}/v1",
25
25
  api_key="none",
26
26
  ),
27
27
  )