dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +446 -16
- dao_ai/config.py +1034 -103
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +5 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +4 -4
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +352 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.5.dist-info/RECORD +0 -70
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/router.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query router for selecting execution mode based on query characteristics.
|
|
3
|
+
|
|
4
|
+
Routes to internal execution modes within the same retriever instance:
|
|
5
|
+
- standard: Single similarity_search for simple queries
|
|
6
|
+
- instructed: Decompose -> Parallel Search -> RRF for constrained queries
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Literal
|
|
11
|
+
|
|
12
|
+
import mlflow
|
|
13
|
+
import yaml
|
|
14
|
+
from langchain_core.language_models import BaseChatModel
|
|
15
|
+
from langchain_core.runnables import Runnable
|
|
16
|
+
from loguru import logger
|
|
17
|
+
from mlflow.entities import SpanType
|
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
19
|
+
|
|
20
|
+
# Load prompt template
|
|
21
|
+
_PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "router.yaml"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _load_prompt_template() -> dict[str, Any]:
|
|
25
|
+
"""Load the router prompt template from YAML."""
|
|
26
|
+
with open(_PROMPT_PATH) as f:
|
|
27
|
+
return yaml.safe_load(f)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RouterDecision(BaseModel):
|
|
31
|
+
"""Classification of a search query into an execution mode.
|
|
32
|
+
|
|
33
|
+
Analyze whether the query contains explicit constraints that map to
|
|
34
|
+
filterable metadata columns, or is a simple semantic search.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
model_config = ConfigDict(extra="forbid")
|
|
38
|
+
mode: Literal["standard", "instructed"] = Field(
|
|
39
|
+
description=(
|
|
40
|
+
"The execution mode. "
|
|
41
|
+
"Use 'standard' for simple semantic searches without constraints. "
|
|
42
|
+
"Use 'instructed' when the query contains explicit constraints "
|
|
43
|
+
"that can be translated to metadata filters."
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@mlflow.trace(name="route_query", span_type=SpanType.LLM)
|
|
49
|
+
def route_query(
|
|
50
|
+
llm: BaseChatModel,
|
|
51
|
+
query: str,
|
|
52
|
+
schema_description: str,
|
|
53
|
+
) -> Literal["standard", "instructed"]:
|
|
54
|
+
"""
|
|
55
|
+
Determine the execution mode for a search query.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
llm: Language model for routing decision
|
|
59
|
+
query: User's search query
|
|
60
|
+
schema_description: Column names, types, and filter syntax
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
"standard" for simple queries, "instructed" for constrained queries
|
|
64
|
+
"""
|
|
65
|
+
prompt_config = _load_prompt_template()
|
|
66
|
+
prompt_template = prompt_config["template"]
|
|
67
|
+
|
|
68
|
+
prompt = prompt_template.format(
|
|
69
|
+
schema_description=schema_description,
|
|
70
|
+
query=query,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger.trace("Routing query", query=query[:100])
|
|
74
|
+
|
|
75
|
+
# Use LangChain's with_structured_output for automatic strategy selection
|
|
76
|
+
# (JSON schema vs tool calling based on model capabilities)
|
|
77
|
+
try:
|
|
78
|
+
structured_llm: Runnable[str, RouterDecision] = llm.with_structured_output(
|
|
79
|
+
RouterDecision
|
|
80
|
+
)
|
|
81
|
+
decision: RouterDecision = structured_llm.invoke(prompt)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.warning("Router failed, defaulting to standard mode", error=str(e))
|
|
84
|
+
return "standard"
|
|
85
|
+
|
|
86
|
+
logger.debug("Router decision", mode=decision.mode, query=query[:50])
|
|
87
|
+
mlflow.set_tag("router.mode", decision.mode)
|
|
88
|
+
|
|
89
|
+
return decision.mode
|
dao_ai/tools/slack.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from typing import Any, Callable, Optional
|
|
2
2
|
|
|
3
3
|
from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod
|
|
4
|
+
from langchain.tools import ToolRuntime
|
|
4
5
|
from langchain_core.tools import tool
|
|
5
6
|
from loguru import logger
|
|
6
7
|
from requests import Response
|
|
7
8
|
|
|
8
9
|
from dao_ai.config import ConnectionModel
|
|
10
|
+
from dao_ai.state import Context
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def _find_channel_id_by_name(
|
|
@@ -129,8 +131,17 @@ def create_send_slack_message_tool(
|
|
|
129
131
|
name_or_callable=name,
|
|
130
132
|
description=description,
|
|
131
133
|
)
|
|
132
|
-
def send_slack_message(
|
|
133
|
-
|
|
134
|
+
def send_slack_message(
|
|
135
|
+
text: str,
|
|
136
|
+
runtime: ToolRuntime[Context] = None,
|
|
137
|
+
) -> str:
|
|
138
|
+
from databricks.sdk import WorkspaceClient
|
|
139
|
+
|
|
140
|
+
# Get workspace client with OBO support via context
|
|
141
|
+
context: Context | None = runtime.context if runtime else None
|
|
142
|
+
workspace_client: WorkspaceClient = connection.workspace_client_from(context)
|
|
143
|
+
|
|
144
|
+
response: Response = workspace_client.serving_endpoints.http_request(
|
|
134
145
|
conn=connection.name,
|
|
135
146
|
method=ExternalFunctionRequestHttpMethod.POST,
|
|
136
147
|
path="/api/chat.postMessage",
|
dao_ai/tools/sql.py
CHANGED
|
@@ -7,10 +7,11 @@ pre-configured SQL statements against a Databricks SQL warehouse.
|
|
|
7
7
|
|
|
8
8
|
from databricks.sdk import WorkspaceClient
|
|
9
9
|
from databricks.sdk.service.sql import StatementResponse, StatementState
|
|
10
|
-
from langchain.tools import tool
|
|
10
|
+
from langchain.tools import ToolRuntime, tool
|
|
11
11
|
from loguru import logger
|
|
12
12
|
|
|
13
13
|
from dao_ai.config import WarehouseModel, value_of
|
|
14
|
+
from dao_ai.state import Context
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def create_execute_statement_tool(
|
|
@@ -63,7 +64,6 @@ def create_execute_statement_tool(
|
|
|
63
64
|
description = f"Execute a pre-configured SQL query against the {warehouse.name} warehouse and return the results."
|
|
64
65
|
|
|
65
66
|
warehouse_id: str = value_of(warehouse.warehouse_id)
|
|
66
|
-
workspace_client: WorkspaceClient = warehouse.workspace_client
|
|
67
67
|
|
|
68
68
|
logger.debug(
|
|
69
69
|
"Creating SQL execution tool",
|
|
@@ -74,7 +74,7 @@ def create_execute_statement_tool(
|
|
|
74
74
|
)
|
|
75
75
|
|
|
76
76
|
@tool(name_or_callable=name, description=description)
|
|
77
|
-
def execute_statement_tool() -> str:
|
|
77
|
+
def execute_statement_tool(runtime: ToolRuntime[Context] = None) -> str:
|
|
78
78
|
"""
|
|
79
79
|
Execute the pre-configured SQL statement against the Databricks SQL warehouse.
|
|
80
80
|
|
|
@@ -88,6 +88,10 @@ def create_execute_statement_tool(
|
|
|
88
88
|
sql_preview=statement[:100] + "..." if len(statement) > 100 else statement,
|
|
89
89
|
)
|
|
90
90
|
|
|
91
|
+
# Get workspace client with OBO support via context
|
|
92
|
+
context: Context | None = runtime.context if runtime else None
|
|
93
|
+
workspace_client: WorkspaceClient = warehouse.workspace_client_from(context)
|
|
94
|
+
|
|
91
95
|
try:
|
|
92
96
|
# Execute the SQL statement
|
|
93
97
|
statement_response: StatementResponse = (
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional, Sequence, Set
|
|
1
|
+
from typing import Annotated, Any, Dict, Optional, Sequence, Set
|
|
2
2
|
|
|
3
3
|
from databricks.sdk import WorkspaceClient
|
|
4
4
|
from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
|
|
5
5
|
from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
|
|
6
|
+
from langchain.tools import ToolRuntime
|
|
6
7
|
from langchain_core.runnables.base import RunnableLike
|
|
7
|
-
from langchain_core.tools import StructuredTool
|
|
8
|
+
from langchain_core.tools import InjectedToolArg, StructuredTool
|
|
8
9
|
from loguru import logger
|
|
9
10
|
from pydantic import BaseModel
|
|
10
11
|
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
@@ -15,6 +16,7 @@ from dao_ai.config import (
|
|
|
15
16
|
UnityCatalogFunctionModel,
|
|
16
17
|
value_of,
|
|
17
18
|
)
|
|
19
|
+
from dao_ai.state import Context
|
|
18
20
|
from dao_ai.utils import normalize_host
|
|
19
21
|
|
|
20
22
|
|
|
@@ -35,13 +37,11 @@ def create_uc_tools(
|
|
|
35
37
|
A sequence of BaseTool objects that wrap the specified UC functions
|
|
36
38
|
"""
|
|
37
39
|
original_function_model: UnityCatalogFunctionModel | None = None
|
|
38
|
-
workspace_client: WorkspaceClient | None = None
|
|
39
40
|
function_name: str
|
|
40
41
|
|
|
41
42
|
if isinstance(function, UnityCatalogFunctionModel):
|
|
42
43
|
original_function_model = function
|
|
43
44
|
function_name = function.resource.full_name
|
|
44
|
-
workspace_client = function.resource.workspace_client
|
|
45
45
|
else:
|
|
46
46
|
function_name = function
|
|
47
47
|
|
|
@@ -56,6 +56,12 @@ def create_uc_tools(
|
|
|
56
56
|
# Use with_partial_args directly with UnityCatalogFunctionModel
|
|
57
57
|
tools = [with_partial_args(original_function_model)]
|
|
58
58
|
else:
|
|
59
|
+
# For standard UC toolkit, we need workspace_client at creation time
|
|
60
|
+
# Use the resource's workspace_client (will use ambient auth if no OBO)
|
|
61
|
+
workspace_client: WorkspaceClient | None = None
|
|
62
|
+
if original_function_model:
|
|
63
|
+
workspace_client = original_function_model.resource.workspace_client
|
|
64
|
+
|
|
59
65
|
# Fallback to standard UC toolkit approach
|
|
60
66
|
client: DatabricksFunctionClient = DatabricksFunctionClient(
|
|
61
67
|
client=workspace_client
|
|
@@ -356,7 +362,6 @@ def with_partial_args(
|
|
|
356
362
|
# Get function info from the resource
|
|
357
363
|
function_name: str = uc_function.resource.full_name
|
|
358
364
|
tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
|
|
359
|
-
workspace_client: WorkspaceClient = uc_function.resource.workspace_client
|
|
360
365
|
|
|
361
366
|
logger.debug(
|
|
362
367
|
"Creating UC tool with partial args",
|
|
@@ -365,7 +370,7 @@ def with_partial_args(
|
|
|
365
370
|
partial_args=list(resolved_args.keys()),
|
|
366
371
|
)
|
|
367
372
|
|
|
368
|
-
# Grant permissions if we have credentials
|
|
373
|
+
# Grant permissions if we have credentials (using ambient auth for setup)
|
|
369
374
|
if "client_id" in resolved_args:
|
|
370
375
|
client_id: str = resolved_args["client_id"]
|
|
371
376
|
host: Optional[str] = resolved_args.get("host")
|
|
@@ -376,14 +381,18 @@ def with_partial_args(
|
|
|
376
381
|
"Failed to grant permissions", function_name=function_name, error=str(e)
|
|
377
382
|
)
|
|
378
383
|
|
|
379
|
-
#
|
|
380
|
-
|
|
384
|
+
# Get workspace client for schema introspection (uses ambient auth at definition time)
|
|
385
|
+
# Actual execution will use OBO via context
|
|
386
|
+
setup_workspace_client: WorkspaceClient = uc_function.resource.workspace_client
|
|
387
|
+
setup_client: DatabricksFunctionClient = DatabricksFunctionClient(
|
|
388
|
+
client=setup_workspace_client
|
|
389
|
+
)
|
|
381
390
|
|
|
382
391
|
# Try to get the function schema for better tool definition
|
|
383
392
|
schema_model: type[BaseModel]
|
|
384
393
|
tool_description: str
|
|
385
394
|
try:
|
|
386
|
-
function_info: FunctionInfo =
|
|
395
|
+
function_info: FunctionInfo = setup_client.get_function(function_name)
|
|
387
396
|
schema_info = generate_function_input_params_schema(function_info)
|
|
388
397
|
tool_description = (
|
|
389
398
|
function_info.comment or f"Unity Catalog function: {function_name}"
|
|
@@ -419,8 +428,21 @@ def with_partial_args(
|
|
|
419
428
|
tool_description = f"Unity Catalog function: {function_name}"
|
|
420
429
|
|
|
421
430
|
# Create a wrapper function that calls _execute_uc_function with partial args
|
|
422
|
-
|
|
431
|
+
# Uses InjectedToolArg to ensure runtime is injected but hidden from the LLM
|
|
432
|
+
def uc_function_wrapper(
|
|
433
|
+
runtime: Annotated[ToolRuntime[Context], InjectedToolArg] = None,
|
|
434
|
+
**kwargs: Any,
|
|
435
|
+
) -> str:
|
|
423
436
|
"""Wrapper function that executes Unity Catalog function with partial args."""
|
|
437
|
+
# Get workspace client with OBO support via context
|
|
438
|
+
context: Context | None = runtime.context if runtime else None
|
|
439
|
+
workspace_client: WorkspaceClient = uc_function.resource.workspace_client_from(
|
|
440
|
+
context
|
|
441
|
+
)
|
|
442
|
+
client: DatabricksFunctionClient = DatabricksFunctionClient(
|
|
443
|
+
client=workspace_client
|
|
444
|
+
)
|
|
445
|
+
|
|
424
446
|
return _execute_uc_function(
|
|
425
447
|
client=client,
|
|
426
448
|
function_name=function_name,
|