dao-ai 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ """
2
+ Instruction-aware reranker for constraint-based document reordering.
3
+
4
+ Runs after FlashRank to apply user instructions and constraints to the ranking.
5
+ General-purpose component usable with any retrieval strategy.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import mlflow
12
+ import yaml
13
+ from langchain_core.documents import Document
14
+ from langchain_core.language_models import BaseChatModel
15
+ from loguru import logger
16
+ from mlflow.entities import SpanType
17
+
18
+ from dao_ai.config import ColumnInfo, RankingResult
19
+
20
+ # Load prompt template
21
+ _PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "instruction_reranker.yaml"
22
+
23
+
24
+ def _load_prompt_template() -> dict[str, Any]:
25
+ """Load the instruction reranker prompt template from YAML."""
26
+ with open(_PROMPT_PATH) as f:
27
+ return yaml.safe_load(f)
28
+
29
+
30
+ def _format_documents(documents: list[Document]) -> str:
31
+ """Format documents for the reranking prompt."""
32
+ if not documents:
33
+ return "No documents to rerank."
34
+
35
+ formatted = []
36
+ for i, doc in enumerate(documents):
37
+ metadata_str = ", ".join(
38
+ f"{k}: {v}"
39
+ for k, v in doc.metadata.items()
40
+ if not k.startswith("_") and k not in ("rrf_score", "reranker_score")
41
+ )
42
+ content_preview = (
43
+ doc.page_content[:300] + "..."
44
+ if len(doc.page_content) > 300
45
+ else doc.page_content
46
+ )
47
+ formatted.append(
48
+ f"[{i}] Content: {content_preview}\n Metadata: {metadata_str}"
49
+ )
50
+
51
+ return "\n\n".join(formatted)
52
+
53
+
54
+ def _format_column_info(columns: list[ColumnInfo] | None) -> str:
55
+ """Format column info for the reranking prompt."""
56
+ if not columns:
57
+ return ""
58
+ return ", ".join(f"{c.name} ({c.type})" for c in columns)
59
+
60
+
61
+ @mlflow.trace(name="instruction_aware_rerank", span_type=SpanType.LLM)
62
+ def instruction_aware_rerank(
63
+ llm: BaseChatModel,
64
+ query: str,
65
+ documents: list[Document],
66
+ instructions: str | None = None,
67
+ schema_description: str | None = None,
68
+ columns: list[ColumnInfo] | None = None,
69
+ top_n: int | None = None,
70
+ ) -> list[Document]:
71
+ """
72
+ Rerank documents based on user instructions and constraints.
73
+
74
+ Args:
75
+ llm: Language model for reranking
76
+ query: User's search query
77
+ documents: Documents to rerank (typically FlashRank output)
78
+ instructions: Custom reranking instructions
79
+ schema_description: Column names and types for context
80
+ columns: Structured column info for dynamic instruction generation
81
+ top_n: Number of documents to return (None = all scored documents)
82
+
83
+ Returns:
84
+ Reranked documents with instruction_rerank_score in metadata
85
+ """
86
+ if not documents:
87
+ return []
88
+
89
+ prompt_config = _load_prompt_template()
90
+ prompt_template = prompt_config["template"]
91
+
92
+ # Build dynamic default instructions based on columns
93
+ if columns:
94
+ column_names = ", ".join(c.name for c in columns)
95
+ default_instructions = (
96
+ f"Prioritize results that best match the user's explicit constraints "
97
+ f"on these columns: {column_names}. Prefer more specific matches over general results."
98
+ )
99
+ else:
100
+ default_instructions = (
101
+ "Prioritize results that best match the user's explicit constraints. "
102
+ "Prefer more specific matches over general results."
103
+ )
104
+
105
+ # Build effective instructions - use columns for context (ignore verbose schema_description)
106
+ effective_instructions = instructions or default_instructions
107
+
108
+ # Add column context if available (simpler than full schema_description)
109
+ if columns:
110
+ effective_instructions += (
111
+ f"\n\nAvailable metadata fields: {_format_column_info(columns)}"
112
+ )
113
+
114
+ prompt = prompt_template.format(
115
+ query=query,
116
+ instructions=effective_instructions,
117
+ documents=_format_documents(documents),
118
+ )
119
+
120
+ logger.trace("Instruction reranking", query=query[:100], num_docs=len(documents))
121
+
122
+ logger.debug(
123
+ "Invoking structured output for reranking",
124
+ query=query[:50],
125
+ num_docs=len(documents),
126
+ prompt_length=len(prompt),
127
+ )
128
+
129
+ try:
130
+ structured_llm = llm.with_structured_output(RankingResult)
131
+ result: RankingResult = structured_llm.invoke(prompt)
132
+ logger.debug(
133
+ "Structured output succeeded",
134
+ num_rankings=len(result.rankings),
135
+ )
136
+ except Exception as e:
137
+ logger.warning(
138
+ "Structured output invocation failed",
139
+ error=str(e),
140
+ query=query[:50],
141
+ )
142
+ result = None
143
+ if result is None or not result.rankings:
144
+ logger.warning(
145
+ "Failed to get structured output from reranker, returning original order",
146
+ query=query[:50],
147
+ )
148
+ # Return fallback with decreasing scores based on original order
149
+ return [
150
+ Document(
151
+ page_content=doc.page_content,
152
+ metadata={
153
+ **doc.metadata,
154
+ "instruction_rerank_score": 1.0 - (i / len(documents)),
155
+ "instruction_rerank_reason": "fallback: extraction failed",
156
+ },
157
+ )
158
+ for i, doc in enumerate(documents[:top_n] if top_n else documents)
159
+ ]
160
+
161
+ # Build reranked document list
162
+ reranked: list[Document] = []
163
+ for ranking in result.rankings:
164
+ if ranking.index < 0 or ranking.index >= len(documents):
165
+ logger.warning("Invalid document index from reranker", index=ranking.index)
166
+ continue
167
+
168
+ original_doc = documents[ranking.index]
169
+ reranked_doc = Document(
170
+ page_content=original_doc.page_content,
171
+ metadata={
172
+ **original_doc.metadata,
173
+ "instruction_rerank_score": ranking.score,
174
+ "instruction_rerank_reason": ranking.reason,
175
+ },
176
+ )
177
+ reranked.append(reranked_doc)
178
+
179
+ # Sort by score (highest first) - don't rely on LLM to sort
180
+ reranked.sort(
181
+ key=lambda d: d.metadata.get("instruction_rerank_score", 0),
182
+ reverse=True,
183
+ )
184
+
185
+ # Apply top_n limit after sorting
186
+ if top_n is not None and len(reranked) > top_n:
187
+ reranked = reranked[:top_n]
188
+
189
+ # Calculate and log average score
190
+ if reranked:
191
+ avg_score = sum(
192
+ d.metadata.get("instruction_rerank_score", 0) for d in reranked
193
+ ) / len(reranked)
194
+ mlflow.set_tag("reranker.instruction_avg_score", f"{avg_score:.3f}")
195
+
196
+ logger.debug(
197
+ "Instruction reranking complete",
198
+ input_count=len(documents),
199
+ output_count=len(reranked),
200
+ )
201
+
202
+ return reranked
dao_ai/tools/router.py ADDED
@@ -0,0 +1,89 @@
1
+ """
2
+ Query router for selecting execution mode based on query characteristics.
3
+
4
+ Routes to internal execution modes within the same retriever instance:
5
+ - standard: Single similarity_search for simple queries
6
+ - instructed: Decompose -> Parallel Search -> RRF for constrained queries
7
+ """
8
+
9
+ from pathlib import Path
10
+ from typing import Any, Literal
11
+
12
+ import mlflow
13
+ import yaml
14
+ from langchain_core.language_models import BaseChatModel
15
+ from langchain_core.runnables import Runnable
16
+ from loguru import logger
17
+ from mlflow.entities import SpanType
18
+ from pydantic import BaseModel, ConfigDict, Field
19
+
20
+ # Load prompt template
21
+ _PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "router.yaml"
22
+
23
+
24
+ def _load_prompt_template() -> dict[str, Any]:
25
+ """Load the router prompt template from YAML."""
26
+ with open(_PROMPT_PATH) as f:
27
+ return yaml.safe_load(f)
28
+
29
+
30
+ class RouterDecision(BaseModel):
31
+ """Classification of a search query into an execution mode.
32
+
33
+ Analyze whether the query contains explicit constraints that map to
34
+ filterable metadata columns, or is a simple semantic search.
35
+ """
36
+
37
+ model_config = ConfigDict(extra="forbid")
38
+ mode: Literal["standard", "instructed"] = Field(
39
+ description=(
40
+ "The execution mode. "
41
+ "Use 'standard' for simple semantic searches without constraints. "
42
+ "Use 'instructed' when the query contains explicit constraints "
43
+ "that can be translated to metadata filters."
44
+ )
45
+ )
46
+
47
+
48
+ @mlflow.trace(name="route_query", span_type=SpanType.LLM)
49
+ def route_query(
50
+ llm: BaseChatModel,
51
+ query: str,
52
+ schema_description: str,
53
+ ) -> Literal["standard", "instructed"]:
54
+ """
55
+ Determine the execution mode for a search query.
56
+
57
+ Args:
58
+ llm: Language model for routing decision
59
+ query: User's search query
60
+ schema_description: Column names, types, and filter syntax
61
+
62
+ Returns:
63
+ "standard" for simple queries, "instructed" for constrained queries
64
+ """
65
+ prompt_config = _load_prompt_template()
66
+ prompt_template = prompt_config["template"]
67
+
68
+ prompt = prompt_template.format(
69
+ schema_description=schema_description,
70
+ query=query,
71
+ )
72
+
73
+ logger.trace("Routing query", query=query[:100])
74
+
75
+ # Use LangChain's with_structured_output for automatic strategy selection
76
+ # (JSON schema vs tool calling based on model capabilities)
77
+ try:
78
+ structured_llm: Runnable[str, RouterDecision] = llm.with_structured_output(
79
+ RouterDecision
80
+ )
81
+ decision: RouterDecision = structured_llm.invoke(prompt)
82
+ except Exception as e:
83
+ logger.warning("Router failed, defaulting to standard mode", error=str(e))
84
+ return "standard"
85
+
86
+ logger.debug("Router decision", mode=decision.mode, query=query[:50])
87
+ mlflow.set_tag("router.mode", decision.mode)
88
+
89
+ return decision.mode