haiku.rag 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

@@ -0,0 +1,78 @@
1
+ from pathlib import Path
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from haiku.rag.utils import get_default_data_dir
6
+
7
+
8
+ class StorageConfig(BaseModel):
9
+ data_dir: Path = Field(default_factory=get_default_data_dir)
10
+ monitor_directories: list[Path] = []
11
+ disable_autocreate: bool = False
12
+ vacuum_retention_seconds: int = 60
13
+
14
+
15
+ class LanceDBConfig(BaseModel):
16
+ uri: str = ""
17
+ api_key: str = ""
18
+ region: str = ""
19
+
20
+
21
+ class EmbeddingsConfig(BaseModel):
22
+ provider: str = "ollama"
23
+ model: str = "qwen3-embedding"
24
+ vector_dim: int = 4096
25
+
26
+
27
+ class RerankingConfig(BaseModel):
28
+ provider: str = ""
29
+ model: str = ""
30
+
31
+
32
+ class QAConfig(BaseModel):
33
+ provider: str = "ollama"
34
+ model: str = "gpt-oss"
35
+
36
+
37
+ class ResearchConfig(BaseModel):
38
+ provider: str = "ollama"
39
+ model: str = "gpt-oss"
40
+
41
+
42
+ class ProcessingConfig(BaseModel):
43
+ chunk_size: int = 256
44
+ context_chunk_radius: int = 0
45
+ markdown_preprocessor: str = ""
46
+
47
+
48
+ class OllamaConfig(BaseModel):
49
+ base_url: str = "http://localhost:11434"
50
+
51
+
52
+ class VLLMConfig(BaseModel):
53
+ embeddings_base_url: str = ""
54
+ rerank_base_url: str = ""
55
+ qa_base_url: str = ""
56
+ research_base_url: str = ""
57
+
58
+
59
+ class ProvidersConfig(BaseModel):
60
+ ollama: OllamaConfig = Field(default_factory=OllamaConfig)
61
+ vllm: VLLMConfig = Field(default_factory=VLLMConfig)
62
+
63
+
64
+ class A2AConfig(BaseModel):
65
+ max_contexts: int = 1000
66
+
67
+
68
+ class AppConfig(BaseModel):
69
+ environment: str = "production"
70
+ storage: StorageConfig = Field(default_factory=StorageConfig)
71
+ lancedb: LanceDBConfig = Field(default_factory=LanceDBConfig)
72
+ embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig)
73
+ reranking: RerankingConfig = Field(default_factory=RerankingConfig)
74
+ qa: QAConfig = Field(default_factory=QAConfig)
75
+ research: ResearchConfig = Field(default_factory=ResearchConfig)
76
+ processing: ProcessingConfig = Field(default_factory=ProcessingConfig)
77
+ providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
78
+ a2a: A2AConfig = Field(default_factory=A2AConfig)
@@ -1,17 +1,23 @@
1
- from haiku.rag.config import Config
1
+ from haiku.rag.config import AppConfig, Config
2
2
  from haiku.rag.embeddings.base import EmbedderBase
3
3
  from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
4
4
 
5
5
 
6
- def get_embedder() -> EmbedderBase:
6
+ def get_embedder(config: AppConfig = Config) -> EmbedderBase:
7
7
  """
8
8
  Factory function to get the appropriate embedder based on the configuration.
9
+
10
+ Args:
11
+ config: Configuration to use. Defaults to global Config.
12
+
13
+ Returns:
14
+ An embedder instance configured according to the config.
9
15
  """
10
16
 
11
- if Config.EMBEDDINGS_PROVIDER == "ollama":
12
- return OllamaEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
17
+ if config.embeddings.provider == "ollama":
18
+ return OllamaEmbedder(config.embeddings.model, config.embeddings.vector_dim)
13
19
 
14
- if Config.EMBEDDINGS_PROVIDER == "voyageai":
20
+ if config.embeddings.provider == "voyageai":
15
21
  try:
16
22
  from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
17
23
  except ImportError:
@@ -20,16 +26,16 @@ def get_embedder() -> EmbedderBase:
20
26
  "Please install haiku.rag with the 'voyageai' extra: "
21
27
  "uv pip install haiku.rag[voyageai]"
22
28
  )
23
- return VoyageAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
29
+ return VoyageAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
24
30
 
25
- if Config.EMBEDDINGS_PROVIDER == "openai":
31
+ if config.embeddings.provider == "openai":
26
32
  from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
27
33
 
