haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (94) hide show
  1. haiku/rag/app.py +430 -72
  2. haiku/rag/chunkers/__init__.py +31 -0
  3. haiku/rag/chunkers/base.py +31 -0
  4. haiku/rag/chunkers/docling_local.py +164 -0
  5. haiku/rag/chunkers/docling_serve.py +179 -0
  6. haiku/rag/cli.py +207 -24
  7. haiku/rag/cli_chat.py +489 -0
  8. haiku/rag/client.py +1251 -266
  9. haiku/rag/config/__init__.py +16 -10
  10. haiku/rag/config/loader.py +5 -44
  11. haiku/rag/config/models.py +126 -17
  12. haiku/rag/converters/__init__.py +31 -0
  13. haiku/rag/converters/base.py +63 -0
  14. haiku/rag/converters/docling_local.py +193 -0
  15. haiku/rag/converters/docling_serve.py +229 -0
  16. haiku/rag/converters/text_utils.py +237 -0
  17. haiku/rag/embeddings/__init__.py +123 -24
  18. haiku/rag/embeddings/voyageai.py +175 -20
  19. haiku/rag/graph/__init__.py +0 -11
  20. haiku/rag/graph/agui/__init__.py +8 -2
  21. haiku/rag/graph/agui/cli_renderer.py +1 -1
  22. haiku/rag/graph/agui/emitter.py +219 -31
  23. haiku/rag/graph/agui/server.py +20 -62
  24. haiku/rag/graph/agui/stream.py +1 -2
  25. haiku/rag/graph/research/__init__.py +5 -2
  26. haiku/rag/graph/research/dependencies.py +12 -126
  27. haiku/rag/graph/research/graph.py +390 -135
  28. haiku/rag/graph/research/models.py +91 -112
  29. haiku/rag/graph/research/prompts.py +99 -91
  30. haiku/rag/graph/research/state.py +35 -27
  31. haiku/rag/inspector/__init__.py +8 -0
  32. haiku/rag/inspector/app.py +259 -0
  33. haiku/rag/inspector/widgets/__init__.py +6 -0
  34. haiku/rag/inspector/widgets/chunk_list.py +100 -0
  35. haiku/rag/inspector/widgets/context_modal.py +89 -0
  36. haiku/rag/inspector/widgets/detail_view.py +130 -0
  37. haiku/rag/inspector/widgets/document_list.py +75 -0
  38. haiku/rag/inspector/widgets/info_modal.py +209 -0
  39. haiku/rag/inspector/widgets/search_modal.py +183 -0
  40. haiku/rag/inspector/widgets/visual_modal.py +126 -0
  41. haiku/rag/mcp.py +106 -102
  42. haiku/rag/monitor.py +33 -9
  43. haiku/rag/providers/__init__.py +5 -0
  44. haiku/rag/providers/docling_serve.py +108 -0
  45. haiku/rag/qa/__init__.py +12 -10
  46. haiku/rag/qa/agent.py +43 -61
  47. haiku/rag/qa/prompts.py +35 -57
  48. haiku/rag/reranking/__init__.py +9 -6
  49. haiku/rag/reranking/base.py +1 -1
  50. haiku/rag/reranking/cohere.py +5 -4
  51. haiku/rag/reranking/mxbai.py +5 -2
  52. haiku/rag/reranking/vllm.py +3 -4
  53. haiku/rag/reranking/zeroentropy.py +6 -5
  54. haiku/rag/store/__init__.py +2 -1
  55. haiku/rag/store/engine.py +242 -42
  56. haiku/rag/store/exceptions.py +4 -0
  57. haiku/rag/store/models/__init__.py +8 -2
  58. haiku/rag/store/models/chunk.py +190 -0
  59. haiku/rag/store/models/document.py +46 -0
  60. haiku/rag/store/repositories/chunk.py +141 -121
  61. haiku/rag/store/repositories/document.py +25 -84
  62. haiku/rag/store/repositories/settings.py +11 -14
  63. haiku/rag/store/upgrades/__init__.py +19 -3
  64. haiku/rag/store/upgrades/v0_10_1.py +1 -1
  65. haiku/rag/store/upgrades/v0_19_6.py +65 -0
  66. haiku/rag/store/upgrades/v0_20_0.py +68 -0
  67. haiku/rag/store/upgrades/v0_23_1.py +100 -0
  68. haiku/rag/store/upgrades/v0_9_3.py +3 -3
  69. haiku/rag/utils.py +371 -146
  70. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
  71. haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
  72. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
  73. haiku/rag/chunker.py +0 -65
  74. haiku/rag/embeddings/base.py +0 -25
  75. haiku/rag/embeddings/ollama.py +0 -28
  76. haiku/rag/embeddings/openai.py +0 -26
  77. haiku/rag/embeddings/vllm.py +0 -29
  78. haiku/rag/graph/agui/events.py +0 -254
  79. haiku/rag/graph/common/__init__.py +0 -5
  80. haiku/rag/graph/common/models.py +0 -42
  81. haiku/rag/graph/common/nodes.py +0 -265
  82. haiku/rag/graph/common/prompts.py +0 -46
  83. haiku/rag/graph/common/utils.py +0 -44
  84. haiku/rag/graph/deep_qa/__init__.py +0 -1
  85. haiku/rag/graph/deep_qa/dependencies.py +0 -27
  86. haiku/rag/graph/deep_qa/graph.py +0 -243
  87. haiku/rag/graph/deep_qa/models.py +0 -20
  88. haiku/rag/graph/deep_qa/prompts.py +0 -59
  89. haiku/rag/graph/deep_qa/state.py +0 -56
  90. haiku/rag/graph/research/common.py +0 -87
  91. haiku/rag/reader.py +0 -135
  92. haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
  93. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
  94. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,229 @@
