haiku.rag-slim 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (71) hide show
  1. haiku/rag/__init__.py +0 -0
  2. haiku/rag/app.py +542 -0
  3. haiku/rag/chunker.py +65 -0
  4. haiku/rag/cli.py +466 -0
  5. haiku/rag/client.py +731 -0
  6. haiku/rag/config/__init__.py +74 -0
  7. haiku/rag/config/loader.py +94 -0
  8. haiku/rag/config/models.py +99 -0
  9. haiku/rag/embeddings/__init__.py +49 -0
  10. haiku/rag/embeddings/base.py +25 -0
  11. haiku/rag/embeddings/ollama.py +28 -0
  12. haiku/rag/embeddings/openai.py +26 -0
  13. haiku/rag/embeddings/vllm.py +29 -0
  14. haiku/rag/embeddings/voyageai.py +27 -0
  15. haiku/rag/graph/__init__.py +26 -0
  16. haiku/rag/graph/agui/__init__.py +53 -0
  17. haiku/rag/graph/agui/cli_renderer.py +135 -0
  18. haiku/rag/graph/agui/emitter.py +197 -0
  19. haiku/rag/graph/agui/events.py +254 -0
  20. haiku/rag/graph/agui/server.py +310 -0
  21. haiku/rag/graph/agui/state.py +34 -0
  22. haiku/rag/graph/agui/stream.py +86 -0
  23. haiku/rag/graph/common/__init__.py +5 -0
  24. haiku/rag/graph/common/models.py +42 -0
  25. haiku/rag/graph/common/nodes.py +265 -0
  26. haiku/rag/graph/common/prompts.py +46 -0
  27. haiku/rag/graph/common/utils.py +44 -0
  28. haiku/rag/graph/deep_qa/__init__.py +1 -0
  29. haiku/rag/graph/deep_qa/dependencies.py +27 -0
  30. haiku/rag/graph/deep_qa/graph.py +243 -0
  31. haiku/rag/graph/deep_qa/models.py +20 -0
  32. haiku/rag/graph/deep_qa/prompts.py +59 -0
  33. haiku/rag/graph/deep_qa/state.py +56 -0
  34. haiku/rag/graph/research/__init__.py +3 -0
  35. haiku/rag/graph/research/common.py +87 -0
  36. haiku/rag/graph/research/dependencies.py +151 -0
  37. haiku/rag/graph/research/graph.py +295 -0
  38. haiku/rag/graph/research/models.py +166 -0
  39. haiku/rag/graph/research/prompts.py +107 -0
  40. haiku/rag/graph/research/state.py +85 -0
  41. haiku/rag/logging.py +56 -0
  42. haiku/rag/mcp.py +245 -0
  43. haiku/rag/monitor.py +194 -0
  44. haiku/rag/qa/__init__.py +33 -0
  45. haiku/rag/qa/agent.py +93 -0
  46. haiku/rag/qa/prompts.py +60 -0
  47. haiku/rag/reader.py +135 -0
  48. haiku/rag/reranking/__init__.py +63 -0
  49. haiku/rag/reranking/base.py +13 -0
  50. haiku/rag/reranking/cohere.py +34 -0
  51. haiku/rag/reranking/mxbai.py +28 -0
  52. haiku/rag/reranking/vllm.py +44 -0
  53. haiku/rag/reranking/zeroentropy.py +59 -0
  54. haiku/rag/store/__init__.py +4 -0
  55. haiku/rag/store/engine.py +309 -0
  56. haiku/rag/store/models/__init__.py +4 -0
  57. haiku/rag/store/models/chunk.py +17 -0
  58. haiku/rag/store/models/document.py +17 -0
  59. haiku/rag/store/repositories/__init__.py +9 -0
  60. haiku/rag/store/repositories/chunk.py +442 -0
  61. haiku/rag/store/repositories/document.py +261 -0
  62. haiku/rag/store/repositories/settings.py +165 -0
  63. haiku/rag/store/upgrades/__init__.py +62 -0
  64. haiku/rag/store/upgrades/v0_10_1.py +64 -0
  65. haiku/rag/store/upgrades/v0_9_3.py +112 -0
  66. haiku/rag/utils.py +211 -0
  67. haiku_rag_slim-0.16.0.dist-info/METADATA +128 -0
  68. haiku_rag_slim-0.16.0.dist-info/RECORD +71 -0
  69. haiku_rag_slim-0.16.0.dist-info/WHEEL +4 -0
  70. haiku_rag_slim-0.16.0.dist-info/entry_points.txt +2 -0
  71. haiku_rag_slim-0.16.0.dist-info/licenses/LICENSE +7 -0
