dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +446 -16
  7. dao_ai/config.py +1034 -103
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +5 -0
  23. dao_ai/middleware/tool_selector.py +129 -0
  24. dao_ai/models.py +327 -370
  25. dao_ai/nodes.py +4 -4
  26. dao_ai/orchestration/core.py +33 -9
  27. dao_ai/orchestration/supervisor.py +23 -8
  28. dao_ai/orchestration/swarm.py +6 -1
  29. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  30. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  31. dao_ai/prompts/instruction_reranker.yaml +14 -0
  32. dao_ai/prompts/router.yaml +37 -0
  33. dao_ai/prompts/verifier.yaml +46 -0
  34. dao_ai/providers/base.py +28 -2
  35. dao_ai/providers/databricks.py +352 -33
  36. dao_ai/state.py +1 -0
  37. dao_ai/tools/__init__.py +5 -3
  38. dao_ai/tools/genie.py +103 -26
  39. dao_ai/tools/instructed_retriever.py +366 -0
  40. dao_ai/tools/instruction_reranker.py +202 -0
  41. dao_ai/tools/mcp.py +539 -97
  42. dao_ai/tools/router.py +89 -0
  43. dao_ai/tools/slack.py +13 -2
  44. dao_ai/tools/sql.py +7 -3
  45. dao_ai/tools/unity_catalog.py +32 -10
  46. dao_ai/tools/vector_search.py +493 -160
  47. dao_ai/tools/verifier.py +159 -0
  48. dao_ai/utils.py +182 -2
  49. dao_ai/vector_search.py +9 -1
  50. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
  51. dao_ai-0.1.20.dist-info/RECORD +89 -0
  52. dao_ai/agent_as_code.py +0 -22
  53. dao_ai/genie/cache/semantic.py +0 -970
  54. dao_ai-0.1.5.dist-info/RECORD +0 -70
  55. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  56. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  57. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/router.py ADDED
@@ -0,0 +1,89 @@
1
+ """
2
+ Query router for selecting execution mode based on query characteristics.
3
+
4
+ Routes to internal execution modes within the same retriever instance:
5
+ - standard: Single similarity_search for simple queries
6
+ - instructed: Decompose -> Parallel Search -> RRF for constrained queries
7
+ """
8
+
9
+ from pathlib import Path
10
+ from typing import Any, Literal
11
+
12
+ import mlflow
13
+ import yaml
14
+ from langchain_core.language_models import BaseChatModel
15
+ from langchain_core.runnables import Runnable
16
+ from loguru import logger
17
+ from mlflow.entities import SpanType
18
+ from pydantic import BaseModel, ConfigDict, Field
19
+
20
+ # Load prompt template
21
+ _PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "router.yaml"
22
+
23
+
24
+ def _load_prompt_template() -> dict[str, Any]:
25
+ """Load the router prompt template from YAML."""
26
+ with open(_PROMPT_PATH) as f:
27
+ return yaml.safe_load(f)
28
+
29
+
30
+ class RouterDecision(BaseModel):
31
+ """Classification of a search query into an execution mode.
32
+
33
+ Analyze whether the query contains explicit constraints that map to
34
+ filterable metadata columns, or is a simple semantic search.
35
+ """
36
+
37
+ model_config = ConfigDict(extra="forbid")
38
+ mode: Literal["standard", "instructed"] = Field(
39
+ description=(
40
+ "The execution mode. "
41
+ "Use 'standard' for simple semantic searches without constraints. "
42
+ "Use 'instructed' when the query contains explicit constraints "
43
+ "that can be translated to metadata filters."
44
+ )
45
+ )
46
+
47
+
48
+ @mlflow.trace(name="route_query", span_type=SpanType.LLM)
49
+ def route_query(
50
+ llm: BaseChatModel,
51
+ query: str,
52
+ schema_description: str,
53
+ ) -> Literal["standard", "instructed"]:
54
+ """
55
+ Determine the execution mode for a search query.
56
+
57
+ Args:
58
+ llm: Language model for routing decision
59
+ query: User's search query
60
+ schema_description: Column names, types, and filter syntax
61
+
62
+ Returns:
63
+ "standard" for simple queries, "instructed" for constrained queries
64
+ """
65
+ prompt_config = _load_prompt_template()
66
+ prompt_template = prompt_config["template"]
67
+
68
+ prompt = prompt_template.format(
69
+ schema_description=schema_description,
70
+ query=query,
71
+ )
72
+
73
+ logger.trace("Routing query", query=query[:100])
74
+
75
+ # Use LangChain's with_structured_output for automatic strategy selection
76
+ # (JSON schema vs tool calling based on model capabilities)
77
+ try:
78
+ structured_llm: Runnable[str, RouterDecision] = llm.with_structured_output(
79
+ RouterDecision
80
+ )
81
+ decision: RouterDecision = structured_llm.invoke(prompt)
82
+ except Exception as e:
83
+ logger.warning("Router failed, defaulting to standard mode", error=str(e))
84
+ return "standard"
85
+
86
+ logger.debug("Router decision", mode=decision.mode, query=query[:50])
87
+ mlflow.set_tag("router.mode", decision.mode)
88
+
89
+ return decision.mode
dao_ai/tools/slack.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from typing import Any, Callable, Optional
2
2
 