1
+ """docling-serve remote converter implementation."""
2
+
3
+ import asyncio
4
+ import json
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, ClassVar
7
+
8
+ from haiku.rag.config import AppConfig
9
+ from haiku.rag.converters.base import DocumentConverter
10
+ from haiku.rag.converters.text_utils import TextFileHandler
11
+ from haiku.rag.providers.docling_serve import DoclingServeClient
12
+
13
+ if TYPE_CHECKING:
14
+ from docling_core.types.doc.document import DoclingDocument
15
+
16
+ from haiku.rag.config.models import ModelConfig
17
+
18
+
19
+ class DoclingServeConverter(DocumentConverter):
20
+ """Converter that uses docling-serve for document conversion.
21
+
22
+ This converter offloads document processing to a docling-serve instance,
23
+ which handles heavy operations like PDF parsing, OCR, and table extraction.
24
+
25
+ For plain text files, it reads them locally and converts to markdown format
26
+ before sending to docling-serve for DoclingDocument conversion.
27
+ """
28
+
29
+ # Extensions that docling-serve can handle
30
+ docling_serve_extensions: ClassVar[list[str]] = [
31
+ ".adoc",
32
+ ".asc",
33
+ ".asciidoc",
34
+ ".bmp",
35
+ ".csv",
36
+ ".docx",
37
+ ".html",
38
+ ".xhtml",
39
+ ".jpeg",
40
+ ".jpg",
41
+ ".md",
42
+ ".pdf",
43
+ ".png",
44
+ ".pptx",
45
+ ".tiff",
46
+ ".xlsx",
47
+ ".xml",
48
+ ".webp",
49
+ ]
50
+
51
+ def __init__(self, config: AppConfig):
52
+ """Initialize the converter with configuration.
53
+
54
+ Args:
55
+ config: Application configuration containing docling-serve settings.
56
+ """
57
+ self.config = config
58
+ self.client = DoclingServeClient(
59
+ base_url=config.providers.docling_serve.base_url,
60
+ api_key=config.providers.docling_serve.api_key,
61
+ )
62
+
63
+ @property
64
+ def supported_extensions(self) -> list[str]:
65
+ """Return list of file extensions supported by this converter."""
66
+ return self.docling_serve_extensions + TextFileHandler.text_extensions
67
+
68
+ def _get_vlm_api_url(self, model: "ModelConfig") -> str:
69
+ """Construct VLM API URL from model config."""
70
+ if model.base_url:
71
+ base = model.base_url.rstrip("/")
72
+ return f"{base}/v1/chat/completions"
73
+
74
+ if model.provider == "ollama":
75
+ base = self.config.providers.ollama.base_url.rstrip("/")
76
+ return f"{base}/v1/chat/completions"
77
+
78
+ if model.provider == "openai":
79
+ return "https://api.openai.com/v1/chat/completions"
80
+
81
+ raise ValueError(f"Unsupported VLM provider: {model.provider}")
82
+
83
+ def _build_conversion_data(self) -> dict[str, str | list[str]]:
84
+ """Build form data for conversion request."""
85
+ opts = self.config.processing.conversion_options
86
+ pic_desc = opts.picture_description
87
+
88
+ data: dict[str, str | list[str]] = {
89
+ "to_formats": "json",
90
+ "do_ocr": str(opts.do_ocr).lower(),
91
+ "force_ocr": str(opts.force_ocr).lower(),
92
+ "do_table_structure": str(opts.do_table_structure).lower(),
93
+ "table_mode": opts.table_mode,
94
+ "table_cell_matching": str(opts.table_cell_matching).lower(),
95
+ "images_scale": str(opts.images_scale),
96
+ "generate_picture_images": str(
97
+ opts.generate_picture_images or pic_desc.enabled
98
+ ).lower(),
99
+ "do_picture_description": str(pic_desc.enabled).lower(),
100
+ }
101
+
102
+ if opts.ocr_lang:
103
+ data["ocr_lang"] = opts.ocr_lang
104
+
105
+ if pic_desc.enabled:
106
+ prompt = self.config.prompts.picture_description
107
+ picture_description_api = {
108
+ "url": self._get_vlm_api_url(pic_desc.model),
109
+ "params": {
110
+ "model": pic_desc.model.name,
111
+ "max_completion_tokens": pic_desc.max_tokens,
112
+ },
113
+ "prompt": prompt,
114
+ "timeout": pic_desc.timeout,
115
+ }
116
+ data["picture_description_api"] = json.dumps(picture_description_api)
117
+
118
+ return data
119
+
120
+ async def _make_request(self, files: dict, name: str) -> "DoclingDocument":
121
+ """Make an async request to docling-serve and poll for results.
122
+
123
+ Args:
124
+ files: Dictionary with files parameter for httpx
125
+ name: Name of the document being converted (for error messages)
126
+
127
+ Returns:
128
+ DoclingDocument representation
129
+
130
+ Raises:
131
+ ValueError: If conversion fails or service is unavailable
132
+ """
133
+ from docling_core.types.doc.document import DoclingDocument
134
+
135
+ data = self._build_conversion_data()
136
+ result = await self.client.submit_and_poll(
137
+ endpoint="/v1/convert/file/async",
138
+ files=files,
139
+ data=data,
140
+ name=name,
141
+ )
142
+
143
+ if result.get("status") not in ("success", "partial_success", None):
144
+ errors = result.get("errors", [])
145
+ raise ValueError(f"Conversion failed: {errors}")
146
+
147
+ json_content = result.get("document", {}).get("json_content")
148
+
149
+ if json_content is None:
150
+ raise ValueError(
151
+ f"docling-serve did not return JSON content for {name}. "
152
+ "This may indicate an unsupported file format."
153
+ )
154
+
155
+ return DoclingDocument.model_validate(json_content)
156
+
157
+ async def convert_file(self, path: Path) -> "DoclingDocument":
158
+ """Convert a file to DoclingDocument using docling-serve.
159
+
160
+ Args:
161
+ path: Path to the file to convert.
162
+
163
+ Returns:
164
+ DoclingDocument representation of the file.
165
+
166
+ Raises:
167
+ ValueError: If the file cannot be converted or service is unavailable.
168
+ """
169
+ file_extension = path.suffix.lower()
170
+
171
+ if file_extension in TextFileHandler.text_extensions:
172
+ try:
173
+ content = await asyncio.to_thread(path.read_text, encoding="utf-8")
174
+ prepared_content = TextFileHandler.prepare_text_content(
175
+ content, file_extension
176
+ )
177
+ return await self.convert_text(prepared_content, name=f"{path.stem}.md")
178
+ except Exception as e:
179
+ raise ValueError(f"Failed to read text file {path}: {e}")
180
+
181
+ def read_file():
182
+ with open(path, "rb") as f:
183
+ return f.read()
184
+
185
+ file_content = await asyncio.to_thread(read_file)
186
+ files = {"files": (path.name, file_content, "application/octet-stream")}
187
+ return await self._make_request(files, path.name)
188
+
189
+ SUPPORTED_FORMATS = ("md", "html", "plain")
190
+
191
+ async def convert_text(
192
+ self, text: str, name: str = "content.md", format: str = "md"
193
+ ) -> "DoclingDocument":
194
+ """Convert text content to DoclingDocument via docling-serve.
195
+
196
+ Sends the text to docling-serve for conversion using the specified format.
197
+
198
+ Args:
199
+ text: The text content to convert.
200
+ name: The name to use for the document (defaults to "content.md").
201
+ format: The format of the text content ("md", "html", or "plain").
202
+ Defaults to "md". Use "plain" for plain text without parsing.
203
+
204
+ Returns:
205
+ DoclingDocument representation of the text.
206
+
207
+ Raises:
208
+ ValueError: If the text cannot be converted or format is unsupported.
209
+ """
210
+ from haiku.rag.converters.text_utils import TextFileHandler
211
+
212
+ if format not in self.SUPPORTED_FORMATS:
213
+ raise ValueError(
214
+ f"Unsupported format: {format}. "
215
+ f"Supported formats: {', '.join(self.SUPPORTED_FORMATS)}"
216
+ )
217
+
218
+ # Derive document name from format to tell docling which parser to use
219
+ doc_name = f"content.{format}" if name == "content.md" else name
220
+
221
+ # Plain text doesn't need remote parsing - create document directly
222
+ if format == "plain":
223
+ return TextFileHandler._create_simple_docling_document(text, doc_name)
224
+
225
+ mime_type = "text/html" if format == "html" else "text/markdown"
226
+
227
+ text_bytes = text.encode("utf-8")
228
+ files = {"files": (doc_name, text_bytes, mime_type)}
229
+ return await self._make_request(files, doc_name)
@@ -0,0 +1,237 @@
1
+ """Shared utilities for text file handling in converters."""
2
+
3
+ import asyncio
4
+ from io import BytesIO
5
+ from typing import TYPE_CHECKING, ClassVar
6
+
7
+ if TYPE_CHECKING:
8
+ from docling_core.types.doc.document import DoclingDocument
9
+
10
+
11
+ class TextFileHandler:
12
+ """Handles conversion of text files to DoclingDocument format.
13
+
14
+ This class provides shared functionality for converting plain text and code files
15
+ to DoclingDocument format, with proper code block wrapping for syntax highlighting.
16
+ """
17
+
18
+ # Plain text extensions that we'll read directly
19
+ text_extensions: ClassVar[list[str]] = [
20
+ ".astro",
21
+ ".bash",
22
+ ".c",
23
+ ".clj",
24
+ ".cljs",
25
+ ".cpp",
26
+ ".cs",
27
+ ".css",
28
+ ".dart",
29
+ ".elm",
30
+ ".ex",
31
+ ".exs",
32
+ ".fs",
33
+ ".fsx",
34
+ ".go",
35
+ ".gql",
36
+ ".graphql",
37
+ ".groovy",
38
+ ".h",
39
+ ".hcl",
40
+ ".hpp",
41
+ ".hs",
42
+ ".java",
43
+ ".jl",
44
+ ".js",
45
+ ".json",
46
+ ".kt",
47
+ ".less",
48
+ ".lua",
49
+ ".mdx",
50
+ ".mjs",
51
+ ".ml",
52
+ ".mli",
53
+ ".nim",
54
+ ".nix",
55
+ ".php",
56
+ ".pl",
57
+ ".pm",
58
+ ".proto",
59
+ ".ps1",
60
+ ".py",
61
+ ".r",
62
+ ".rb",
63
+ ".rs",
64
+ ".sass",
65
+ ".scala",
66
+ ".scss",
67
+ ".sh",
68
+ ".sql",
69
+ ".svelte",
70
+ ".swift",
71
+ ".tf",
72
+ ".toml",
73
+ ".ts",
74
+ ".tsx",
75
+ ".txt",
76
+ ".vue",
77
+ ".xml",
78
+ ".yaml",
79
+ ".yml",
80
+ ".zig",
81
+ ]
82
+
83
+ # Code file extensions with their markdown language identifiers
84
+ code_markdown_identifier: ClassVar[dict[str, str]] = {
85
+ ".astro": "astro",
86
+ ".bash": "bash",
87
+ ".c": "c",
88
+ ".clj": "clojure",
89
+ ".cljs": "clojure",
90
+ ".cpp": "cpp",
91
+ ".cs": "csharp",
92
+ ".css": "css",
93
+ ".dart": "dart",
94
+ ".elm": "elm",
95
+ ".ex": "elixir",
96
+ ".exs": "elixir",
97
+ ".fs": "fsharp",
98
+ ".fsx": "fsharp",
99
+ ".go": "go",
100
+ ".gql": "graphql",
101
+ ".graphql": "graphql",
102
+ ".groovy": "groovy",
103
+ ".h": "c",
104
+ ".hcl": "hcl",
105
+ ".hpp": "cpp",
106
+ ".hs": "haskell",
107
+ ".java": "java",
108
+ ".jl": "julia",
109
+ ".js": "javascript",
110
+ ".json": "json",
111
+ ".kt": "kotlin",
112
+ ".less": "less",
113
+ ".lua": "lua",
114
+ ".mjs": "javascript",
115
+ ".ml": "ocaml",
116
+ ".mli": "ocaml",
117
+ ".nim": "nim",
118
+ ".nix": "nix",
119
+ ".php": "php",
120
+ ".pl": "perl",
121
+ ".pm": "perl",
122
+ ".proto": "protobuf",
123
+ ".ps1": "powershell",
124
+ ".py": "python",
125
+ ".r": "r",
126
+ ".rb": "ruby",
127
+ ".rs": "rust",
128
+ ".sass": "sass",
129
+ ".scala": "scala",
130
+ ".scss": "scss",
131
+ ".sh": "bash",
132
+ ".sql": "sql",
133
+ ".svelte": "svelte",
134
+ ".swift": "swift",
135
+ ".tf": "hcl",
136
+ ".toml": "toml",
137
+ ".ts": "typescript",
138
+ ".tsx": "tsx",
139
+ ".vue": "vue",
140
+ ".xml": "xml",
141
+ ".yaml": "yaml",
142
+ ".yml": "yaml",
143
+ ".zig": "zig",
144
+ }
145
+
146
+ @staticmethod
147
+ def prepare_text_content(content: str, file_extension: str) -> str:
148
+ """Prepare text content for conversion to DoclingDocument.
149
+
150
+ Wraps code files in markdown code blocks with appropriate language identifiers.
151
+
152
+ Args:
153
+ content: The text content.
154
+ file_extension: File extension (including dot, e.g., ".py").
155
+
156
+ Returns:
157
+ Prepared text content, possibly wrapped in code blocks.
158
+ """
159
+ if file_extension in TextFileHandler.code_markdown_identifier:
160
+ language = TextFileHandler.code_markdown_identifier[file_extension]
161
+ return f"```{language}\n{content}\n```"
162
+ return content
163
+
164
+ SUPPORTED_FORMATS = ("md", "html", "plain")
165
+
166
+ @staticmethod
167
+ def _create_simple_docling_document(text: str, name: str) -> "DoclingDocument":
168
+ """Create a simple DoclingDocument directly from text.
169
+
170
+ Used as fallback when docling's format detection fails for plain text
171
+ that doesn't contain markdown syntax.
172
+ """
173
+ from docling_core.types.doc.document import DoclingDocument
174
+ from docling_core.types.doc.labels import DocItemLabel
175
+
176
+ doc_name = name.rsplit(".", 1)[0] if "." in name else name
177
+ doc = DoclingDocument(name=doc_name)
178
+ doc.add_text(label=DocItemLabel.TEXT, text=text)
179
+ return doc
180
+
181
+ @staticmethod
182
+ def _sync_text_to_docling_document(
183
+ text: str, name: str = "content.md", format: str = "md"
184
+ ) -> "DoclingDocument":
185
+ """Synchronous implementation of text to DoclingDocument conversion."""
186
+ from docling.document_converter import DocumentConverter as DoclingDocConverter
187
+ from docling.exceptions import ConversionError
188
+ from docling_core.types.io import DocumentStream
189
+
190
+ if format not in TextFileHandler.SUPPORTED_FORMATS:
191
+ raise ValueError(
192
+ f"Unsupported format: {format}. "
193
+ f"Supported formats: {', '.join(TextFileHandler.SUPPORTED_FORMATS)}"
194
+ )
195
+
196
+ # Derive document name from format to tell docling which parser to use
197
+ doc_name = f"content.{format}" if name == "content.md" else name
198
+
199
+ # Plain text doesn't need parsing - create document directly
200
+ if format == "plain":
201
+ return TextFileHandler._create_simple_docling_document(text, doc_name)
202
+
203
+ bytes_io = BytesIO(text.encode("utf-8"))
204
+ doc_stream = DocumentStream(name=doc_name, stream=bytes_io)
205
+ converter = DoclingDocConverter()
206
+ try:
207
+ result = converter.convert(doc_stream)
208
+ return result.document
209
+ except ConversionError:
210
+ # Docling's format detection fails for plain text without markdown syntax.
211
+ # Fall back to creating a simple document directly.
212
+ return TextFileHandler._create_simple_docling_document(text, doc_name)
213
+
214
+ @staticmethod
215
+ async def text_to_docling_document(
216
+ text: str, name: str = "content.md", format: str = "md"
217
+ ) -> "DoclingDocument":
218
+ """Convert text to DoclingDocument using docling's parser.
219
+
220
+ Args:
221
+ text: The text content to convert.
222
+ name: The name to use for the document.
223
+ format: The format of the text content ("md", "html", or "plain").
224
+ Defaults to "md". Use "plain" for plain text without parsing.
225
+
226
+ Returns:
227
+ DoclingDocument representation of the text.
228
+
229
+ Raises:
230
+ ValueError: If the conversion fails or format is unsupported.
231
+ """
232
+ try:
233
+ return await asyncio.to_thread(
234
+ TextFileHandler._sync_text_to_docling_document, text, name, format
235
+ )
236
+ except Exception as e:
237
+ raise ValueError(f"Failed to convert text to DoclingDocument: {e}")
@@ -1,11 +1,100 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from pydantic_ai.embeddings import Embedder
4
+ from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel
5
+ from pydantic_ai.providers.ollama import OllamaProvider
6
+ from pydantic_ai.providers.openai import OpenAIProvider
7
+
1
8
  from haiku.rag.config import AppConfig, Config
