haiku.rag 0.5.1__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

haiku/rag/app.py CHANGED
@@ -32,9 +32,9 @@ class HaikuRAGApp:
32
32
  f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
33
33
  )
34
34
 
35
- async def add_document_from_source(self, file_path: Path):
35
+ async def add_document_from_source(self, source: str):
36
36
  async with HaikuRAG(db_path=self.db_path) as self.client:
37
- doc = await self.client.create_document_from_source(file_path)
37
+ doc = await self.client.create_document_from_source(source)
38
38
  self._rich_print_document(doc, truncate=True)
39
39
  self.console.print(
40
40
  f"[b]Document with id [cyan]{doc.id}[/cyan] added successfully.[/b]"
@@ -62,10 +62,10 @@ class HaikuRAGApp:
62
62
  for chunk, score in results:
63
63
  self._rich_print_search_result(chunk, score)
64
64
 
65
- async def ask(self, question: str):
65
+ async def ask(self, question: str, cite: bool = False):
66
66
  async with HaikuRAG(db_path=self.db_path) as self.client:
67
67
  try:
68
- answer = await self.client.ask(question)
68
+ answer = await self.client.ask(question, cite=cite)
69
69
  self.console.print(f"[bold blue]Question:[/bold blue] {question}")
70
70
  self.console.print()
71
71
  self.console.print("[bold green]Answer:[/bold green]")
haiku/rag/cli.py CHANGED
@@ -81,7 +81,7 @@ def add_document_text(
81
81
 
82
82
  @cli.command("add-src", help="Add a document from a file path or URL")
83
83
  def add_document_src(
84
- file_path: Path = typer.Argument(
84
+ source: str = typer.Argument(
85
85
  help="The file path or URL of the document to add",
86
86
  ),
87
87
  db: Path = typer.Option(
@@ -91,7 +91,7 @@ def add_document_src(
91
91
  ),
92
92
  ):
93
93
  app = HaikuRAGApp(db_path=db)
94
- asyncio.run(app.add_document_from_source(file_path=file_path))
94
+ asyncio.run(app.add_document_from_source(source=source))
95
95
 
96
96
 
97
97
  @cli.command("get", help="Get and display a document by its ID")
@@ -160,9 +160,14 @@ def ask(
160
160
  "--db",
161
161
  help="Path to the SQLite database file",
162
162
  ),
163
+ cite: bool = typer.Option(
164
+ False,
165
+ "--cite",
166
+ help="Include citations in the response",
167
+ ),
163
168
  ):
164
169
  app = HaikuRAGApp(db_path=db)
165
- asyncio.run(app.ask(question=question))
170
+ asyncio.run(app.ask(question=question, cite=cite))
166
171
 
167
172
 
168
173
  @cli.command("settings", help="Display current configuration settings")
haiku/rag/client.py CHANGED
@@ -319,7 +319,7 @@ class HaikuRAG:
319
319
  return await self.document_repository.list_all(limit=limit, offset=offset)
320
320
 
321
321
  async def search(
322
- self, query: str, limit: int = 5, k: int = 60, rerank=Config.RERANK
322
+ self, query: str, limit: int = 5, k: int = 60
323
323
  ) -> list[tuple[Chunk, float]]:
324
324
  """Search for relevant chunks using hybrid search (vector similarity + full-text search) with reranking.
325
325
 
@@ -331,8 +331,10 @@ class HaikuRAG:
331
331
  Returns:
332
332
  List of (chunk, score) tuples ordered by relevance.
333
333
  """
334
+ # Get reranker if available
335
+ reranker = get_reranker()
334
336
 
335
- if not rerank:
337
+ if reranker is None:
336
338
  return await self.chunk_repository.search_chunks_hybrid(query, limit, k)
337
339
 
338
340
  # Get more initial results (3X) for reranking
@@ -340,25 +342,151 @@ class HaikuRAG:
340
342
  query, limit * 3, k
341
343
  )
342
344
  # Apply reranking
343
- reranker = get_reranker()
344
345
  chunks = [chunk for chunk, _ in search_results]
345
346
  reranked_results = await reranker.rerank(query, chunks, top_n=limit)
346
347
 
347
348
  # Return reranked results with scores from reranker
348
349
  return reranked_results
349
350
 
350
- async def ask(self, question: str) -> str:
351
+ async def expand_context(
352
+ self, search_results: list[tuple[Chunk, float]]
353
+ ) -> list[tuple[Chunk, float]]:
354
+ """Expand search results with adjacent chunks, merging overlapping chunks.
355
+
356
+ Args:
357
+ search_results: List of (chunk, score) tuples from search.
358
+
359
+ Returns:
360
+ List of (chunk, score) tuples with expanded and merged context chunks.
361
+ """
362
+ if Config.CONTEXT_CHUNK_RADIUS == 0:
363
+ return search_results
364
+
365
+ # Group chunks by document_id to handle merging within documents
366
+ document_groups = {}
367
+ for chunk, score in search_results:
368
+ doc_id = chunk.document_id
369
+ if doc_id not in document_groups:
370
+ document_groups[doc_id] = []
371
+ document_groups[doc_id].append((chunk, score))
372
+
373
+ results = []
374
+
375
+ for doc_id, doc_chunks in document_groups.items():
376
+ # Get all expanded ranges for this document
377
+ expanded_ranges = []
378
+ for chunk, score in doc_chunks:
379
+ adjacent_chunks = await self.chunk_repository.get_adjacent_chunks(
380
+ chunk, Config.CONTEXT_CHUNK_RADIUS
381
+ )
382
+
383
+ all_chunks = adjacent_chunks + [chunk]
384
+
385
+ # Get the range of orders for this expanded chunk
386
+ orders = [c.metadata.get("order", 0) for c in all_chunks]
387
+ min_order = min(orders)
388
+ max_order = max(orders)
389
+
390
+ expanded_ranges.append(
391
+ {
392
+ "original_chunk": chunk,
393
+ "score": score,
394
+ "min_order": min_order,
395
+ "max_order": max_order,
396
+ "all_chunks": sorted(
397
+ all_chunks, key=lambda c: c.metadata.get("order", 0)
398
+ ),
399
+ }
400
+ )
401
+
402
+ # Merge overlapping/adjacent ranges
403
+ merged_ranges = self._merge_overlapping_ranges(expanded_ranges)
404
+
405
+ # Create merged chunks
406
+ for merged_range in merged_ranges:
407
+ combined_content_parts = [c.content for c in merged_range["all_chunks"]]
408
+
409
+ # Use the first original chunk for metadata
410
+ original_chunk = merged_range["original_chunks"][0]
411
+
412
+ merged_chunk = Chunk(
413
+ id=original_chunk.id,
414
+ document_id=original_chunk.document_id,
415
+ content="".join(combined_content_parts),
416
+ metadata=original_chunk.metadata,
417
+ document_uri=original_chunk.document_uri,
418
+ document_meta=original_chunk.document_meta,
419
+ )
420
+
421
+ # Use the highest score from merged chunks
422
+ best_score = max(merged_range["scores"])
423
+ results.append((merged_chunk, best_score))
424
+
425
+ return results
426
+
427
+ def _merge_overlapping_ranges(self, expanded_ranges):
428
+ """Merge overlapping or adjacent expanded ranges."""
429
+ if not expanded_ranges:
430
+ return []
431
+
432
+ # Sort by min_order
433
+ sorted_ranges = sorted(expanded_ranges, key=lambda x: x["min_order"])
434
+ merged = []
435
+
436
+ current = {
437
+ "min_order": sorted_ranges[0]["min_order"],
438
+ "max_order": sorted_ranges[0]["max_order"],
439
+ "original_chunks": [sorted_ranges[0]["original_chunk"]],
440
+ "scores": [sorted_ranges[0]["score"]],
441
+ "all_chunks": sorted_ranges[0]["all_chunks"],
442
+ }
443
+
444
+ for range_info in sorted_ranges[1:]:
445
+ # Check if ranges overlap or are adjacent (max_order + 1 >= min_order)
446
+ if current["max_order"] >= range_info["min_order"] - 1:
447
+ # Merge ranges
448
+ current["max_order"] = max(
449
+ current["max_order"], range_info["max_order"]
450
+ )
451
+ current["original_chunks"].append(range_info["original_chunk"])
452
+ current["scores"].append(range_info["score"])
453
+
454
+ # Merge all_chunks and deduplicate by order
455
+ all_chunks_dict = {}
456
+ for chunk in current["all_chunks"] + range_info["all_chunks"]:
457
+ order = chunk.metadata.get("order", 0)
458
+ all_chunks_dict[order] = chunk
459
+ current["all_chunks"] = [
460
+ all_chunks_dict[order] for order in sorted(all_chunks_dict.keys())
461
+ ]
462
+ else:
463
+ # No overlap, add current to merged and start new
464
+ merged.append(current)
465
+ current = {
466
+ "min_order": range_info["min_order"],
467
+ "max_order": range_info["max_order"],
468
+ "original_chunks": [range_info["original_chunk"]],
469
+ "scores": [range_info["score"]],
470
+ "all_chunks": range_info["all_chunks"],
471
+ }
472
+
473
+ # Add the last range
474
+ merged.append(current)
475
+ return merged
476
+
477
+ async def ask(self, question: str, cite: bool = False) -> str:
351
478
  """Ask a question using the configured QA agent.
352
479
 
353
480
  Args:
354
481
  question: The question to ask.
482
+ cite: Whether to include citations in the response.
355
483
 
356
484
  Returns:
357
485
  The generated answer as a string.
358
486
  """
359
487
  from haiku.rag.qa import get_qa_agent
360
488
 
361
- qa_agent = get_qa_agent(self)
489
+ qa_agent = get_qa_agent(self, use_citations=cite)
362
490
  return await qa_agent.answer(question)
363
491
 
364
492
  async def rebuild_database(self) -> AsyncGenerator[int, None]:
haiku/rag/config.py CHANGED
@@ -19,14 +19,14 @@ class AppConfig(BaseModel):
19
19
  EMBEDDINGS_MODEL: str = "mxbai-embed-large"
20
20
  EMBEDDINGS_VECTOR_DIM: int = 1024
21
21
 
22
- RERANK: bool = True
23
- RERANK_PROVIDER: str = "mxbai"
24
- RERANK_MODEL: str = "mixedbread-ai/mxbai-rerank-base-v2"
22
+ RERANK_PROVIDER: str = "ollama"
23
+ RERANK_MODEL: str = "qwen3"
25
24
 
26
25
  QA_PROVIDER: str = "ollama"
27
26
  QA_MODEL: str = "qwen3"
28
27
 
29
28
  CHUNK_SIZE: int = 256
29
+ CONTEXT_CHUNK_RADIUS: int = 0
30
30
 
31
31
  OLLAMA_BASE_URL: str = "http://localhost:11434"
32
32
 
haiku/rag/qa/__init__.py CHANGED
@@ -4,12 +4,16 @@ from haiku.rag.qa.base import QuestionAnswerAgentBase
4
4
  from haiku.rag.qa.ollama import QuestionAnswerOllamaAgent
5
5
 
6
6
 
7
- def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
7
+ def get_qa_agent(
8
+ client: HaikuRAG, model: str = "", use_citations: bool = False
9
+ ) -> QuestionAnswerAgentBase:
8
10
  """
9
11
  Factory function to get the appropriate QA agent based on the configuration.
10
12
  """
11
13
  if Config.QA_PROVIDER == "ollama":
12
- return QuestionAnswerOllamaAgent(client, model or Config.QA_MODEL)
14
+ return QuestionAnswerOllamaAgent(
15
+ client, model or Config.QA_MODEL, use_citations
16
+ )
13
17
 
14
18
  if Config.QA_PROVIDER == "openai":
15
19
  try:
@@ -20,7 +24,9 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
20
24
  "Please install haiku.rag with the 'openai' extra:"
21
25
  "uv pip install haiku.rag[openai]"
22
26
  )
23
- return QuestionAnswerOpenAIAgent(client, model or Config.QA_MODEL)
27
+ return QuestionAnswerOpenAIAgent(
28
+ client, model or Config.QA_MODEL, use_citations
29
+ )
24
30
 
25
31
  if Config.QA_PROVIDER == "anthropic":
26
32
  try:
@@ -31,6 +37,8 @@ def get_qa_agent(client: HaikuRAG, model: str = "") -> QuestionAnswerAgentBase:
31
37
  "Please install haiku.rag with the 'anthropic' extra:"
32
38
  "uv pip install haiku.rag[anthropic]"
33
39
  )