3
3
  from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod
4
+ from langchain.tools import ToolRuntime
4
5
  from langchain_core.tools import tool
5
6
  from loguru import logger
6
7
  from requests import Response
7
8
 
8
9
  from dao_ai.config import ConnectionModel
10
+ from dao_ai.state import Context
9
11
 
10
12
 
11
13
  def _find_channel_id_by_name(
@@ -129,8 +131,17 @@ def create_send_slack_message_tool(
129
131
  name_or_callable=name,
130
132
  description=description,
131
133
  )
132
- def send_slack_message(text: str) -> str:
133
- response: Response = connection.workspace_client.serving_endpoints.http_request(
134
+ def send_slack_message(
135
+ text: str,
136
+ runtime: ToolRuntime[Context] = None,
137
+ ) -> str:
138
+ from databricks.sdk import WorkspaceClient
139
+
140
+ # Get workspace client with OBO support via context
141
+ context: Context | None = runtime.context if runtime else None
142
+ workspace_client: WorkspaceClient = connection.workspace_client_from(context)
143
+
144
+ response: Response = workspace_client.serving_endpoints.http_request(
134
145
  conn=connection.name,
135
146
  method=ExternalFunctionRequestHttpMethod.POST,
136
147
  path="/api/chat.postMessage",
dao_ai/tools/sql.py CHANGED
@@ -7,10 +7,11 @@ pre-configured SQL statements against a Databricks SQL warehouse.
7
7
 
8
8
  from databricks.sdk import WorkspaceClient
9
9
  from databricks.sdk.service.sql import StatementResponse, StatementState
10
- from langchain.tools import tool
10
+ from langchain.tools import ToolRuntime, tool
11
11
  from loguru import logger
12
12
 
13
13
  from dao_ai.config import WarehouseModel, value_of
14
+ from dao_ai.state import Context
14
15
 
15
16
 
16
17
  def create_execute_statement_tool(
@@ -63,7 +64,6 @@ def create_execute_statement_tool(
63
64
  description = f"Execute a pre-configured SQL query against the {warehouse.name} warehouse and return the results."
64
65
 
65
66
  warehouse_id: str = value_of(warehouse.warehouse_id)
66
- workspace_client: WorkspaceClient = warehouse.workspace_client
67
67
 
68
68
  logger.debug(
69
69
  "Creating SQL execution tool",
@@ -74,7 +74,7 @@ def create_execute_statement_tool(
74
74
  )
75
75
 
76
76
  @tool(name_or_callable=name, description=description)
77
- def execute_statement_tool() -> str:
77
+ def execute_statement_tool(runtime: ToolRuntime[Context] = None) -> str:
78
78
  """
79
79
  Execute the pre-configured SQL statement against the Databricks SQL warehouse.
80
80
 
@@ -88,6 +88,10 @@ def create_execute_statement_tool(
88
88
  sql_preview=statement[:100] + "..." if len(statement) > 100 else statement,
89
89
  )
90
90
 
91
+ # Get workspace client with OBO support via context
92
+ context: Context | None = runtime.context if runtime else None
93
+ workspace_client: WorkspaceClient = warehouse.workspace_client_from(context)
94
+
91
95
  try:
92
96
  # Execute the SQL statement
93
97
  statement_response: StatementResponse = (
@@ -1,10 +1,11 @@
1
- from typing import Any, Dict, Optional, Sequence, Set
1
+ from typing import Annotated, Any, Dict, Optional, Sequence, Set
2
2
 
3
3
  from databricks.sdk import WorkspaceClient
4
4
  from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
5
5
  from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
6
+ from langchain.tools import ToolRuntime
6
7
  from langchain_core.runnables.base import RunnableLike
7
- from langchain_core.tools import StructuredTool
8
+ from langchain_core.tools import InjectedToolArg, StructuredTool
8
9
  from loguru import logger
9
10
  from pydantic import BaseModel
10
11
  from unitycatalog.ai.core.base import FunctionExecutionResult
@@ -15,6 +16,7 @@ from dao_ai.config import (
15
16
  UnityCatalogFunctionModel,
16
17
  value_of,
17
18
  )
19
+ from dao_ai.state import Context
18
20
  from dao_ai.utils import normalize_host
19
21
 
20
22
 
@@ -35,13 +37,11 @@ def create_uc_tools(
35
37
  A sequence of BaseTool objects that wrap the specified UC functions
36
38
  """
37
39
  original_function_model: UnityCatalogFunctionModel | None = None
38
- workspace_client: WorkspaceClient | None = None
39
40
  function_name: str
40
41
 
41
42
  if isinstance(function, UnityCatalogFunctionModel):
42
43
  original_function_model = function
43
44
  function_name = function.resource.full_name
44
- workspace_client = function.resource.workspace_client
45
45
  else:
46
46
  function_name = function
47
47
 
@@ -56,6 +56,12 @@ def create_uc_tools(
56
56
  # Use with_partial_args directly with UnityCatalogFunctionModel
57
57
  tools = [with_partial_args(original_function_model)]
58
58
  else:
59
+ # For standard UC toolkit, we need workspace_client at creation time
60
+ # Use the resource's workspace_client (will use ambient auth if no OBO)
61
+ workspace_client: WorkspaceClient | None = None
62
+ if original_function_model:
63
+ workspace_client = original_function_model.resource.workspace_client
64
+
59
65
  # Fallback to standard UC toolkit approach
60
66
  client: DatabricksFunctionClient = DatabricksFunctionClient(
61
67
  client=workspace_client
@@ -356,7 +362,6 @@ def with_partial_args(
356
362
  # Get function info from the resource
357
363
  function_name: str = uc_function.resource.full_name
358
364
  tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
359
- workspace_client: WorkspaceClient = uc_function.resource.workspace_client
360
365
 
361
366
  logger.debug(
362
367
  "Creating UC tool with partial args",
@@ -365,7 +370,7 @@ def with_partial_args(
365
370
  partial_args=list(resolved_args.keys()),
366
371
  )
367
372
 
368
- # Grant permissions if we have credentials
373
+ # Grant permissions if we have credentials (using ambient auth for setup)
369
374
  if "client_id" in resolved_args:
370
375
  client_id: str = resolved_args["client_id"]
371
376
  host: Optional[str] = resolved_args.get("host")
@@ -376,14 +381,18 @@ def with_partial_args(
376
381
  "Failed to grant permissions", function_name=function_name, error=str(e)
377
382
  )
378
383
 
379
- # Create the client for function execution using the resource's workspace client
380
- client: DatabricksFunctionClient = DatabricksFunctionClient(client=workspace_client)
384
+ # Get workspace client for schema introspection (uses ambient auth at definition time)
385
+ # Actual execution will use OBO via context
386
+ setup_workspace_client: WorkspaceClient = uc_function.resource.workspace_client
387
+ setup_client: DatabricksFunctionClient = DatabricksFunctionClient(
388
+ client=setup_workspace_client
389
+ )
381
390
 
382
391
  # Try to get the function schema for better tool definition
383
392
  schema_model: type[BaseModel]
384
393
  tool_description: str
385
394
  try:
386
- function_info: FunctionInfo = client.get_function(function_name)
395
+ function_info: FunctionInfo = setup_client.get_function(function_name)
387
396
  schema_info = generate_function_input_params_schema(function_info)
388
397
  tool_description = (
389
398
  function_info.comment or f"Unity Catalog function: {function_name}"
@@ -419,8 +428,21 @@ def with_partial_args(
419
428
  tool_description = f"Unity Catalog function: {function_name}"
420
429
 
421
430
  # Create a wrapper function that calls _execute_uc_function with partial args
422
- def uc_function_wrapper(**kwargs) -> str:
431
+ # Uses InjectedToolArg to ensure runtime is injected but hidden from the LLM
432
+ def uc_function_wrapper(
433
+ runtime: Annotated[ToolRuntime[Context], InjectedToolArg] = None,
434
+ **kwargs: Any,
435
+ ) -> str:
423
436
  """Wrapper function that executes Unity Catalog function with partial args."""
437
+ # Get workspace client with OBO support via context
438
+ context: Context | None = runtime.context if runtime else None
439
+ workspace_client: WorkspaceClient = uc_function.resource.workspace_client_from(
440
+ context
441
+ )
442
+ client: DatabricksFunctionClient = DatabricksFunctionClient(
443
+ client=workspace_client
444
+ )
445
+
424
446
  return _execute_uc_function(
425
447
  client=client,
426
448
  function_name=function_name,