2
- from haiku.rag.embeddings.base import EmbedderBase
3
- from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
4
9
 
10
+ if TYPE_CHECKING:
11
+ from haiku.rag.store.models.chunk import Chunk
12
+
13
+
14
+ class EmbedderWrapper:
15
+ """Wrapper around pydantic-ai Embedder with explicit query/document methods."""
16
+
17
+ def __init__(self, embedder: Embedder, vector_dim: int):
18
+ self._embedder = embedder
19
+ self._vector_dim = vector_dim
20
+
21
+ async def embed_query(self, text: str) -> list[float]:
22
+ """Embed a search query."""
23
+ result = await self._embedder.embed_query(text)
24
+ return list(result.embeddings[0])
25
+
26
+ async def embed_documents(self, texts: list[str]) -> list[list[float]]:
27
+ """Embed documents/chunks for indexing."""
28
+ if not texts:
29
+ return []
30
+ result = await self._embedder.embed_documents(texts)
31
+ return [list(e) for e in result.embeddings]
32
+
33
+
34
+ def contextualize(chunks: list["Chunk"]) -> list[str]:
35
+ """Prepare chunk content for embedding/FTS by adding context.
5
36
 
6
- def get_embedder(config: AppConfig = Config) -> EmbedderBase:
37
+ Prepends section headings to chunk content for better semantic search.
38
+
39
+ Args:
40
+ chunks: List of chunks to contextualize.
41
+
42
+ Returns:
43
+ List of contextualized text strings.
7
44
  """
