dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py CHANGED
@@ -6,7 +6,7 @@ interact with Databricks Genie.
6
6
 
7
7
  For the core Genie service and cache implementations, see:
8
8
  - dao_ai.genie: GenieService, GenieServiceBase
9
- - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
9
+ - dao_ai.genie.cache: LRUCacheService, PostgresContextAwareGenieService, InMemoryContextAwareGenieService
10
10
  """
11
11
 
12
12
  import json
@@ -25,13 +25,19 @@ from pydantic import BaseModel
25
25
  from dao_ai.config import (
26
26
  AnyVariable,
27
27
  CompositeVariableModel,
28
+ GenieContextAwareCacheParametersModel,
29
+ GenieInMemorySemanticCacheParametersModel,
28
30
  GenieLRUCacheParametersModel,
29
31
  GenieRoomModel,
30
- GenieSemanticCacheParametersModel,
31
32
  value_of,
32
33
  )
33
34
  from dao_ai.genie import GenieService, GenieServiceBase
34
- from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
35
+ from dao_ai.genie.cache import (
36
+ CacheResult,
37
+ InMemoryContextAwareGenieService,
38
+ LRUCacheService,
39
+ PostgresContextAwareGenieService,
40
+ )
35
41
  from dao_ai.state import AgentState, Context, SessionState
36
42
 
37
43
 
@@ -64,7 +70,10 @@ def create_genie_tool(
64
70
  persist_conversation: bool = True,
65
71
  truncate_results: bool = False,
66
72
  lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
67
- semantic_cache_parameters: GenieSemanticCacheParametersModel
73
+ semantic_cache_parameters: GenieContextAwareCacheParametersModel
74
+ | dict[str, Any]
75
+ | None = None,
76
+ in_memory_semantic_cache_parameters: GenieInMemorySemanticCacheParametersModel
68
77
  | dict[str, Any]
69
78
  | None = None,
70
79
  ) -> Callable[..., Command]:
@@ -84,7 +93,9 @@ def create_genie_tool(
84
93
  truncate_results: Whether to truncate large query results to fit token limits
85
94
  lru_cache_parameters: Optional LRU cache configuration for SQL query caching
86
95
  semantic_cache_parameters: Optional semantic cache configuration using pg_vector
87
- for similarity-based query matching
96
+ for similarity-based query matching (requires PostgreSQL/Lakebase)
97
+ in_memory_semantic_cache_parameters: Optional in-memory semantic cache configuration
98
+ for similarity-based query matching (no database required)
88
99
 
89
100
  Returns:
90
101
  A LangGraph tool that processes natural language queries through Genie
@@ -97,6 +108,7 @@ def create_genie_tool(
97
108
  name=name,
98
109
  has_lru_cache=lru_cache_parameters is not None,
99
110
  has_semantic_cache=semantic_cache_parameters is not None,
111
+ has_in_memory_semantic_cache=in_memory_semantic_cache_parameters is not None,
100
112
  )
101
113
 
102
114
  if isinstance(genie_room, dict):
@@ -106,10 +118,15 @@ def create_genie_tool(
106
118
  lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
107
119
 
108
120
  if isinstance(semantic_cache_parameters, dict):
109
- semantic_cache_parameters = GenieSemanticCacheParametersModel(
121
+ semantic_cache_parameters = GenieContextAwareCacheParametersModel(
110
122
  **semantic_cache_parameters
111
123
  )
112
124
 
125
+ if isinstance(in_memory_semantic_cache_parameters, dict):
126
+ in_memory_semantic_cache_parameters = GenieInMemorySemanticCacheParametersModel(
127
+ **in_memory_semantic_cache_parameters
128
+ )
129
+
113
130
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
114
131
  "DATABRICKS_GENIE_SPACE_ID"
115
132
  )
@@ -139,29 +156,61 @@ Returns:
139
156
  GenieResponse: A response object containing the conversation ID and result from Genie."""
140
157
  tool_description = tool_description + function_docs