haiku/rag/monitor.py ADDED
@@ -0,0 +1,194 @@
1
+ import asyncio
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import pathspec
7
+ from pathspec.patterns.gitwildmatch import GitWildMatchPattern
8
+ from watchfiles import Change, DefaultFilter, awatch
9
+
10
+ from haiku.rag.client import HaikuRAG
11
+ from haiku.rag.config import AppConfig, Config
12
+ from haiku.rag.store.models.document import Document
13
+
14
+ if TYPE_CHECKING:
15
+ pass
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class FileFilter(DefaultFilter):
21
+ def __init__(
22
+ self,
23
+ *,
24
+ ignore_patterns: list[str] | None = None,
25
+ include_patterns: list[str] | None = None,
26
+ ) -> None:
27
+ # Lazy import to avoid loading docling
28
+ from haiku.rag.reader import FileReader
29
+
30
+ self.extensions = tuple(FileReader.extensions)
31
+ self.ignore_spec = (
32
+ pathspec.PathSpec.from_lines(GitWildMatchPattern, ignore_patterns)
33
+ if ignore_patterns
34
+ else None
35
+ )
36
+ self.include_spec = (
37
+ pathspec.PathSpec.from_lines(GitWildMatchPattern, include_patterns)
38
+ if include_patterns
39
+ else None
40
+ )
41
+ super().__init__()
42
+
43
+ def __call__(self, change: Change, path: str) -> bool:
44
+ if not self.include_file(path):
45
+ return False
46
+
47
+ # Apply default watchfiles filter
48
+ return super().__call__(change, path)
49
+
50
+ def include_file(self, path: str) -> bool:
51
+ """Check if a file should be included based on filters."""
52
+ # Check extension filter
53
+ if not path.endswith(self.extensions):
54
+ return False
55
+
56
+ # Apply include patterns if specified (whitelist mode)
57
+ if self.include_spec:
58
+ if not self.include_spec.match_file(path):
59
+ return False
60
+
61
+ # Apply ignore patterns (blacklist mode)
62
+ if self.ignore_spec:
63
+ if self.ignore_spec.match_file(path):
64
+ return False
65
+
66
+ return True
67
+
68
+
69
+ class FileWatcher:
70
+ def __init__(
71
+ self,
72
+ client: HaikuRAG,
73
+ config: AppConfig = Config,
74
+ ):
75
+ self.paths = config.monitor.directories
76
+ self.client = client
77
+ self.ignore_patterns = config.monitor.ignore_patterns or None
78
+ self.include_patterns = config.monitor.include_patterns or None
79
+ self.delete_orphans = config.monitor.delete_orphans
80
+
81
+ async def observe(self):
82
+ logger.info(f"Watching files in {self.paths}")
83
+ filter = FileFilter(
84
+ ignore_patterns=self.ignore_patterns, include_patterns=self.include_patterns
85
+ )
86
+ await self.refresh()
87
+
88
+ async for changes in awatch(*self.paths, watch_filter=filter):
89
+ await self.handler(changes)
90
+
91
+ async def handler(self, changes: set[tuple[Change, str]]):
92
+ for change, path in changes:
93
+ if change == Change.added or change == Change.modified:
94
+ await self._upsert_document(Path(path))
95
+ elif change == Change.deleted:
96
+ await self._delete_document(Path(path))
97
+
98
+ async def refresh(self):
99
+ # Lazy import to avoid loading docling
100
+ from haiku.rag.reader import FileReader
101
+
102
+ # Delete orphaned documents in background if enabled
103
+ if self.delete_orphans:
104
+ logger.info("Starting orphan cleanup in background")
105
+ asyncio.create_task(self._delete_orphans())
106
+
107
+ # Create filter to apply same logic as observe()
108
+ filter = FileFilter(
109
+ ignore_patterns=self.ignore_patterns, include_patterns=self.include_patterns
110
+ )
111
+
112
+ for path in self.paths:
113
+ for f in Path(path).rglob("**/*"):
114
+ if f.is_file() and f.suffix in FileReader.extensions:
115
+ # Apply pattern filters
116
+ if filter(Change.added, str(f)):
117
+ await self._upsert_document(f)
118
+
119
+ async def _upsert_document(self, file: Path) -> Document | None:
120
+ try:
121
+ uri = file.as_uri()
122
+ existing_doc = await self.client.get_document_by_uri(uri)
123
+
124
+ result = await self.client.create_document_from_source(str(file))
125
+ doc = result if isinstance(result, Document) else result[0]
126
+
127
+ if existing_doc:
128
+ # Check if document was actually updated by comparing updated_at timestamps
129
+ if doc.updated_at > existing_doc.updated_at:
130
+ logger.info(f"Updated document {existing_doc.id} from {file}")
131
+ else:
132
+ logger.info(
133
+ f"Skipped unchanged document {existing_doc.id} from {file}"
134
+ )
135
+ else:
136
+ logger.info(f"Created new document {doc.id} from {file}")
137
+
138
+ return doc
139
+ except Exception as e:
140
+ logger.error(f"Failed to upsert document from {file}: {e}")
141
+ return None
142
+
143
+ async def _delete_orphans(self):
144
+ """Delete documents whose source files no longer exist."""
145
+ try:
146
+ from urllib.parse import unquote, urlparse
147
+
148
+ # Create filter to apply same include/exclude logic
149
+ filter = FileFilter(
150
+ ignore_patterns=self.ignore_patterns,
151
+ include_patterns=self.include_patterns,
152
+ )
153
+
154
+ all_docs = await self.client.list_documents()
155
+
156
+ for doc in all_docs:
157
+ if not doc.uri or not doc.id:
158
+ continue
159
+
160
+ # Only check file:// URIs
161
+ parsed = urlparse(doc.uri)
162
+ if parsed.scheme != "file":
163
+ continue
164
+
165
+ # Convert URI to Path, decoding URL-encoded characters (like %20 for spaces)
166
+ file_path = Path(unquote(parsed.path))
167
+
168
+ # Check if file exists
169
+ if not file_path.exists():
170
+ # Check if file is within monitored directories
171
+ is_monitored = any(
172
+ file_path.is_relative_to(monitored_path)
173
+ for monitored_path in self.paths
174
+ )
175
+
176
+ # Check if file would have been included by filters
177
+ if is_monitored and filter.include_file(str(file_path)):
178
+ await self.client.delete_document(doc.id)
179
+ logger.info(
180
+ f"Deleted orphaned document {doc.id} for {file_path}"
181
+ )
182
+ except Exception as e:
183
+ logger.error(f"Failed to delete orphaned documents: {e}")
184
+
185
+ async def _delete_document(self, file: Path):
186
+ try:
187
+ uri = file.as_uri()
188
+ existing_doc = await self.client.get_document_by_uri(uri)
189
+
190
+ if existing_doc and existing_doc.id:
191
+ await self.client.delete_document(existing_doc.id)
192
+ logger.info(f"Deleted document {existing_doc.id} for {file}")
193
+ except Exception as e:
194
+ logger.error(f"Failed to delete document for {file}: {e}")
@@ -0,0 +1,33 @@
1
+ from haiku.rag.client import HaikuRAG
2
+ from haiku.rag.config import AppConfig, Config
3
+ from haiku.rag.qa.agent import QuestionAnswerAgent
4
+
5
+
6
+ def get_qa_agent(
7
+ client: HaikuRAG,
8
+ config: AppConfig = Config,
9
+ use_citations: bool = False,
10
+ system_prompt: str | None = None,
11
+ ) -> QuestionAnswerAgent:
12
+ """
13
+ Factory function to get a QA agent based on the configuration.
14
+
15
+ Args:
16
+ client: HaikuRAG client instance.
17
+ config: Configuration to use. Defaults to global Config.
18
+ use_citations: Whether to include citations in responses.
19
+ system_prompt: Optional custom system prompt.
20
+
21
+ Returns:
22
+ A configured QuestionAnswerAgent instance.
23
+ """
24
+ provider = config.qa.provider
25
+ model_name = config.qa.model
26
+
27
+ return QuestionAnswerAgent(
28
+ client=client,
29
+ provider=provider,
30
+ model=model_name,
31
+ use_citations=use_citations,
32
+ system_prompt=system_prompt,
33
+ )
haiku/rag/qa/agent.py ADDED
@@ -0,0 +1,93 @@
1
+ from pydantic import BaseModel, Field
2
+ from pydantic_ai import Agent, RunContext
3
+ from pydantic_ai.models.openai import OpenAIChatModel
4
+ from pydantic_ai.providers.ollama import OllamaProvider
5
+ from pydantic_ai.providers.openai import OpenAIProvider
6
+
7
+ from haiku.rag.client import HaikuRAG
8
+ from haiku.rag.config import Config
9
+ from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT, QA_SYSTEM_PROMPT_WITH_CITATIONS
10
+
11
+
12
+ class SearchResult(BaseModel):
13
+ content: str = Field(description="The document text content")
14
+ score: float = Field(description="Relevance score (higher is more relevant)")
15
+ document_uri: str = Field(
16
+ description="Source title (if available) or URI/path of the document"
17
+ )
18
+
19
+
20
+ class Dependencies(BaseModel):
21
+ model_config = {"arbitrary_types_allowed": True}
22
+ client: HaikuRAG
23
+
24
+
25
+ class QuestionAnswerAgent:
26
+ def __init__(
27
+ self,
28
+ client: HaikuRAG,
29
+ provider: str,
30
+ model: str,
31
+ use_citations: bool = False,
32
+ q: float = 0.0,
33
+ system_prompt: str | None = None,
34
+ ):
35
+ self._client = client
36
+
37
+ if system_prompt is None:
38
+ system_prompt = (
39
+ QA_SYSTEM_PROMPT_WITH_CITATIONS if use_citations else QA_SYSTEM_PROMPT
40
+ )
41
+ model_obj = self._get_model(provider, model)
42
+
43
+ self._agent = Agent(
44
+ model=model_obj,
45
+ deps_type=Dependencies,
46
+ system_prompt=system_prompt,
47
+ retries=3,
48
+ )
49
+
50
+ @self._agent.tool
51
+ async def search_documents(
52
+ ctx: RunContext[Dependencies],
53
+ query: str,
54
+ limit: int = 3,
55
+ ) -> list[SearchResult]:
56
+ """Search the knowledge base for relevant documents."""
57
+ search_results = await ctx.deps.client.search(query, limit=limit)
58
+ expanded_results = await ctx.deps.client.expand_context(search_results)
59
+
60
+ return [
61
+ SearchResult(
62
+ content=chunk.content,
63
+ score=score,
64
+ document_uri=(chunk.document_title or chunk.document_uri or ""),
65
+ )
66
+ for chunk, score in expanded_results
67
+ ]
68
+
69
+ def _get_model(self, provider: str, model: str):
70
+ """Get the appropriate model object for the provider."""
71
+ if provider == "ollama":
72
+ return OpenAIChatModel(
73
+ model_name=model,
74
+ provider=OllamaProvider(
75
+ base_url=f"{Config.providers.ollama.base_url}/v1"
76
+ ),
77
+ )
78
+ elif provider == "vllm":
79
+ return OpenAIChatModel(
80
+ model_name=model,
81
+ provider=OpenAIProvider(
82
+ base_url=f"{Config.providers.vllm.qa_base_url}/v1", api_key="none"
83
+ ),
84
+ )
85
+ else:
86
+ # For all other providers, use the provider:model format
87
+ return f"{provider}:{model}"
88
+
89
+ async def answer(self, question: str) -> str:
90
+ """Answer a question using the RAG system."""
91
+ deps = Dependencies(client=self._client)
92
+ result = await self._agent.run(question, deps=deps)
93
+ return result.output
@@ -0,0 +1,60 @@
1
+ QA_SYSTEM_PROMPT = """
2
+ You are a knowledgeable assistant that helps users find information from a document knowledge base.
3
+
4
+ Your process:
5
+ 1. When a user asks a question, use the search_documents tool to find relevant information
6
+ 2. Search with specific keywords and phrases from the user's question
7
+ 3. Review the search results and their relevance scores
8
+ 4. If you need additional context, perform follow-up searches with different keywords
9
+ 5. Provide a short and to the point comprehensive answer based only on the retrieved documents
10
+
11
+ Guidelines:
12
+ - Base your answers strictly on the provided document content
13
+ - Quote or reference specific information when possible
14
+ - If multiple documents contain relevant information, synthesize them coherently
15
+ - Indicate when information is incomplete or when you need to search for additional context
16
+ - If the retrieved documents don't contain sufficient information, clearly state: "I cannot find enough information in the knowledge base to answer this question."
17
+ - For complex questions, consider breaking them down and performing multiple searches
18
+ - Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
19
+
20
+ Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
21
+ /no_think
22
+ """
23
+
24
+ QA_SYSTEM_PROMPT_WITH_CITATIONS = """
25
+ You are a knowledgeable assistant that helps users find information from a document knowledge base.
26
+
27
+ IMPORTANT: You MUST use the search_documents tool for every question. Do not answer any question without first searching the knowledge base.
28
+
29
+ Your process:
30
+ 1. IMMEDIATELY call the search_documents tool with relevant keywords from the user's question
31
+ 2. Review the search results and their relevance scores
32
+ 3. If you need additional context, perform follow-up searches with different keywords
33
+ 4. Provide a short and to the point comprehensive answer based only on the retrieved documents
34
+ 5. Always include citations for the sources used in your answer
35
+
36
+ Guidelines:
37
+ - Base your answers strictly on the provided document content
38
+ - If multiple documents contain relevant information, synthesize them coherently
39
+ - Indicate when information is incomplete or when you need to search for additional context
40
+ - If the retrieved documents don't contain sufficient information, clearly state: "I cannot find enough information in the knowledge base to answer this question."
41
+ - For complex questions, consider breaking them down and performing multiple searches
42
+ - Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
43
+ - ALWAYS include citations at the end of your response using the format below
44
+
45
+ Citation Format:
46
+ After your answer, include a "Citations:" section that lists:
47
+ - The document title (if available) or URI from each search result used
48
+ - A brief excerpt (first 50-100 characters) of the content that supported your answer
49
+ - Format: "Citations:\n- [document title or URI]: [content_excerpt]..."
50
+
51
+ Example response format:
52
+ [Your answer here]
53
+
54
+ Citations:
55
+ - /path/to/document1.pdf: "This document explains that AFMAN stands for Air Force Manual..."
56
+ - /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
57
+
58
+ Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
59
+ /no_think
60
+ """
haiku/rag/reader.py ADDED
@@ -0,0 +1,135 @@
1
+ from pathlib import Path
2
+ from typing import ClassVar
3
+
4
+ from docling_core.types.doc.document import DoclingDocument
5
+
6
+ from haiku.rag.utils import text_to_docling_document
7
+
8
+ # Check if docling is available
9
+ try:
10
+ import docling # noqa: F401
11
+
12
+ DOCLING_AVAILABLE = True
13
+ except ImportError:
14
+ DOCLING_AVAILABLE = False
15
+
16
+
17
+ class FileReader:
18
+ # Extensions supported by docling
19
+ docling_extensions: ClassVar[list[str]] = [
20
+ ".adoc",
21
+ ".asc",
22
+ ".asciidoc",
23
+ ".bmp",
24
+ ".csv",
25
+ ".docx",
26
+ ".html",
27
+ ".xhtml",
28
+ ".jpeg",
29
+ ".jpg",
30
+ ".md",
31
+ ".pdf",
32
+ ".png",
33
+ ".pptx",
34
+ ".tiff",
35
+ ".xlsx",
36
+ ".xml",
37
+ ".webp",
38
+ ]
39
+
40
+ # Plain text extensions that we'll read directly
41
+ text_extensions: ClassVar[list[str]] = [
42
+ ".astro",
43
+ ".c",
44
+ ".cpp",
45
+ ".css",
46
+ ".go",
47
+ ".h",
48
+ ".hpp",
49
+ ".java",
50
+ ".js",
51
+ ".json",
52
+ ".kt",
53
+ ".mdx",
54
+ ".mjs",
55
+ ".php",
56
+ ".py",
57
+ ".rb",
58
+ ".rs",
59
+ ".svelte",
60
+ ".swift",
61
+ ".ts",
62
+ ".tsx",
63
+ ".txt",
64
+ ".vue",
65
+ ".yaml",
66
+ ".yml",
67
+ ]
68
+
69
+ # Code file extensions with their markdown language identifiers for syntax highlighting
70
+ code_markdown_identifier: ClassVar[dict[str, str]] = {
71
+ ".astro": "astro",
72
+ ".c": "c",
73
+ ".cpp": "cpp",
74
+ ".css": "css",
75
+ ".go": "go",
76
+ ".h": "c",
77
+ ".hpp": "cpp",
78
+ ".java": "java",
79
+ ".js": "javascript",
80
+ ".json": "json",
81
+ ".kt": "kotlin",
82
+ ".mjs": "javascript",
83
+ ".php": "php",
84
+ ".py": "python",
85
+ ".rb": "ruby",
86
+ ".rs": "rust",
87
+ ".svelte": "svelte",
88
+ ".swift": "swift",
89
+ ".ts": "typescript",
90
+ ".tsx": "tsx",
91
+ ".vue": "vue",
92
+ ".yaml": "yaml",
93
+ ".yml": "yaml",
94
+ }
95
+
96
+ extensions: ClassVar[list[str]] = docling_extensions + text_extensions
97
+
98
+ @staticmethod
99
+ def parse_file(path: Path) -> DoclingDocument:
100
+ try:
101
+ file_extension = path.suffix.lower()
102
+
103
+ if file_extension in FileReader.docling_extensions:
104
+ # Use docling for complex document formats
105
+ if not DOCLING_AVAILABLE:
106
+ raise ImportError(
107
+ "Docling is required for processing this file type. "
108
+ "Install with: pip install haiku.rag-slim[docling]"
109
+ )
110
+ from docling.document_converter import DocumentConverter
111
+
112
+ converter = DocumentConverter()
113
+ result = converter.convert(path)
114
+ return result.document
115
+ elif file_extension in FileReader.text_extensions:
116
+ # Read plain text files directly
117
+ content = path.read_text(encoding="utf-8")
118
+
119
+ # Wrap code files (but not plain txt) in markdown code blocks for better presentation
120
+ if file_extension in FileReader.code_markdown_identifier:
121
+ language = FileReader.code_markdown_identifier[file_extension]
122
+ content = f"```{language}\n{content}\n```"
123
+
124
+ # Convert text to DoclingDocument by wrapping as markdown
125
+ return text_to_docling_document(content, name=f"{path.stem}.md")
126
+ else:
127
+ # Fallback: try to read as text and convert to DoclingDocument
128
+ content = path.read_text(encoding="utf-8")
129
+ return text_to_docling_document(content, name=f"{path.stem}.md")
130
+ except ImportError:
131
+ raise
132
+ except Exception as e:
133
+ raise ValueError(
134
+ f"Failed to parse file: {path} - {type(e).__name__}: {e}"
135
+ ) from e
@@ -0,0 +1,63 @@
1
+ import os
2
+
3
+ from haiku.rag.config import AppConfig, Config
4
+ from haiku.rag.reranking.base import RerankerBase
5
+
6
+ _reranker_cache: dict[int, RerankerBase | None] = {}
7
+
8
+
9
+ def get_reranker(config: AppConfig = Config) -> RerankerBase | None:
10
+ """
11
+ Factory function to get the appropriate reranker based on the configuration.
12
+ Returns None if reranking is disabled.
13
+
14
+ Args:
15
+ config: Configuration to use. Defaults to global Config.
16
+
17
+ Returns:
18
+ A reranker instance if configured, None otherwise.
19
+ """
20
+ # Use config id as cache key to support multiple configs
21
+ config_id = id(config)
22
+ if config_id in _reranker_cache:
23
+ return _reranker_cache[config_id]
24
+
25
+ reranker: RerankerBase | None = None
26
+
27
+ if config.reranking.provider == "mxbai":
28
+ try:
29
+ from haiku.rag.reranking.mxbai import MxBAIReranker
30
+
31
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
32
+ reranker = MxBAIReranker()
33
+ except ImportError:
34
+ reranker = None
35
+
36
+ elif config.reranking.provider == "cohere":
37
+ try:
38
+ from haiku.rag.reranking.cohere import CohereReranker
39
+
40
+ reranker = CohereReranker()
41
+ except ImportError:
42
+ reranker = None
43
+
44
+ elif config.reranking.provider == "vllm":
45
+ try:
46
+ from haiku.rag.reranking.vllm import VLLMReranker
47
+
48
+ reranker = VLLMReranker(config.reranking.model)
49
+ except ImportError:
50
+ reranker = None
51
+
52
+ elif config.reranking.provider == "zeroentropy":
53
+ try:
54
+ from haiku.rag.reranking.zeroentropy import ZeroEntropyReranker
55
+
56
+ # Use configured model or default to zerank-1
57
+ model = config.reranking.model or "zerank-1"
58
+ reranker = ZeroEntropyReranker(model)
59
+ except ImportError:
60
+ reranker = None
61
+
62
+ _reranker_cache[config_id] = reranker
63
+ return reranker
@@ -0,0 +1,13 @@
1
+ from haiku.rag.config import Config
2
+ from haiku.rag.store.models.chunk import Chunk
3
+
4
+
5
+ class RerankerBase:
6
+ _model: str = Config.reranking.model
7
+
8
+ async def rerank(
9
+ self, query: str, chunks: list[Chunk], top_n: int = 10
10
+ ) -> list[tuple[Chunk, float]]:
11
+ raise NotImplementedError(
12
+ "Reranker is an abstract class. Please implement the rerank method in a subclass."
13
+ )
@@ -0,0 +1,34 @@
1
+ from haiku.rag.reranking.base import RerankerBase
2
+ from haiku.rag.store.models.chunk import Chunk
3
+
4
+ try:
5
+ import cohere
6
+ except ImportError as e:
7
+ raise ImportError(
8
+ "cohere is not installed. Please install it with `pip install cohere` or use the cohere optional dependency."
9
+ ) from e
10
+
11
+
12
+ class CohereReranker(RerankerBase):
13
+ def __init__(self):
14
+ # Cohere SDK reads CO_API_KEY from environment by default
15
+ self._client = cohere.ClientV2()
16
+
17
+ async def rerank(
18
+ self, query: str, chunks: list[Chunk], top_n: int = 10
19
+ ) -> list[tuple[Chunk, float]]:
20
+ if not chunks:
21
+ return []
22
+
23
+ documents = [chunk.content for chunk in chunks]
24
+
25
+ response = self._client.rerank(
26
+ model=self._model, query=query, documents=documents, top_n=top_n
27
+ )
28
+
29
+ reranked_chunks = []
30
+ for result in response.results:
31
+ original_chunk = chunks[result.index]
32
+ reranked_chunks.append((original_chunk, result.relevance_score))
33
+
34
+ return reranked_chunks
@@ -0,0 +1,28 @@
1
+ from mxbai_rerank import MxbaiRerankV2 # pyright: ignore[reportMissingImports]
2
+
3
+ from haiku.rag.config import Config
4
+ from haiku.rag.reranking.base import RerankerBase
5
+ from haiku.rag.store.models.chunk import Chunk
6
+
7
+
8
+ class MxBAIReranker(RerankerBase):
9
+ def __init__(self):
10
+ self._client = MxbaiRerankV2(
11
+ Config.reranking.model, disable_transformers_warnings=True
12
+ )
13
+
14
+ async def rerank(
15
+ self, query: str, chunks: list[Chunk], top_n: int = 10
16
+ ) -> list[tuple[Chunk, float]]:
17
+ if not chunks:
18
+ return []
19
+
20
+ documents = [chunk.content for chunk in chunks]
21
+
22
+ results = self._client.rank(query=query, documents=documents, top_k=top_n)
23
+ reranked_chunks = []
24
+ for result in results:
25
+ original_chunk = chunks[result.index]
26
+ reranked_chunks.append((original_chunk, result.score))
27
+
28
+ return reranked_chunks