8
- Factory function to get the appropriate embedder based on the configuration.
45
+ texts = []
46
+ for chunk in chunks:
47
+ meta = chunk.get_chunk_metadata()
48
+ if meta.headings:
49
+ text = "\n".join(meta.headings) + "\n" + chunk.content
50
+ else:
51
+ text = chunk.content
52
+ texts.append(text)
53
+ return texts
54
+
55
+
56
+ async def embed_chunks(
57
+ chunks: list["Chunk"], config: AppConfig = Config
58
+ ) -> list["Chunk"]:
59
+ """Generate embeddings for chunks.
60
+
61
+ Contextualizes chunks (prepends headings) before embedding for better
62
+ semantic search. Returns new Chunk objects with embeddings set.
63
+
64
+ Args:
65
+ chunks: List of chunks to embed.
66
+ config: Configuration for embedder selection.
67
+
68
+ Returns:
69
+ New list of Chunk objects with embedding field populated.
70
+ """
71
+ if not chunks:
72
+ return []
73
+
74
+ from haiku.rag.store.models.chunk import Chunk
75
+
76
+ embedder = get_embedder(config)
77
+ texts = contextualize(chunks)
78
+ embeddings = await embedder.embed_documents(texts)
79
+
80
+ return [
81
+ Chunk(
82
+ id=chunk.id,
83
+ document_id=chunk.document_id,
84
+ content=chunk.content,
85
+ metadata=chunk.metadata,
86
+ order=chunk.order,
87
+ document_uri=chunk.document_uri,
88
+ document_title=chunk.document_title,
89
+ document_meta=chunk.document_meta,
90
+ embedding=embedding,
91
+ )
92
+ for chunk, embedding in zip(chunks, embeddings)
93
+ ]
94
+
95
+
96
+ def get_embedder(config: AppConfig = Config) -> EmbedderWrapper:
97
+ """Factory function to get the appropriate embedder based on the configuration.
9
98
 
10
99
  Args:
11
100
  config: Configuration to use. Defaults to global Config.
@@ -13,37 +102,47 @@ def get_embedder(config: AppConfig = Config) -> EmbedderBase:
13
102
  Returns:
14
103
  An embedder instance configured according to the config.
15
104
  """