34
- return QuestionAnswerAnthropicAgent(client, model or Config.QA_MODEL)
40
+ return QuestionAnswerAnthropicAgent(
41
+ client, model or Config.QA_MODEL, use_citations
42
+ )
35
43
 
36
44
  raise ValueError(f"Unsupported QA provider: {Config.QA_PROVIDER}")
haiku/rag/qa/anthropic.py CHANGED
@@ -1,19 +1,29 @@
1
1
  from collections.abc import Sequence
2
2
 
3
3
  try:
4
- from anthropic import AsyncAnthropic
5
- from anthropic.types import MessageParam, TextBlock, ToolParam, ToolUseBlock
4
+ from anthropic import AsyncAnthropic # type: ignore
5
+ from anthropic.types import ( # type: ignore
6
+ MessageParam,
7
+ TextBlock,
8
+ ToolParam,
9
+ ToolUseBlock,
10
+ )
6
11
 
7
12
  from haiku.rag.client import HaikuRAG
8
13
  from haiku.rag.qa.base import QuestionAnswerAgentBase
9
14
 
10
15
  class QuestionAnswerAnthropicAgent(QuestionAnswerAgentBase):
11
- def __init__(self, client: HaikuRAG, model: str = "claude-3-5-haiku-20241022"):
12
- super().__init__(client, model or self._model)
16
+ def __init__(
17
+ self,
18
+ client: HaikuRAG,
19
+ model: str = "claude-3-5-haiku-20241022",
20
+ use_citations: bool = False,
21
+ ):
22
+ super().__init__(client, model or self._model, use_citations)
13
23
  self.tools: Sequence[ToolParam] = [
14
24
  ToolParam(
15
25
  name="search_documents",
16
- description="Search the knowledge base for relevant documents",
26
+ description="Search the knowledge base for relevant documents. Returns a JSON array with content, score, and document_uri for each result.",
17
27
  input_schema={
18
28
  "type": "object",
19
29
  "properties": {
@@ -69,18 +79,10 @@ try:
69
79
  else 3
70
80
  )
71
81
 
72
- search_results = await self._client.search(
82
+ context = await self._search_and_expand(
73
83
  query, limit=limit
74
84
  )
75
85
 
76
- context_chunks = []
77
- for chunk, score in search_results:
78
- context_chunks.append(
79
- f"Content: {chunk.content}\nScore: {score:.4f}"
80
- )
81
-
82
- context = "\n\n".join(context_chunks)
83
-
84
86
  tool_results.append(
85
87
  {
86
88
  "type": "tool_result",
haiku/rag/qa/base.py CHANGED
@@ -1,26 +1,50 @@
1
+ import json
2
+
1
3
  from haiku.rag.client import HaikuRAG
2
- from haiku.rag.qa.prompts import SYSTEM_PROMPT
4
+ from haiku.rag.qa.prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_WITH_CITATIONS
3
5
 
4
6
 
5
7
  class QuestionAnswerAgentBase:
6
8
  _model: str = ""
7
9
  _system_prompt: str = SYSTEM_PROMPT
8
10
 
9
- def __init__(self, client: HaikuRAG, model: str = ""):
11
+ def __init__(self, client: HaikuRAG, model: str = "", use_citations: bool = False):
10
12
  self._model = model
11
13
  self._client = client
14
+ self._system_prompt = (
15
+ SYSTEM_PROMPT_WITH_CITATIONS if use_citations else SYSTEM_PROMPT
16
+ )
12
17
 
13
18
  async def answer(self, question: str) -> str:
14
19
  raise NotImplementedError(
15
20
  "QABase is an abstract class. Please implement the answer method in a subclass."
16
21
  )
17
22
 
23
+ async def _search_and_expand(self, query: str, limit: int = 3) -> str:
24
+ """Search for documents and expand context, then format as JSON"""
25
+ search_results = await self._client.search(query, limit=limit)
26
+ expanded_results = await self._client.expand_context(search_results)
27
+ return self._format_search_results(expanded_results)
28
+
29
+ def _format_search_results(self, search_results) -> str:
30
+ """Format search results as JSON list of {content, score, document_uri}"""
31
+ formatted_results = []
32
+ for chunk, score in search_results:
33
+ formatted_results.append(
34
+ {
35
+ "content": chunk.content,
36
+ "score": score,
37
+ "document_uri": chunk.document_uri,
38
+ }
39
+ )
40
+ return json.dumps(formatted_results, indent=2)
41
+
18
42
  tools = [
19
43
  {
20
44
  "type": "function",
21
45
  "function": {
22
46
  "name": "search_documents",
23
- "description": "Search the knowledge base for relevant documents",
47
+ "description": "Search the knowledge base for relevant documents. Returns a JSON array of search results.",
24
48
  "parameters": {
25
49
  "type": "object",
26
50
  "properties": {
@@ -36,6 +60,30 @@ class QuestionAnswerAgentBase:
36
60
  },
37
61
  "required": ["query"],
38
62
  },
63
+ "returns": {
64
+ "type": "string",
65
+ "description": "JSON array of search results",
66
+ "schema": {
67
+ "type": "array",
68
+ "items": {
69
+ "type": "object",
70
+ "properties": {
71
+ "content": {
72
+ "type": "string",
73
+ "description": "The document text content",
74
+ },
75
+ "score": {
76
+ "type": "number",
77
+ "description": "Relevance score (higher is more relevant)",
78
+ },
79
+ "document_uri": {
80
+ "type": "string",
81
+ "description": "Source URI/path of the document",
82
+ },
83
+ },
84
+ },
85
+ },
86
+ },
39
87
  },
40
88
  }
41
89
  ]
haiku/rag/qa/ollama.py CHANGED
@@ -8,8 +8,13 @@ OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
8
8
 
9
9
 
10
10
  class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
11
- def __init__(self, client: HaikuRAG, model: str = Config.QA_MODEL):
12
- super().__init__(client, model or self._model)
11
+ def __init__(
12
+ self,
13
+ client: HaikuRAG,
14
+ model: str = Config.QA_MODEL,
15
+ use_citations: bool = False,
16
+ ):
17
+ super().__init__(client, model or self._model, use_citations)
13
18
 
14
19
  async def answer(self, question: str) -> str:
15
20
  ollama_client = AsyncClient(host=Config.OLLAMA_BASE_URL)
@@ -39,16 +44,7 @@ class QuestionAnswerOllamaAgent(QuestionAnswerAgentBase):
39
44
  query = args.get("query", question)
40
45
  limit = int(args.get("limit", 3))
41
46
 
42
- search_results = await self._client.search(query, limit=limit)
43
-
44
- context_chunks = []
45
- for chunk, score in search_results:
46
- context_chunks.append(
47
- f"Content: {chunk.content}\nScore: {score:.4f}"
48
- )
49
-
50
- context = "\n\n".join(context_chunks)
51
-
47
+ context = await self._search_and_expand(query, limit=limit)
52
48
  messages.append(
53
49
  {
54
50
  "role": "tool",
haiku/rag/qa/openai.py CHANGED
@@ -1,22 +1,29 @@
1
1
  from collections.abc import Sequence
2
2
 
3
3
  try:
4
- from openai import AsyncOpenAI
5
- from openai.types.chat import (
4
+ from openai import AsyncOpenAI # type: ignore
5
+ from openai.types.chat import ( # type: ignore
6
6
  ChatCompletionAssistantMessageParam,
7
7
  ChatCompletionMessageParam,
8
8
  ChatCompletionSystemMessageParam,
9
9
  ChatCompletionToolMessageParam,
10
10
  ChatCompletionUserMessageParam,
11
11
  )
12
- from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
12
+ from openai.types.chat.chat_completion_tool_param import ( # type: ignore
13
+ ChatCompletionToolParam,
14
+ )
13
15
 
14
16
  from haiku.rag.client import HaikuRAG
15
17
  from haiku.rag.qa.base import QuestionAnswerAgentBase
16
18
 
17
19
  class QuestionAnswerOpenAIAgent(QuestionAnswerAgentBase):
18
- def __init__(self, client: HaikuRAG, model: str = "gpt-4o-mini"):
19
- super().__init__(client, model or self._model)
20
+ def __init__(
21
+ self,
22
+ client: HaikuRAG,
23
+ model: str = "gpt-4o-mini",
24
+ use_citations: bool = False,
25
+ ):
26
+ super().__init__(client, model or self._model, use_citations)
20
27
  self.tools: Sequence[ChatCompletionToolParam] = [
21
28
  ChatCompletionToolParam(tool) for tool in self.tools
22
29
  ]
@@ -70,17 +77,7 @@ try:
70
77
  query = args.get("query", question)
71
78
  limit = int(args.get("limit", 3))
72
79
 
73
- search_results = await self._client.search(
74
- query, limit=limit
75
- )
76
-
77
- context_chunks = []
78
- for chunk, score in search_results:
79
- context_chunks.append(
80
- f"Content: {chunk.content}\nScore: {score:.4f}"
81
- )
82
-
83
- context = "\n\n".join(context_chunks)
80
+ context = await self._search_and_expand(query, limit=limit)
84
81
 
85
82
  messages.append(
86
83
  ChatCompletionToolMessageParam(
haiku/rag/qa/prompts.py CHANGED
@@ -19,3 +19,40 @@ Guidelines:
19
19
 
20
20
  Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
21
21
  """
22
+
23
+ SYSTEM_PROMPT_WITH_CITATIONS = """
24
+ You are a knowledgeable assistant that helps users find information from a document knowledge base.
25
+
26
+ IMPORTANT: You MUST use the search_documents tool for every question. Do not answer any question without first searching the knowledge base.
27
+
28
+ Your process:
29
+ 1. IMMEDIATELY call the search_documents tool with relevant keywords from the user's question
30
+ 2. Review the search results and their relevance scores
31
+ 3. If you need additional context, perform follow-up searches with different keywords
32
+ 4. Provide a short and to the point comprehensive answer based only on the retrieved documents
33
+ 5. Always include citations for the sources used in your answer
34
+
35
+ Guidelines:
36
+ - Base your answers strictly on the provided document content
37
+ - If multiple documents contain relevant information, synthesize them coherently
38
+ - Indicate when information is incomplete or when you need to search for additional context
39
+ - If the retrieved documents don't contain sufficient information, clearly state: "I cannot find enough information in the knowledge base to answer this question."
40
+ - For complex questions, consider breaking them down and performing multiple searches
41
+ - Stick to the answer, do not ellaborate or provide context unless explicitly asked for it.
42
+ - ALWAYS include citations at the end of your response using the format below
43
+
44
+ Citation Format:
45
+ After your answer, include a "Citations:" section that lists:
46
+ - The document URI from each search result used
47
+ - A brief excerpt (first 50-100 characters) of the content that supported your answer
48
+ - Format: "Citations:\n- [document_uri]: [content_excerpt]..."
49
+
50
+ Example response format:
51
+ [Your answer here]
52
+
53
+ Citations:
54
+ - /path/to/document1.pdf: "This document explains that AFMAN stands for Air Force Manual..."
55
+ - /path/to/document2.pdf: "The manual provides guidance on military procedures and..."
56
+
57
+ Be concise, and always maintain accuracy over completeness. Prefer short, direct answers that are well-supported by the documents.
58
+ """
@@ -1,37 +1,40 @@
1
1
  from haiku.rag.config import Config
2
2
  from haiku.rag.reranking.base import RerankerBase
3
3
 
4
- try:
5
- from haiku.rag.reranking.cohere import CohereReranker
6
- except ImportError:
7
- pass
8
-
9
4
  _reranker: RerankerBase | None = None
10
5
 
11
6
 
12
- def get_reranker() -> RerankerBase:
7
+ def get_reranker() -> RerankerBase | None:
13
8
  """
14
9
  Factory function to get the appropriate reranker based on the configuration.
10
+ Returns None if if reranking is disabled.
15
11
  """
16
12
  global _reranker
17
13
  if _reranker is not None:
18
14
  return _reranker
15
+
19
16
  if Config.RERANK_PROVIDER == "mxbai":
20
- from haiku.rag.reranking.mxbai import MxBAIReranker
17
+ try:
18
+ from haiku.rag.reranking.mxbai import MxBAIReranker
21
19
 
22
- _reranker = MxBAIReranker()
23
- return _reranker
20
+ _reranker = MxBAIReranker()
21
+ return _reranker
22
+ except ImportError:
23
+ return None
24
24
 
25
25
  if Config.RERANK_PROVIDER == "cohere":
26
26
  try:
27
27
  from haiku.rag.reranking.cohere import CohereReranker
28
+
29
+ _reranker = CohereReranker()
30
+ return _reranker
28
31
  except ImportError:
29
- raise ImportError(
30
- "Cohere reranker requires the 'cohere' package. "
31
- "Please install haiku.rag with the 'cohere' extra:"
32
- "uv pip install haiku.rag[cohere]"
33
- )
34
- _reranker = CohereReranker()
32
+ return None
33
+
34
+ if Config.RERANK_PROVIDER == "ollama":
35
+ from haiku.rag.reranking.ollama import OllamaReranker
36
+
37
+ _reranker = OllamaReranker()
35
38
  return _reranker
36
39
 
37
- raise ValueError(f"Unsupported reranker provider: {Config.RERANK_PROVIDER}")
40
+ return None
@@ -0,0 +1,84 @@
1
+ import json
2
+
3
+ from ollama import AsyncClient
4
+ from pydantic import BaseModel
5
+
6
+ from haiku.rag.config import Config
7
+ from haiku.rag.reranking.base import RerankerBase
8
+ from haiku.rag.store.models.chunk import Chunk
9
+
10
+ OLLAMA_OPTIONS = {"temperature": 0.0, "seed": 42, "num_ctx": 16384}
11
+
12
+
13
+ class RerankResult(BaseModel):
14
+ """Individual rerank result with index and relevance score."""
15
+
16
+ index: int
17
+ relevance_score: float
18
+
19
+
20
+ class RerankResponse(BaseModel):
21
+ """Response from the reranking model containing ranked results."""
22
+
23
+ results: list[RerankResult]
24
+
25
+
26
+ class OllamaReranker(RerankerBase):
27
+ def __init__(self, model: str = Config.RERANK_MODEL):
28
+ self._model = model
29
+ self._client = AsyncClient(host=Config.OLLAMA_BASE_URL)
30
+
31
+ async def rerank(
32
+ self, query: str, chunks: list[Chunk], top_n: int = 10
33
+ ) -> list[tuple[Chunk, float]]:
34
+ if not chunks:
35
+ return []
36
+
37
+ documents = []
38
+ for i, chunk in enumerate(chunks):
39
+ documents.append({"index": i, "content": chunk.content})
40
+
41
+ # Create the prompt for reranking
42
+ system_prompt = """You are a document reranking assistant. Given a query and a list of document chunks, you must rank them by relevance to the query.
43
+
44
+ Return your response as a JSON object with a "results" array. Each result should have:
45
+ - "index": the original index of the document (integer)
46
+ - "relevance_score": a score between 0.0 and 1.0 indicating relevance (float, where 1.0 is most relevant)
47
+
48
+ Only return the top documents up to the requested limit, ordered by decreasing relevance score."""
49
+
50
+ documents_text = ""
51
+ for doc in documents:
52
+ documents_text += f"Index {doc['index']}: {doc['content']}\n\n"
53
+
54
+ user_prompt = f"""Query: {query}
55
+
56
+ Documents to rerank:
57
+ {documents_text.strip()}
58
+
59
+ Please rank these documents by relevance to the query and return the top {top_n} results as JSON."""
60
+
61
+ messages = [
62
+ {"role": "system", "content": system_prompt},
63
+ {"role": "user", "content": user_prompt},
64
+ ]
65
+
66
+ try:
67
+ response = await self._client.chat(
68
+ model=self._model,
69
+ messages=messages,
70
+ format=RerankResponse.model_json_schema(),
71
+ options=OLLAMA_OPTIONS,
72
+ )
73
+
74
+ content = response["message"]["content"]
75
+
76
+ parsed_response = RerankResponse.model_validate(json.loads(content))
77
+ return [
78
+ (chunks[result.index], result.relevance_score)
79
+ for result in parsed_response.results[:top_n]
80
+ ]
81
+
82
+ except Exception:
83
+ # Fallback: return chunks in original order with same score
84
+ return [(chunks[i], 1.0) for i in range(min(top_n, len(chunks)))]
@@ -468,3 +468,49 @@ class ChunkRepository(BaseRepository[Chunk]):
468
468
  )
469
469
  for chunk_id, document_id, content, metadata_json, document_uri, document_metadata_json in rows
470
470
  ]
471
+
472
+ async def get_adjacent_chunks(self, chunk: Chunk, num_adjacent: int) -> list[Chunk]:
473
+ """Get adjacent chunks before and after the given chunk within the same document."""
474
+ if self.store._connection is None:
475
+ raise ValueError("Store connection is not available")
476
+ if chunk.document_id is None:
477
+ return []
478
+
479
+ cursor = self.store._connection.cursor()
480
+ chunk_order = chunk.metadata.get("order")
481
+ if chunk_order is None:
482
+ return []
483
+
484
+ # Get adjacent chunks within the same document
485
+ cursor.execute(
486
+ """
487
+ SELECT c.id, c.document_id, c.content, c.metadata, d.uri, d.metadata as document_metadata
488
+ FROM chunks c
489
+ JOIN documents d ON c.document_id = d.id
490
+ WHERE c.document_id = :document_id
491
+ AND JSON_EXTRACT(c.metadata, '$.order') BETWEEN :start_order AND :end_order
492
+ AND c.id != :chunk_id
493
+ ORDER BY JSON_EXTRACT(c.metadata, '$.order')
494
+ """,
495
+ {
496
+ "document_id": chunk.document_id,
497
+ "start_order": max(0, chunk_order - num_adjacent),
498
+ "end_order": chunk_order + num_adjacent,
499
+ "chunk_id": chunk.id,
500
+ },
501
+ )
502
+
503
+ rows = cursor.fetchall()
504
+ return [
505
+ Chunk(
506
+ id=chunk_id,
507
+ document_id=document_id,
508
+ content=content,
509
+ metadata=json.loads(metadata_json) if metadata_json else {},
510
+ document_uri=document_uri,
511
+ document_meta=json.loads(document_metadata_json)
512
+ if document_metadata_json
513
+ else {},
514
+ )
515
+ for chunk_id, document_id, content, metadata_json, document_uri, document_metadata_json in rows
516
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haiku.rag
3
- Version: 0.5.1
3
+ Version: 0.5.4
4
4
  Summary: Retrieval Augmented Generation (RAG) with SQLite
5
5
  Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
6
6
  License: MIT
@@ -21,8 +21,7 @@ Requires-Python: >=3.11
21
21
  Requires-Dist: docling>=2.15.0
22
22
  Requires-Dist: fastmcp>=2.8.1
23
23
  Requires-Dist: httpx>=0.28.1
24
- Requires-Dist: mxbai-rerank>=0.1.6
25
- Requires-Dist: ollama>=0.5.1
24
+ Requires-Dist: ollama>=0.5.3
26
25
  Requires-Dist: pydantic>=2.11.7
27
26
  Requires-Dist: python-dotenv>=1.1.0
28
27
  Requires-Dist: rich>=14.0.0
@@ -34,6 +33,8 @@ Provides-Extra: anthropic
34
33
  Requires-Dist: anthropic>=0.56.0; extra == 'anthropic'
35
34
  Provides-Extra: cohere
36
35
  Requires-Dist: cohere>=5.16.1; extra == 'cohere'
36
+ Provides-Extra: mxbai
37
+ Requires-Dist: mxbai-rerank>=0.1.6; extra == 'mxbai'
37
38
  Provides-Extra: openai
38
39
  Requires-Dist: openai>=1.0.0; extra == 'openai'
39
40
  Provides-Extra: voyageai
@@ -75,6 +76,9 @@ haiku-rag search "query"
75
76
  # Ask questions
76
77
  haiku-rag ask "Who is the author of haiku.rag?"
77
78
 
79
+ # Ask questions with citations
80
+ haiku-rag ask "Who is the author of haiku.rag?" --cite
81
+
78
82
  # Rebuild database (re-chunk and re-embed all documents)
79
83
  haiku-rag rebuild
80
84
 
@@ -100,6 +104,10 @@ async with HaikuRAG("database.db") as client:
100
104
  # Ask questions
101
105
  answer = await client.ask("Who is the author of haiku.rag?")
102
106
  print(answer)
107
+
108
+ # Ask questions with citations
109
+ answer = await client.ask("Who is the author of haiku.rag?", cite=True)
110
+ print(answer)
103
111
  ```
104
112
 
105
113
  ## MCP Server
@@ -1,9 +1,9 @@
1
1
  haiku/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- haiku/rag/app.py,sha256=FpLVyP1-zAq_XPmU8CPVLkuIAeuhBOGvMqhYS8RbN40,7649
2
+ haiku/rag/app.py,sha256=k45EOz-rbYg_8RSII3btqsZo2TpGqj3ysamFehhaCGo,7673
3
3
  haiku/rag/chunker.py,sha256=PVe6ysv8UlacUd4Zb3_8RFWIaWDXnzBAy2VDJ4TaUsE,1555
4
- haiku/rag/cli.py,sha256=rk4uUwN_FdMC-rai9_R2sgXXMI3TIWKRtdWWHg_WoWM,5865
5
- haiku/rag/client.py,sha256=pFcrPkQo1h1zJ76jts-72goP_kGVtnJNfLuoT8qpsb8,15795
6
- haiku/rag/config.py,sha256=8mlQ8gYFxxq1q9gi9tjY9StjqhfhiHkO1FvS4b0et0E,1633
4
+ haiku/rag/cli.py,sha256=mGpdnEH8rS-rZLGmE4MbcDci8uexci7UkGTdCxrz1Lg,5987
5
+ haiku/rag/client.py,sha256=CTc4OJ-rnAI3pcjQgazK7B06wkNLP6wYXD1spQtXXzg,20961
6
+ haiku/rag/config.py,sha256=oLrmwGp1OjcKPpJFnf9GgTpoBSOXalFWO6PCKFwQe0w,1615
7
7
  haiku/rag/logging.py,sha256=zTTGpGq5tPdcd7RpCbd9EGw1IZlQDbYkrCg9t9pqRc4,580
8
8
  haiku/rag/mcp.py,sha256=tMN6fNX7ZtAER1R6DL1GkC9HZozTC4HzuQs199p7icI,4551
9
9
  haiku/rag/monitor.py,sha256=r386nkhdlsU8UECwIuVwnrSlgMk3vNIuUZGNIzkZuec,2770
@@ -14,16 +14,17 @@ haiku/rag/embeddings/base.py,sha256=NTQvuzbZPu0LBo5wAu3qGyJ4xXUaRAt1fjBO0ygWn_Y,
14
14
  haiku/rag/embeddings/ollama.py,sha256=y6-lp0XpbnyIjoOEdtSzMdEVkU5glOwnWQ1FkpUZnpI,370
15
15
  haiku/rag/embeddings/openai.py,sha256=i4Ui5hAJkcKqJkH9L3jJo7fuGYHn07td532w-ksg_T8,431
16
16
  haiku/rag/embeddings/voyageai.py,sha256=0hiRTIqu-bpl-4OaCtMHvWfPdgbrzhnfZJowSV8pLRA,415
17
- haiku/rag/qa/__init__.py,sha256=f9ZU7YDzJJoyglV1hGja1j9B6NcWerAImuKO1gFP-qs,1487
18
- haiku/rag/qa/anthropic.py,sha256=6I6cf6ySNkYbmDFdy22sA8r3GO5moiiH75tJnHcgJQA,4448
19
- haiku/rag/qa/base.py,sha256=4ZTM_l5FAZ9cA0f8NeqRJiUAmjatwCTmSoclFw0gTFQ,1349
20
- haiku/rag/qa/ollama.py,sha256=EGUi4urSx9nrnsr5j-qHVDVOnvRTbSMKUbMvXEMIcxM,2381
21
- haiku/rag/qa/openai.py,sha256=dF32sGgVt8mZi5oVxByaeECs9NqLjvDiZnnpJBsrHm8,3968
22
- haiku/rag/qa/prompts.py,sha256=8uYMxHzbzI9vo2FPkCSSNTh_RNL96WkBbUWPCMBlLpo,1315
23
- haiku/rag/reranking/__init__.py,sha256=DsPCdU94wRzDCYl6hz2DySOMWwOvNxKviqKAUfyykK8,1118
17
+ haiku/rag/qa/__init__.py,sha256=vC9S6cvZtPz-UfA_v4DMwI7eam6567BXNrUwHsMo_i8,1633
18
+ haiku/rag/qa/anthropic.py,sha256=o0RVn7lcdYvoCUGXh551jeuoB3ANJSZ7uz2R_h_pZ2w,4321
19
+ haiku/rag/qa/base.py,sha256=dCX14ifJW4QMCNFP_pmss9SYWM9Qm1cSWZrMl6A-2C8,3541
20
+ haiku/rag/qa/ollama.py,sha256=3T9ciKWpCIY7jejvdrsMC_wIvGRWQEWA0AwKjOlX35M,2131
21
+ haiku/rag/qa/openai.py,sha256=4BFc8pzFI-CTDxxKMskMxMKkacvUoRTVWI8kKntl3Jw,3718
22
+ haiku/rag/qa/prompts.py,sha256=WTA66brySfzIkuDZ_hRQQKGx12ngIu9nUDKMNGg2-Bg,3321
23
+ haiku/rag/reranking/__init__.py,sha256=fwC3pauteJwh9Ulm2270QvwAdwr4NMr4RUEuolC-wKU,1063
24
24
  haiku/rag/reranking/base.py,sha256=LM9yUSSJ414UgBZhFTgxGprlRqzfTe4I1vgjricz2JY,405
25
25
  haiku/rag/reranking/cohere.py,sha256=1iTdiaa8vvb6oHVB2qpWzUOVkyfUcimVSZp6Qr4aq4c,1049
26
26
  haiku/rag/reranking/mxbai.py,sha256=46sVTsTIkzIX9THgM3u8HaEmgY7evvEyB-N54JTHvK8,867
27
+ haiku/rag/reranking/ollama.py,sha256=tCrLlNNDBCZu7J3to1gvBq-sOvN1flYEA7E3H3Jq0mU,2790
27
28
  haiku/rag/store/__init__.py,sha256=hq0W0DAC7ysqhWSP2M2uHX8cbG6kbr-sWHxhq6qQcY0,103
28
29
  haiku/rag/store/engine.py,sha256=cOMBToLilI1Di1qQrFzGLqtRMsuvtiX0Q5RNIEzQy9w,6232
29
30
  haiku/rag/store/models/__init__.py,sha256=s0E72zneGlowvZrFWaNxHYjOAUjgWdLxzdYsnvNRVlY,88
@@ -31,13 +32,13 @@ haiku/rag/store/models/chunk.py,sha256=9-vIxW75-kMTelIhgVIMd_WhP-Drc1q65vjaWMP8w
31
32
  haiku/rag/store/models/document.py,sha256=TVXVY-nQs-1vCORQEs9rA7zOtndeGC4dgCoujLAS054,396
32
33
  haiku/rag/store/repositories/__init__.py,sha256=uIBhxjQh-4o3O-ck8b7BQ58qXQTuJdPvrDIHVhY5T1A,263
33
34
  haiku/rag/store/repositories/base.py,sha256=cm3VyQXhtxvRfk1uJHpA0fDSxMpYN-mjQmRiDiLsQ68,1008
34
- haiku/rag/store/repositories/chunk.py,sha256=DIIdpHVemvxZOPHOLBL7pJGWY4VyNrUiQSWPWt24BYo,16974
35
+ haiku/rag/store/repositories/chunk.py,sha256=R8dvNy3po2FspZvkWKZTGlqccbekLjY39GroXRfAU18,18808
35
36
  haiku/rag/store/repositories/document.py,sha256=ki8LiDukwU1469Yw51i0rQFvBzUQeYkFYWs3Ly83akc,8815
36
37
  haiku/rag/store/repositories/settings.py,sha256=qZLXvLsErnCWL0nBQQNfRnatHzCKhtUDLvUK9k-W_fU,2463
37
38
  haiku/rag/store/upgrades/__init__.py,sha256=kKS1YWT_P-CYKhKtokOLTIFNKf9jlfjFFr8lyIMeogM,100
38
39
  haiku/rag/store/upgrades/v0_3_4.py,sha256=GLogKZdZ40NX1vBHKdOJju7fFzNUCHoEnjSZg17Hm2U,663
39
- haiku_rag-0.5.1.dist-info/METADATA,sha256=X4r-1CBCTef3_T9HWPgCHi5XumqOSF4tlHfUpxO533E,4198
40
- haiku_rag-0.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
41
- haiku_rag-0.5.1.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
42
- haiku_rag-0.5.1.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
43
- haiku_rag-0.5.1.dist-info/RECORD,,
40
+ haiku_rag-0.5.4.dist-info/METADATA,sha256=hUovrigbcJX6I3vewMVXut3QaI-PXe5BiDzs84noBts,4455
41
+ haiku_rag-0.5.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
42
+ haiku_rag-0.5.4.dist-info/entry_points.txt,sha256=G1U3nAkNd5YDYd4v0tuYFbriz0i-JheCsFuT9kIoGCI,48
43
+ haiku_rag-0.5.4.dist-info/licenses/LICENSE,sha256=eXZrWjSk9PwYFNK9yUczl3oPl95Z4V9UXH7bPN46iPo,1065
44
+ haiku_rag-0.5.4.dist-info/RECORD,,