141
158
 
142
- genie: Genie = Genie(
143
- space_id=space_id,
144
- client=genie_room.workspace_client,
145
- truncate_results=truncate_results,
146
- )
159
+ # Cache for genie service - created lazily on first call
160
+ # This allows us to use workspace_client_from with runtime context for OBO
161
+ _cached_genie_service: GenieServiceBase | None = None
162
+
163
+ def _get_genie_service(context: Context | None) -> GenieServiceBase:
164
+ """Get or create the Genie service, using context for OBO auth if available."""
165
+ nonlocal _cached_genie_service
166
+
167
+ # Use cached service if available (for non-OBO or after first call)
168
+ # For OBO, we need fresh workspace client each time to use the user's token
169
+ if _cached_genie_service is not None and not genie_room.on_behalf_of_user:
170
+ return _cached_genie_service
147
171
 
148
- genie_service: GenieServiceBase = GenieService(genie)
149
-
150
- # Wrap with semantic cache first (checked second due to decorator pattern)
151
- if semantic_cache_parameters is not None:
152
- genie_service = SemanticCacheService(
153
- impl=genie_service,
154
- parameters=semantic_cache_parameters,
155
- workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
- ).initialize() # Eagerly initialize to fail fast and create table
157
-
158
- # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
- if lru_cache_parameters is not None:
160
- genie_service = LRUCacheService(
161
- impl=genie_service,
162
- parameters=lru_cache_parameters,
172
+ # Get workspace client using context for OBO support
173
+ from databricks.sdk import WorkspaceClient
174
+
175
+ workspace_client: WorkspaceClient = genie_room.workspace_client_from(context)
176
+
177
+ genie: Genie = Genie(
178
+ space_id=space_id,
179
+ client=workspace_client,
180
+ truncate_results=truncate_results,
163
181
  )
164
182
 
183
+ genie_service: GenieServiceBase = GenieService(genie)
184
+
185
+ # Wrap with context-aware cache first (checked second/third due to decorator pattern)
186
+ if semantic_cache_parameters is not None:
187
+ genie_service = PostgresContextAwareGenieService(
188
+ impl=genie_service,
189
+ parameters=semantic_cache_parameters,
190
+ workspace_client=workspace_client,
191
+ ).initialize()
192
+
193
+ # Wrap with in-memory context-aware cache (alternative to PostgreSQL context-aware cache)
194
+ if in_memory_semantic_cache_parameters is not None:
195
+ genie_service = InMemoryContextAwareGenieService(
196
+ impl=genie_service,
197
+ parameters=in_memory_semantic_cache_parameters,
198
+ workspace_client=workspace_client,
199
+ ).initialize()
200
+
201
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
202
+ if lru_cache_parameters is not None:
203
+ genie_service = LRUCacheService(
204
+ impl=genie_service,
205
+ parameters=lru_cache_parameters,
206
+ )
207
+
208
+ # Cache for non-OBO scenarios
209
+ if not genie_room.on_behalf_of_user:
210
+ _cached_genie_service = genie_service
211
+
212
+ return genie_service
213
+
165
214
  @tool(
166
215
  name_or_callable=tool_name,
167
216
  description=tool_description,
@@ -177,6 +226,10 @@ GenieResponse: A response object containing the conversation ID and result from
177
226
  # Access state through runtime
178
227
  state: AgentState = runtime.state
179
228
  tool_call_id: str = runtime.tool_call_id
229
+ context: Context | None = runtime.context
230
+
231
+ # Get genie service with OBO support via context
232
+ genie_service: GenieServiceBase = _get_genie_service(context)
180
233
 
181
234
  # Ensure space_id is a string for state keys
182
235
  space_id_str: str = str(space_id)
@@ -194,6 +247,14 @@ GenieResponse: A response object containing the conversation ID and result from
194
247
  conversation_id=existing_conversation_id,
195
248
  )
196
249
 
250
+ # Log the prompt being sent to Genie
251
+ logger.trace(
252
+ "Sending prompt to Genie",
253
+ space_id=space_id_str,
254
+ conversation_id=existing_conversation_id,
255
+ prompt=question[:500] + "..." if len(question) > 500 else question,
256
+ )
257
+
197
258
  # Call ask_question which always returns CacheResult with cache metadata
198
259
  cache_result: CacheResult = genie_service.ask_question(
199
260
  question, conversation_id=existing_conversation_id
@@ -211,6 +272,22 @@ GenieResponse: A response object containing the conversation ID and result from
211
272
  cache_key=cache_key,
212
273
  )
213
274
 
275
+ # Log truncated response for debugging
276
+ result_preview: str = str(genie_response.result)
277
+ if len(result_preview) > 500:
278
+ result_preview = result_preview[:500] + "..."
279
+ logger.trace(
280
+ "Genie response content",
281
+ question=question[:100] + "..." if len(question) > 100 else question,
282
+ query=genie_response.query,
283
+ description=(
284
+ genie_response.description[:200] + "..."
285
+ if genie_response.description and len(genie_response.description) > 200
286
+ else genie_response.description
287
+ ),
288
+ result_preview=result_preview,
289
+ )
290
+
214
291
  # Update session state with cache information
215
292
  if persist_conversation:
216
293
  session.genie.update_space(
@@ -0,0 +1,366 @@
1
+ """
2
+ Instructed retriever for query decomposition and result fusion.
3
+
4
+ This module provides functions for decomposing user queries into multiple
5
+ subqueries with metadata filters and merging results using Reciprocal Rank Fusion.
6
+ """
7
+
8
+ import json
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any, Optional, Union
12
+
13
+ import mlflow
14
+ import yaml
15
+ from langchain_core.documents import Document
16
+ from langchain_core.language_models import BaseChatModel
17
+ from langchain_core.runnables import Runnable
18
+ from loguru import logger
19
+ from mlflow.entities import SpanType
20
+ from pydantic import BaseModel, ConfigDict, Field
21
+
22
+ from dao_ai.config import (
23
+ ColumnInfo,
24
+ DecomposedQueries,
25
+ FilterItem,
26
+ LLMModel,
27
+ SearchQuery,
28
+ )
29
+
30
+ # Module-level cache for LLM clients
31
+ _llm_cache: dict[str, BaseChatModel] = {}
32
+
33
+ # Load prompt template
34
+ _PROMPT_PATH = (
35
+ Path(__file__).parent.parent / "prompts" / "instructed_retriever_decomposition.yaml"
36
+ )
37
+
38
+
39
+ def _load_prompt_template() -> dict[str, Any]:
40
+ """Load the decomposition prompt template from YAML."""
41
+ with open(_PROMPT_PATH) as f:
42
+ return yaml.safe_load(f)
43
+
44
+
45
+ def _get_cached_llm(model_config: LLMModel) -> BaseChatModel:
46
+ """
47
+ Get or create cached LLM client for decomposition.
48
+
49
+ Uses full config as cache key to avoid collisions when same model name
50
+ has different parameters (temperature, API keys, etc.).
51
+ """
52
+ cache_key = model_config.model_dump_json()
53
+ if cache_key not in _llm_cache:
54
+ _llm_cache[cache_key] = model_config.as_chat_model()
55
+ logger.debug(
56
+ "Created new LLM client for decomposition", model=model_config.name
57
+ )
58
+ return _llm_cache[cache_key]
59
+
60
+
61
+ def _format_constraints(constraints: list[str] | None) -> str:
62
+ """Format constraints list for prompt injection."""
63
+ if not constraints:
64
+ return "No additional constraints."
65
+ return "\n".join(f"- {c}" for c in constraints)
66
+
67
+
68
+ def _format_examples(examples: list[dict[str, Any]] | None) -> str:
69
+ """Format few-shot examples for prompt injection.
70
+
71
+ Converts dict-style filters from config to FilterItem array format
72
+ to match the expected JSON schema output.
73
+ """
74
+ if not examples:
75
+ return "No examples provided."
76
+
77
+ formatted = []
78
+ for i, ex in enumerate(examples, 1):
79
+ query = ex.get("query", "")
80
+ filters = ex.get("filters", {})
81
+ # Convert dict to FilterItem array format
82
+ filter_items = [{"key": k, "value": v} for k, v in filters.items()]
83
+ formatted.append(
84
+ f'Example {i}:\n Query: "{query}"\n Filters: {json.dumps(filter_items)}'
85
+ )
86
+ return "\n".join(formatted)
87
+
88
+
89
+ def create_decomposition_schema(
90
+ columns: list[ColumnInfo] | None = None,
91
+ ) -> type[BaseModel]:
92
+ """Create schema-aware DecomposedQueries model with dynamic descriptions.
93
+
94
+ When columns are provided, the column names and valid operators are embedded
95
+ directly into the JSON schema that with_structured_output sends to the LLM.
96
+ This improves accuracy by making valid filter keys explicit in the schema.
97
+
98
+ Args:
99
+ columns: List of column metadata for dynamic schema generation
100
+
101
+ Returns:
102
+ A DecomposedQueries-compatible Pydantic model class
103
+ """
104
+ if not columns:
105
+ # Fall back to generic models
106
+ return DecomposedQueries
107
+
108
+ # Build column info with types for the schema description
109
+ column_info = ", ".join(f"{c.name} ({c.type})" for c in columns)
110
+
111
+ # Build operator list from column definitions (union of all column operators)
112
+ all_operators: set[str] = set()
113
+ for col in columns:
114
+ all_operators.update(col.operators)
115
+ # Remove empty string (equality) and sort for consistent output
116
+ named_operators = sorted(all_operators - {""})
117
+ operator_list = ", ".join(named_operators) if named_operators else "equality only"
118
+
119
+ # Build valid key examples with operators
120
+ key_examples: list[str] = []
121
+ for col in columns[:3]: # Show examples for first 3 columns
122
+ key_examples.append(f"'{col.name}'")
123
+ if "<" in col.operators:
124
+ key_examples.append(f"'{col.name} <'")
125
+ if "NOT" in col.operators:
126
+ key_examples.append(f"'{col.name} NOT'")
127
+
128
+ # Create dynamic FilterItem with schema-aware description
129
+ class SchemaFilterItem(BaseModel):
130
+ """A metadata filter for vector search with schema-specific columns."""
131
+
132
+ model_config = ConfigDict(extra="forbid")
133
+ key: str = Field(
134
+ description=(
135
+ f"Column name with optional operator suffix. "
136
+ f"Valid columns: {column_info}. "
137
+ f"Operators: (none) for equality, {operator_list}. "
138
+ f"Examples: {', '.join(key_examples[:5])}"
139
+ )
140
+ )
141
+ value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
142
+ description="The filter value matching the column type."
143
+ )
144
+
145
+ # Create dynamic SearchQuery using SchemaFilterItem
146
+ class SchemaSearchQuery(BaseModel):
147
+ """A search query with schema-aware filters."""
148
+
149
+ model_config = ConfigDict(extra="forbid")
150
+ text: str = Field(
151
+ description=(
152
+ "Natural language search query text optimized for semantic similarity. "
153
+ "Should be focused on a single search intent. "
154
+ "Do NOT include filter criteria in the text; use the filters field instead."
155
+ )
156
+ )
157
+ filters: Optional[list[SchemaFilterItem]] = Field(
158
+ default=None,
159
+ description=(
160
+ f"Metadata filters to constrain search results. "
161
+ f"Valid filter columns: {column_info}. "
162
+ f"Set to null if no filters apply."
163
+ ),
164
+ )
165
+
166
+ # Create dynamic DecomposedQueries using SchemaSearchQuery
167
+ class SchemaDecomposedQueries(BaseModel):
168
+ """Decomposed search queries with schema-aware filters."""
169
+
170
+ model_config = ConfigDict(extra="forbid")
171
+ queries: list[SchemaSearchQuery] = Field(
172
+ description=(
173
+ "List of search queries extracted from the user request. "
174
+ "Each query should target a distinct search intent. "
175
+ "Order queries by importance, with the most relevant first."
176
+ )
177
+ )
178
+
179
+ return SchemaDecomposedQueries
180
+
181
+
182
+ @mlflow.trace(name="decompose_query", span_type=SpanType.LLM)
183
+ def decompose_query(
184
+ llm: BaseChatModel,
185
+ query: str,
186
+ schema_description: str,
187
+ constraints: list[str] | None = None,
188
+ max_subqueries: int = 3,
189
+ examples: list[dict[str, Any]] | None = None,
190
+ previous_feedback: str | None = None,
191
+ columns: list[ColumnInfo] | None = None,
192
+ ) -> list[SearchQuery]:
193
+ """
194
+ Decompose a user query into multiple search queries with filters.
195
+
196
+ Uses structured output for reliable parsing and injects current time
197
+ for resolving relative date references. When columns are provided,
198
+ schema-aware Pydantic models are used for improved filter accuracy.
199
+
200
+ Args:
201
+ llm: Language model for decomposition
202
+ query: User's search query
203
+ schema_description: Column names, types, and valid filter syntax
204
+ constraints: Default constraints to apply
205
+ max_subqueries: Maximum number of subqueries to generate
206
+ examples: Few-shot examples for domain-specific filter translation
207
+ previous_feedback: Feedback from failed verification (for retry)
208
+ columns: Structured column info for dynamic schema generation
209
+
210
+ Returns:
211
+ List of SearchQuery objects with text and optional filters
212
+ """
213
+ current_time = datetime.now().isoformat()
214
+
215
+ # Load and format prompt
216
+ prompt_config = _load_prompt_template()
217
+ prompt_template = prompt_config["template"]
218
+
219
+ # Add previous feedback section if provided (for retry)
220
+ feedback_section = ""
221
+ if previous_feedback:
222
+ feedback_section = f"\n\n## Previous Attempt Feedback\nThe previous search attempt failed verification: {previous_feedback}\nAdjust your filters to address this feedback."
223
+
224
+ prompt = (
225
+ prompt_template.format(
226
+ current_time=current_time,
227
+ schema_description=schema_description,
228
+ constraints=_format_constraints(constraints),
229
+ examples=_format_examples(examples),
230
+ max_subqueries=max_subqueries,
231
+ query=query,
232
+ )
233
+ + feedback_section
234
+ )
235
+
236
+ logger.trace(
237
+ "Decomposing query",
238
+ query=query[:100],
239
+ max_subqueries=max_subqueries,
240
+ dynamic_schema=columns is not None,
241
+ )
242
+
243
+ # Create schema-aware model when columns are provided
244
+ DecompositionSchema: type[BaseModel] = create_decomposition_schema(columns)
245
+
246
+ # Use LangChain's with_structured_output for automatic strategy selection
247
+ # (JSON schema vs tool calling based on model capabilities)
248
+ try:
249
+ structured_llm: Runnable[str, BaseModel] = llm.with_structured_output(
250
+ DecompositionSchema
251
+ )
252
+ result: BaseModel = structured_llm.invoke(prompt)
253
+ except Exception as e:
254
+ logger.warning("Query decomposition failed", error=str(e))
255
+ raise
256
+
257
+ # Extract queries from result (works with both static and dynamic schemas)
258
+ subqueries: list[SearchQuery] = []
259
+ for query_obj in result.queries[:max_subqueries]:
260
+ # Convert dynamic schema objects to SearchQuery for consistent return type
261
+ filters: list[FilterItem] | None = None
262
+ if query_obj.filters:
263
+ filters = [FilterItem(key=f.key, value=f.value) for f in query_obj.filters]
264
+ subqueries.append(SearchQuery(text=query_obj.text, filters=filters))
265
+
266
+ # Log for observability
267
+ mlflow.set_tag("num_subqueries", len(subqueries))
268
+ mlflow.log_text(
269
+ json.dumps([sq.model_dump() for sq in subqueries], indent=2),
270
+ "decomposition.json",
271
+ )
272
+
273
+ logger.debug(
274
+ "Query decomposed",
275
+ num_subqueries=len(subqueries),
276
+ queries=[sq.text[:50] for sq in subqueries],
277
+ )
278
+
279
+ return subqueries
280
+
281
+
282
+ def rrf_merge(
283
+ results_lists: list[list[Document]],
284
+ k: int = 60,
285
+ primary_key: str | None = None,
286
+ ) -> list[Document]:
287
+ """
288
+ Merge results from multiple queries using Reciprocal Rank Fusion.
289
+
290
+ RRF is safer than raw score sorting because Databricks Vector Search
291
+ scores aren't normalized across query types (HYBRID vs ANN).
292
+
293
+ RRF Score = Σ 1 / (k + rank_i) for each result list
294
+
295
+ Args:
296
+ results_lists: List of document lists from different subqueries
297
+ k: RRF constant (lower values weight top ranks more heavily)
298
+ primary_key: Metadata key for document identity (for deduplication)
299
+
300
+ Returns:
301
+ Merged and deduplicated documents sorted by RRF score
302
+ """
303
+ if not results_lists:
304
+ return []
305
+
306
+ # Filter empty lists first
307
+ non_empty = [r for r in results_lists if r]
308
+ if not non_empty:
309
+ return []
310
+
311
+ # Single list optimization (still add RRF scores for consistency)
312
+ if len(non_empty) == 1:
313
+ docs_with_scores: list[Document] = []
314
+ for rank, doc in enumerate(non_empty[0]):
315
+ rrf_score = 1.0 / (k + rank + 1)
316
+ docs_with_scores.append(
317
+ Document(
318
+ page_content=doc.page_content,
319
+ metadata={**doc.metadata, "rrf_score": rrf_score},
320
+ )
321
+ )
322
+ return docs_with_scores
323
+
324
+ # Calculate RRF scores
325
+ # Key: document identifier, Value: (total_rrf_score, Document)
326
+ doc_scores: dict[str, tuple[float, Document]] = {}
327
+
328
+ def get_doc_id(doc: Document) -> str:
329
+ """Get unique identifier for document."""
330
+ if primary_key and primary_key in doc.metadata:
331
+ return str(doc.metadata[primary_key])
332
+ # Fallback to content hash
333
+ return str(hash(doc.page_content))
334
+
335
+ for result_list in non_empty:
336
+ for rank, doc in enumerate(result_list):
337
+ doc_id = get_doc_id(doc)
338
+ rrf_score = 1.0 / (k + rank + 1) # rank is 0-indexed
339
+
340
+ if doc_id in doc_scores:
341
+ # Accumulate RRF score for duplicates
342
+ existing_score, existing_doc = doc_scores[doc_id]
343
+ doc_scores[doc_id] = (existing_score + rrf_score, existing_doc)
344
+ else:
345
+ doc_scores[doc_id] = (rrf_score, doc)
346
+
347
+ # Sort by RRF score descending
348
+ sorted_docs = sorted(doc_scores.values(), key=lambda x: x[0], reverse=True)
349
+
350
+ # Add RRF score to metadata
351
+ merged_docs: list[Document] = []
352
+ for rrf_score, doc in sorted_docs:
353
+ merged_doc = Document(
354
+ page_content=doc.page_content,
355
+ metadata={**doc.metadata, "rrf_score": rrf_score},
356
+ )
357
+ merged_docs.append(merged_doc)
358
+
359
+ logger.debug(
360
+ "RRF merge complete",
361
+ input_lists=len(results_lists),
362
+ total_docs=sum(len(r) for r in results_lists),
363
+ unique_docs=len(merged_docs),
364
+ )
365
+
366
+ return merged_docs