105
+ embedding_model = config.embeddings.model
106
+ provider = embedding_model.provider
107
+ model_name = embedding_model.name
108
+ vector_dim = embedding_model.vector_dim
16
109
 
17
- if config.embeddings.provider == "ollama":
18
- return OllamaEmbedder(
19
- config.embeddings.model, config.embeddings.vector_dim, config
110
+ if provider == "ollama":
111
+ # Use model-level base_url if set, otherwise fall back to providers config
112
+ base_url = embedding_model.base_url or f"{config.providers.ollama.base_url}/v1"
113
+ model = OpenAIEmbeddingModel(
114
+ model_name,
115
+ provider=OllamaProvider(base_url=base_url),
20
116
  )
117
+ return EmbedderWrapper(Embedder(model), vector_dim)
118
+
119
+ if provider == "openai":
120
+ if embedding_model.base_url:
121
+ model = OpenAIEmbeddingModel(
122
+ model_name,
123
+ provider=OpenAIProvider(base_url=embedding_model.base_url),
124
+ )
125
+ return EmbedderWrapper(Embedder(model), vector_dim)
126
+ return EmbedderWrapper(Embedder(f"openai:{model_name}"), vector_dim)
21
127
 
22
- if config.embeddings.provider == "voyageai":
128
+ if provider == "voyageai":
23
129
  try:
24
- from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
130
+ from haiku.rag.embeddings.voyageai import VoyageAIEmbeddingModel
25
131
  except ImportError:
26
132
  raise ImportError(
27
133
  "VoyageAI embedder requires the 'voyageai' package. "
28
134
  "Please install haiku.rag with the 'voyageai' extra: "
29
135
  "uv pip install haiku.rag[voyageai]"
30
136
  )
31
- return VoyageAIEmbedder(
32
- config.embeddings.model, config.embeddings.vector_dim, config
33
- )
34
-
35
- if config.embeddings.provider == "openai":
36
- from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
37
-
38
- return OpenAIEmbedder(
39
- config.embeddings.model, config.embeddings.vector_dim, config
40
- )
137
+ model = VoyageAIEmbeddingModel(model_name)
138
+ return EmbedderWrapper(Embedder(model), vector_dim)
41
139
 
42
- if config.embeddings.provider == "vllm":
43
- from haiku.rag.embeddings.vllm import Embedder as VllmEmbedder
140
+ if provider == "cohere":
141
+ return EmbedderWrapper(Embedder(f"cohere:{model_name}"), vector_dim)
44
142
 
45
- return VllmEmbedder(
46
- config.embeddings.model, config.embeddings.vector_dim, config
143
+ if provider == "sentence-transformers":
144
+ return EmbedderWrapper(
145
+ Embedder(f"sentence-transformers:{model_name}"), vector_dim
47
146
  )
48
147
 
49
- raise ValueError(f"Unsupported embedding provider: {config.embeddings.provider}")
148
+ raise ValueError(f"Unsupported embedding provider: {provider}")