28
- return OpenAIEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
34
+ return OpenAIEmbedder(config.embeddings.model, config.embeddings.vector_dim)
29
35
 
30
- if Config.EMBEDDINGS_PROVIDER == "vllm":
36
+ if config.embeddings.provider == "vllm":
31
37
  from haiku.rag.embeddings.vllm import Embedder as VllmEmbedder
32
38
 
33
- return VllmEmbedder(Config.EMBEDDINGS_MODEL, Config.EMBEDDINGS_VECTOR_DIM)
39
+ return VllmEmbedder(config.embeddings.model, config.embeddings.vector_dim)
34
40
 
35
- raise ValueError(f"Unsupported embedding provider: {Config.EMBEDDINGS_PROVIDER}")
41
+ raise ValueError(f"Unsupported embedding provider: {config.embeddings.provider}")
@@ -4,8 +4,8 @@ from haiku.rag.config import Config
4
4
 
5
5
 
6
6
  class EmbedderBase:
7
- _model: str = Config.EMBEDDINGS_MODEL
8
- _vector_dim: int = Config.EMBEDDINGS_VECTOR_DIM
7
+ _model: str = Config.embeddings.model
8
+ _vector_dim: int = Config.embeddings.vector_dim
9
9
 
10
10
  def __init__(self, model: str, vector_dim: int):
11
11
  self._model = model
@@ -14,7 +14,9 @@ class Embedder(EmbedderBase):
14
14
  async def embed(self, text: list[str]) -> list[list[float]]: ...
15
15
 
16
16
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
17
- client = AsyncOpenAI(base_url=f"{Config.OLLAMA_BASE_URL}/v1", api_key="dummy")
17
+ client = AsyncOpenAI(
18
+ base_url=f"{Config.providers.ollama.base_url}/v1", api_key="dummy"
19
+ )
18
20
  if not text:
19
21
  return []
20
22
  response = await client.embeddings.create(
@@ -15,7 +15,7 @@ class Embedder(EmbedderBase):
15
15
 
16
16
  async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
17
17
  client = AsyncOpenAI(
18
- base_url=f"{Config.VLLM_EMBEDDINGS_BASE_URL}/v1", api_key="dummy"
18
+ base_url=f"{Config.providers.vllm.embeddings_base_url}/v1", api_key="dummy"
19
19
  )
20
20
  if not text:
21
21
  return []
haiku/rag/graph/common.py CHANGED
@@ -15,13 +15,13 @@ def get_model(provider: str, model: str) -> Any:
15
15
  if provider == "ollama":
16
16
  return OpenAIChatModel(
17
17
  model_name=model,
18
- provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
18
+ provider=OllamaProvider(base_url=f"{Config.providers.ollama.base_url}/v1"),
19
19
  )
20
20
  elif provider == "vllm":
21
21
  return OpenAIChatModel(
22
22
  model_name=model,
23
23
  provider=OpenAIProvider(
24
- base_url=f"{Config.VLLM_RESEARCH_BASE_URL or Config.VLLM_QA_BASE_URL}/v1",
24
+ base_url=f"{Config.providers.vllm.research_base_url or Config.providers.vllm.qa_base_url}/v1",
25
25
  api_key="none",
26
26
  ),
27
27
  )
haiku/rag/mcp.py CHANGED
@@ -38,10 +38,13 @@ def create_mcp_server(db_path: Path) -> FastMCP:
38
38
  """Add a document to the RAG system from a file path."""
39
39
  try:
40
40
  async with HaikuRAG(db_path) as rag:
41
- document = await rag.create_document_from_source(
41
+ result = await rag.create_document_from_source(
42
42
  Path(file_path), title=title, metadata=metadata or {}
43
43
  )
44
- return document.id
44
+ # Handle both single document and list of documents (directories)
45
+ if isinstance(result, list):
46
+ return result[0].id if result else None
47
+ return result.id
45
48
  except Exception:
46
49
  return None
47
50
 
@@ -52,10 +55,13 @@ def create_mcp_server(db_path: Path) -> FastMCP:
52
55
  """Add a document to the RAG system from a URL."""
53
56
  try:
54
57
  async with HaikuRAG(db_path) as rag:
55
- document = await rag.create_document_from_source(
58
+ result = await rag.create_document_from_source(
56
59
  url, title=title, metadata=metadata or {}
57
60
  )
58
- return document.id
61
+ # Handle both single document and list of documents
62
+ if isinstance(result, list):
63
+ return result[0].id if result else None
64
+ return result.id
59
65
  except Exception:
60
66
  return None
61
67
 
@@ -188,8 +194,8 @@ def create_mcp_server(db_path: Path) -> FastMCP:
188
194
  deps = DeepQADeps(client=rag)
189
195
 
190
196
  start_node = DeepQAPlanNode(
191
- provider=Config.QA_PROVIDER,
192
- model=Config.QA_MODEL,
197
+ provider=Config.qa.provider,
198
+ model=Config.qa.model,
193
199
  )
194
200
 
195
201
  result = await graph.run(
@@ -241,8 +247,8 @@ def create_mcp_server(db_path: Path) -> FastMCP:
241
247
 
242
248
  result = await graph.run(
243
249
  PlanNode(
244
- provider=Config.RESEARCH_PROVIDER or Config.QA_PROVIDER,
245
- model=Config.RESEARCH_MODEL or Config.QA_MODEL,
250
+ provider=Config.research.provider or Config.qa.provider,
251
+ model=Config.research.model or Config.qa.model,
246
252
  ),
247
253
  state=state,
248
254
  deps=deps,
haiku/rag/monitor.py CHANGED
@@ -1,21 +1,27 @@
1
1
  import logging
2
2
  from pathlib import Path
3
+ from typing import TYPE_CHECKING
3
4
 
4
5
  from watchfiles import Change, DefaultFilter, awatch
5
6
 
6
7
  from haiku.rag.client import HaikuRAG
7
- from haiku.rag.reader import FileReader
8
8
  from haiku.rag.store.models.document import Document
9
9
 
10
+ if TYPE_CHECKING:
11
+ pass
12
+
10
13
  logger = logging.getLogger(__name__)
11
14
 
12
15
 
13
16
  class FileFilter(DefaultFilter):
14
17
  def __init__(self, *, ignore_paths: list[Path] | None = None) -> None:
18
+ # Lazy import to avoid loading docling
19
+ from haiku.rag.reader import FileReader
20
+
15
21
  self.extensions = tuple(FileReader.extensions)
16
22
  super().__init__(ignore_paths=ignore_paths)
17
23
 
18
- def __call__(self, change: "Change", path: str) -> bool:
24
+ def __call__(self, change: Change, path: str) -> bool:
19
25
  return path.endswith(self.extensions) and super().__call__(change, path)
20
26
 
21
27
 
@@ -40,6 +46,9 @@ class FileWatcher:
40
46
  await self._delete_document(Path(path))
41
47
 
42
48
  async def refresh(self):
49
+ # Lazy import to avoid loading docling
50
+ from haiku.rag.reader import FileReader
51
+
43
52
  for path in self.paths:
44
53
  for f in Path(path).rglob("**/*"):
45
54
  if f.is_file() and f.suffix in FileReader.extensions:
@@ -50,11 +59,15 @@ class FileWatcher:
50
59
  uri = file.as_uri()
51
60
  existing_doc = await self.client.get_document_by_uri(uri)
52
61
  if existing_doc:
53
- doc = await self.client.create_document_from_source(str(file))
62
+ result = await self.client.create_document_from_source(str(file))
63
+ # Since we're passing a file (not directory), result should be a single Document
64
+ doc = result if isinstance(result, Document) else result[0]
54
65
  logger.info(f"Updated document {existing_doc.id} from {file}")
55
66
  return doc
56
67
  else:
57
- doc = await self.client.create_document_from_source(str(file))
68
+ result = await self.client.create_document_from_source(str(file))
69
+ # Since we're passing a file (not directory), result should be a single Document
70
+ doc = result if isinstance(result, Document) else result[0]
58
71
  logger.info(f"Created new document {doc.id} from {file}")
59
72
  return doc
60
73
  except Exception as e:
haiku/rag/qa/__init__.py CHANGED
@@ -1,15 +1,28 @@
1
1
  from haiku.rag.client import HaikuRAG
2
- from haiku.rag.config import Config
2
+ from haiku.rag.config import AppConfig, Config
3
3
  from haiku.rag.qa.agent import QuestionAnswerAgent
4
4
 
5
5
 
6
6
  def get_qa_agent(
7
7
  client: HaikuRAG,
8
+ config: AppConfig = Config,
8
9
  use_citations: bool = False,
9
10
  system_prompt: str | None = None,
10
11
  ) -> QuestionAnswerAgent:
11
- provider = Config.QA_PROVIDER
12
- model_name = Config.QA_MODEL
12
+ """
13
+ Factory function to get a QA agent based on the configuration.
14
+
15
+ Args:
16
+ client: HaikuRAG client instance.
17
+ config: Configuration to use. Defaults to global Config.
18
+ use_citations: Whether to include citations in responses.
19
+ system_prompt: Optional custom system prompt.
20
+
21
+ Returns:
22
+ A configured QuestionAnswerAgent instance.
23
+ """
24
+ provider = config.qa.provider
25
+ model_name = config.qa.model
13
26
 
14
27
  return QuestionAnswerAgent(
15
28
  client=client,
haiku/rag/qa/agent.py CHANGED
@@ -71,13 +71,15 @@ class QuestionAnswerAgent:
71
71
  if provider == "ollama":
72
72
  return OpenAIChatModel(
73
73
  model_name=model,
74
- provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
74
+ provider=OllamaProvider(
75
+ base_url=f"{Config.providers.ollama.base_url}/v1"
76
+ ),
75
77
  )
76
78
  elif provider == "vllm":
77
79
  return OpenAIChatModel(
78
80
  model_name=model,
79
81
  provider=OpenAIProvider(
80
- base_url=f"{Config.VLLM_QA_BASE_URL}/v1", api_key="none"
82
+ base_url=f"{Config.providers.vllm.qa_base_url}/v1", api_key="none"
81
83
  ),
82
84
  )
83
85
  else:
@@ -1,37 +1,45 @@
1
1
  import os
2
2
 
3
- from haiku.rag.config import Config
3
+ from haiku.rag.config import AppConfig, Config
4
4
  from haiku.rag.reranking.base import RerankerBase
5
5
 
6
- _reranker: RerankerBase | None = None
6
+ _reranker_cache: dict[int, RerankerBase | None] = {}
7
7
 
8
8
 
9
- def get_reranker() -> RerankerBase | None:
9
+ def get_reranker(config: AppConfig = Config) -> RerankerBase | None:
10
10
  """
11
11
  Factory function to get the appropriate reranker based on the configuration.
12
- Returns None if if reranking is disabled.
12
+ Returns None if reranking is disabled.
13
+
14
+ Args:
15
+ config: Configuration to use. Defaults to global Config.
16
+
17
+ Returns:
18
+ A reranker instance if configured, None otherwise.
13
19
  """
14
- global _reranker
15
- if _reranker is not None:
16
- return _reranker
20
+ # Use config id as cache key to support multiple configs
21
+ config_id = id(config)
22
+ if config_id in _reranker_cache:
23
+ return _reranker_cache[config_id]
24
+
25
+ reranker: RerankerBase | None = None
17
26
 
18
- if Config.RERANK_PROVIDER == "mxbai":
27
+ if config.reranking.provider == "mxbai":
19
28
  try:
20
29
  from haiku.rag.reranking.mxbai import MxBAIReranker
21
30
 
22
31
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
23
- _reranker = MxBAIReranker()
24
- return _reranker
32
+ reranker = MxBAIReranker()
25
33
  except ImportError:
26
- return None
34
+ reranker = None
27
35
 
28
- if Config.RERANK_PROVIDER == "cohere":
36
+ elif config.reranking.provider == "cohere":
29
37
  try:
30
38
  from haiku.rag.reranking.cohere import CohereReranker
31
39
 
32
- _reranker = CohereReranker()
33
- return _reranker
40
+ reranker = CohereReranker()
34
41
  except ImportError:
35
- return None
42
+ reranker = None
36
43
 
37
- return None
44
+ _reranker_cache[config_id] = reranker
45
+ return reranker
@@ -3,7 +3,7 @@ from haiku.rag.store.models.chunk import Chunk
3
3
 
4
4
 
5
5
  class RerankerBase:
6
- _model: str = Config.RERANK_MODEL
6
+ _model: str = Config.reranking.model
7
7
 
8
8
  async def rerank(
9
9
  self, query: str, chunks: list[Chunk], top_n: int = 10
@@ -1,4 +1,3 @@
1
- from haiku.rag.config import Config
2
1
  from haiku.rag.reranking.base import RerankerBase
3
2
  from haiku.rag.store.models.chunk import Chunk
4
3
 
@@ -12,7 +11,8 @@ except ImportError as e:
12
11
 
13
12
  class CohereReranker(RerankerBase):
14
13
  def __init__(self):
15
- self._client = cohere.ClientV2(api_key=Config.COHERE_API_KEY)
14
+ # Cohere SDK reads CO_API_KEY from environment by default
15
+ self._client = cohere.ClientV2()
16
16
 
17
17
  async def rerank(
18
18
  self, query: str, chunks: list[Chunk], top_n: int = 10
@@ -8,7 +8,7 @@ from haiku.rag.store.models.chunk import Chunk
8
8
  class MxBAIReranker(RerankerBase):
9
9
  def __init__(self):
10
10
  self._client = MxbaiRerankV2(
11
- Config.RERANK_MODEL, disable_transformers_warnings=True
11
+ Config.reranking.model, disable_transformers_warnings=True
12
12
  )
13
13
 
14
14
  async def rerank(
@@ -8,7 +8,7 @@ from haiku.rag.store.models.chunk import Chunk
8
8
  class VLLMReranker(RerankerBase):
9
9
  def __init__(self, model: str):
10
10
  self._model = model
11
- self._base_url = Config.VLLM_RERANK_BASE_URL
11
+ self._base_url = Config.providers.vllm.rerank_base_url
12
12
 
13
13
  async def rerank(
14
14
  self, query: str, chunks: list[Chunk], top_n: int = 10
haiku/rag/store/engine.py CHANGED
@@ -10,7 +10,7 @@ import lancedb
10
10
  from lancedb.pydantic import LanceModel, Vector
11
11
  from pydantic import Field
12
12
 
13
- from haiku.rag.config import Config
13
+ from haiku.rag.config import AppConfig, Config
14
14
  from haiku.rag.embeddings import get_embedder
15
15
 
16
16
  logger = logging.getLogger(__name__)
@@ -49,9 +49,12 @@ class SettingsRecord(LanceModel):
49
49
 
50
50
 
51
51
  class Store:
52
- def __init__(self, db_path: Path, skip_validation: bool = False):
52
+ def __init__(
53
+ self, db_path: Path, config: AppConfig = Config, skip_validation: bool = False
54
+ ):
53
55
  self.db_path: Path = db_path
54
- self.embedder = get_embedder()
56
+ self._config = config
57
+ self.embedder = get_embedder(config=self._config)
55
58
  self._vacuum_lock = asyncio.Lock()
56
59
 
57
60
  # Create the ChunkRecord model with the correct vector dimension
@@ -59,7 +62,7 @@ class Store:
59
62
 
60
63
  # Local filesystem handling for DB directory
61
64
  if not self._has_cloud_config():
62
- if Config.DISABLE_DB_AUTOCREATE:
65
+ if self._config.storage.disable_autocreate:
63
66
  # LanceDB uses a directory path for local databases; enforce presence
64
67
  if not db_path.exists():
65
68
  raise FileNotFoundError(
@@ -85,13 +88,15 @@ class Store:
85
88
 
86
89
  Args:
87
90
  retention_seconds: Retention threshold in seconds. Only versions older
88
- than this will be removed. If None, uses Config.VACUUM_RETENTION_SECONDS.
91
+ than this will be removed. If None, uses config.storage.vacuum_retention_seconds.
89
92
 
90
93
  Note:
91
94
  If vacuum is already running, this method returns immediately without blocking.
92
95
  Use asyncio.create_task(store.vacuum()) for non-blocking background execution.
93
96
  """
94
- if self._has_cloud_config() and str(Config.LANCEDB_URI).startswith("db://"):
97
+ if self._has_cloud_config() and str(self._config.lancedb.uri).startswith(
98
+ "db://"
99
+ ):
95
100
  return
96
101
 
97
102
  # Skip if already running (non-blocking)
@@ -102,7 +107,7 @@ class Store:
102
107
  try:
103
108
  # Evaluate config at runtime to allow dynamic changes
104
109
  if retention_seconds is None:
105
- retention_seconds = Config.VACUUM_RETENTION_SECONDS
110
+ retention_seconds = self._config.storage.vacuum_retention_seconds
106
111
  # Perform maintenance per table using optimize() with configurable retention
107
112
  retention = timedelta(seconds=retention_seconds)
108
113
  for table in [
@@ -120,9 +125,9 @@ class Store:
120
125
  # Check if we have cloud configuration
121
126
  if self._has_cloud_config():
122
127
  return lancedb.connect(
123
- uri=Config.LANCEDB_URI,
124
- api_key=Config.LANCEDB_API_KEY,
125
- region=Config.LANCEDB_REGION,
128
+ uri=self._config.lancedb.uri,
129
+ api_key=self._config.lancedb.api_key,
130
+ region=self._config.lancedb.region,
126
131
  )
127
132
  else:
128
133
  # Local file system connection
@@ -131,7 +136,9 @@ class Store:
131
136
  def _has_cloud_config(self) -> bool:
132
137
  """Check if cloud configuration is complete."""
133
138
  return bool(
134
- Config.LANCEDB_URI and Config.LANCEDB_API_KEY and Config.LANCEDB_REGION
139
+ self._config.lancedb.uri
140
+ and self._config.lancedb.api_key
141
+ and self._config.lancedb.region
135
142
  )
136
143
 
137
144
  def _validate_configuration(self) -> None:
@@ -173,7 +180,7 @@ class Store:
173
180
  "settings", schema=SettingsRecord
174
181
  )
175
182
  # Save current settings to the new database
176
- settings_data = Config.model_dump(mode="json")
183
+ settings_data = self._config.model_dump(mode="json")
177
184
  self.settings_table.add(
178
185
  [SettingsRecord(id="settings", settings=json.dumps(settings_data))]
179
186
  )
@@ -1,17 +1,17 @@
1
1
  import inspect
2
2
  import json
3
3
  import logging
4
+ from typing import TYPE_CHECKING
4
5
  from uuid import uuid4
5
6
 
6
- from docling_core.types.doc.document import DoclingDocument
7
7
  from lancedb.rerankers import RRFReranker
8
8
 
9
- from haiku.rag.chunker import chunker
10
- from haiku.rag.config import Config
11
- from haiku.rag.embeddings import get_embedder
12
9
  from haiku.rag.store.engine import DocumentRecord, Store
13
10
  from haiku.rag.store.models.chunk import Chunk
14
- from haiku.rag.utils import load_callable, text_to_docling_document
11
+ from haiku.rag.utils import load_callable
12
+
13
+ if TYPE_CHECKING:
14
+ from docling_core.types.doc.document import DoclingDocument
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
@@ -21,7 +21,7 @@ class ChunkRepository:
21
21
 
22
22
  def __init__(self, store: Store) -> None:
23
23
  self.store = store
24
- self.embedder = get_embedder()
24
+ self.embedder = store.embedder
25
25
 
26
26
  def _ensure_fts_index(self) -> None:
27
27
  """Ensure FTS index exists on the content column."""
@@ -142,12 +142,16 @@ class ChunkRepository:
142
142
  return chunks
143
143
 
144
144
  async def create_chunks_for_document(
145
- self, document_id: str, document: DoclingDocument
145
+ self, document_id: str, document: "DoclingDocument"
146
146
  ) -> list[Chunk]:
147
147
  """Create chunks and embeddings for a document from DoclingDocument."""
148
+ # Lazy imports to avoid loading docling during module import
149
+ from haiku.rag.chunker import chunker
150
+ from haiku.rag.utils import text_to_docling_document
151
+
148
152
  # Optionally preprocess markdown before chunking
149
153
  processed_document = document
150
- preprocessor_path = Config.MARKDOWN_PREPROCESSOR
154
+ preprocessor_path = self.store._config.processing.markdown_preprocessor
151
155
  if preprocessor_path:
152
156
  try:
153
157
  pre_fn = load_callable(preprocessor_path)
@@ -4,12 +4,12 @@ from datetime import datetime
4
4
  from typing import TYPE_CHECKING
5
5
  from uuid import uuid4
6
6
 
7
- from docling_core.types.doc.document import DoclingDocument
8
-
9
7
  from haiku.rag.store.engine import DocumentRecord, Store
10
8
  from haiku.rag.store.models.document import Document
11
9
 
12
10
  if TYPE_CHECKING:
11
+ from docling_core.types.doc.document import DoclingDocument
12
+
13
13
  from haiku.rag.store.models.chunk import Chunk
14
14
 
15
15
 
@@ -171,7 +171,7 @@ class DocumentRepository:
171
171
  async def _create_with_docling(
172
172
  self,
173
173
  entity: Document,
174
- docling_document: DoclingDocument,
174
+ docling_document: "DoclingDocument",
175
175
  chunks: list["Chunk"] | None = None,
176
176
  ) -> Document:
177
177
  """Create a document with its chunks and embeddings."""
@@ -211,7 +211,7 @@ class DocumentRepository:
211
211
  raise
212
212
 
213
213
  async def _update_with_docling(
214
- self, entity: Document, docling_document: DoclingDocument
214
+ self, entity: Document, docling_document: "DoclingDocument"
215
215
  ) -> Document:
216
216
  """Update a document and regenerate its chunks."""
217
217
  assert entity.id is not None, "Document ID is required for update"