dao-ai 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +371 -27
- dao_ai/graph.py +29 -4
- dao_ai/nodes.py +29 -20
- dao_ai/providers/databricks.py +536 -35
- dao_ai/tools/mcp.py +46 -27
- dao_ai/utils.py +56 -1
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.27.dist-info}/METADATA +4 -2
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.27.dist-info}/RECORD +11 -11
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.27.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.27.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.27.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -30,12 +30,15 @@ from databricks_langchain import (
|
|
|
30
30
|
DatabricksFunctionClient,
|
|
31
31
|
)
|
|
32
32
|
from langchain_core.language_models import LanguageModelLike
|
|
33
|
+
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
33
34
|
from langchain_core.runnables.base import RunnableLike
|
|
34
35
|
from langchain_openai import ChatOpenAI
|
|
35
36
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
36
37
|
from langgraph.graph.state import CompiledStateGraph
|
|
37
38
|
from langgraph.store.base import BaseStore
|
|
38
39
|
from loguru import logger
|
|
40
|
+
from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
41
|
+
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
39
42
|
from mlflow.models import ModelConfig
|
|
40
43
|
from mlflow.models.resources import (
|
|
41
44
|
DatabricksFunction,
|
|
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
|
|
|
49
52
|
DatabricksVectorSearchIndex,
|
|
50
53
|
)
|
|
51
54
|
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
55
|
+
from mlflow.types.responses import (
|
|
56
|
+
ResponsesAgentRequest,
|
|
57
|
+
)
|
|
52
58
|
from pydantic import (
|
|
53
59
|
BaseModel,
|
|
54
60
|
ConfigDict,
|
|
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
324
330
|
"serving.serving-endpoints",
|
|
325
331
|
]
|
|
326
332
|
|
|
333
|
+
@property
|
|
334
|
+
def uri(self) -> str:
|
|
335
|
+
return f"databricks:/{self.name}"
|
|
336
|
+
|
|
327
337
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
328
338
|
return [
|
|
329
339
|
DatabricksServingEndpoint(
|
|
@@ -387,6 +397,13 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
387
397
|
name: str
|
|
388
398
|
type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
|
|
389
399
|
|
|
400
|
+
@field_serializer("type")
|
|
401
|
+
def serialize_type(self, value: VectorSearchEndpointType) -> str:
|
|
402
|
+
"""Ensure enum is serialized to string value."""
|
|
403
|
+
if isinstance(value, VectorSearchEndpointType):
|
|
404
|
+
return value.value
|
|
405
|
+
return str(value)
|
|
406
|
+
|
|
390
407
|
|
|
391
408
|
class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
392
409
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -988,22 +1005,120 @@ class TransportType(str, Enum):
|
|
|
988
1005
|
class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
989
1006
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
990
1007
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
991
|
-
|
|
992
1008
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
993
1009
|
command: Optional[str] = "python"
|
|
994
1010
|
url: Optional[AnyVariable] = None
|
|
995
|
-
connection: Optional[ConnectionModel] = None
|
|
996
1011
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
997
1012
|
args: list[str] = Field(default_factory=list)
|
|
998
1013
|
pat: Optional[AnyVariable] = None
|
|
999
1014
|
client_id: Optional[AnyVariable] = None
|
|
1000
1015
|
client_secret: Optional[AnyVariable] = None
|
|
1001
1016
|
workspace_host: Optional[AnyVariable] = None
|
|
1017
|
+
connection: Optional[ConnectionModel] = None
|
|
1018
|
+
functions: Optional[SchemaModel] = None
|
|
1019
|
+
genie_room: Optional[GenieRoomModel] = None
|
|
1020
|
+
sql: Optional[bool] = None
|
|
1021
|
+
vector_search: Optional[VectorStoreModel] = None
|
|
1002
1022
|
|
|
1003
1023
|
@property
|
|
1004
1024
|
def full_name(self) -> str:
|
|
1005
1025
|
return self.name
|
|
1006
1026
|
|
|
1027
|
+
def _get_workspace_host(self) -> str:
|
|
1028
|
+
"""
|
|
1029
|
+
Get the workspace host, either from config or from workspace client.
|
|
1030
|
+
|
|
1031
|
+
If connection is provided, uses its workspace client.
|
|
1032
|
+
Otherwise, falls back to creating a new workspace client.
|
|
1033
|
+
|
|
1034
|
+
Returns:
|
|
1035
|
+
str: The workspace host URL without trailing slash
|
|
1036
|
+
"""
|
|
1037
|
+
from databricks.sdk import WorkspaceClient
|
|
1038
|
+
|
|
1039
|
+
# Try to get workspace_host from config
|
|
1040
|
+
workspace_host: str | None = (
|
|
1041
|
+
value_of(self.workspace_host) if self.workspace_host else None
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
# If no workspace_host in config, get it from workspace client
|
|
1045
|
+
if not workspace_host:
|
|
1046
|
+
# Use connection's workspace client if available
|
|
1047
|
+
if self.connection:
|
|
1048
|
+
workspace_host = self.connection.workspace_client.config.host
|
|
1049
|
+
else:
|
|
1050
|
+
# Create a default workspace client
|
|
1051
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
1052
|
+
workspace_host = w.config.host
|
|
1053
|
+
|
|
1054
|
+
# Remove trailing slash
|
|
1055
|
+
return workspace_host.rstrip("/")
|
|
1056
|
+
|
|
1057
|
+
@property
|
|
1058
|
+
def mcp_url(self) -> str:
|
|
1059
|
+
"""
|
|
1060
|
+
Get the MCP URL for this function.
|
|
1061
|
+
|
|
1062
|
+
Returns the URL based on the configured source:
|
|
1063
|
+
- If url is set, returns it directly
|
|
1064
|
+
- If connection is set, constructs URL from connection
|
|
1065
|
+
- If genie_room is set, constructs Genie MCP URL
|
|
1066
|
+
- If sql is set, constructs DBSQL MCP URL (serverless)
|
|
1067
|
+
- If vector_search is set, constructs Vector Search MCP URL
|
|
1068
|
+
- If functions is set, constructs UC Functions MCP URL
|
|
1069
|
+
|
|
1070
|
+
URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
|
|
1071
|
+
- Genie: https://{host}/api/2.0/mcp/genie/{space_id}
|
|
1072
|
+
- DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
|
|
1073
|
+
- Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
|
|
1074
|
+
- UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
|
|
1075
|
+
- Connection: https://{host}/api/2.0/mcp/external/{connection_name}
|
|
1076
|
+
"""
|
|
1077
|
+
# Direct URL provided
|
|
1078
|
+
if self.url:
|
|
1079
|
+
return self.url
|
|
1080
|
+
|
|
1081
|
+
# Get workspace host (from config, connection, or default workspace client)
|
|
1082
|
+
workspace_host: str = self._get_workspace_host()
|
|
1083
|
+
|
|
1084
|
+
# UC Connection
|
|
1085
|
+
if self.connection:
|
|
1086
|
+
connection_name: str = self.connection.name
|
|
1087
|
+
return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
|
|
1088
|
+
|
|
1089
|
+
# Genie Room
|
|
1090
|
+
if self.genie_room:
|
|
1091
|
+
space_id: str = value_of(self.genie_room.space_id)
|
|
1092
|
+
return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
|
|
1093
|
+
|
|
1094
|
+
# DBSQL MCP server (serverless, workspace-level)
|
|
1095
|
+
if self.sql:
|
|
1096
|
+
return f"{workspace_host}/api/2.0/mcp/sql"
|
|
1097
|
+
|
|
1098
|
+
# Vector Search
|
|
1099
|
+
if self.vector_search:
|
|
1100
|
+
if (
|
|
1101
|
+
not self.vector_search.index
|
|
1102
|
+
or not self.vector_search.index.schema_model
|
|
1103
|
+
):
|
|
1104
|
+
raise ValueError(
|
|
1105
|
+
"vector_search must have an index with a schema (catalog/schema) configured"
|
|
1106
|
+
)
|
|
1107
|
+
catalog: str = self.vector_search.index.schema_model.catalog_name
|
|
1108
|
+
schema: str = self.vector_search.index.schema_model.schema_name
|
|
1109
|
+
return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
|
|
1110
|
+
|
|
1111
|
+
# UC Functions MCP server
|
|
1112
|
+
if self.functions:
|
|
1113
|
+
catalog: str = self.functions.catalog_name
|
|
1114
|
+
schema: str = self.functions.schema_name
|
|
1115
|
+
return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
|
|
1116
|
+
|
|
1117
|
+
raise ValueError(
|
|
1118
|
+
"No URL source configured. Provide one of: url, connection, genie_room, "
|
|
1119
|
+
"sql, vector_search, or functions"
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1007
1122
|
@field_serializer("transport")
|
|
1008
1123
|
def serialize_transport(self, value) -> str:
|
|
1009
1124
|
if isinstance(value, TransportType):
|
|
@@ -1011,32 +1126,56 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1011
1126
|
return str(value)
|
|
1012
1127
|
|
|
1013
1128
|
@model_validator(mode="after")
|
|
1014
|
-
def validate_mutually_exclusive(self):
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1129
|
+
def validate_mutually_exclusive(self) -> "McpFunctionModel":
|
|
1130
|
+
"""Validate that exactly one URL source is provided."""
|
|
1131
|
+
# Count how many URL sources are provided
|
|
1132
|
+
url_sources: list[tuple[str, Any]] = [
|
|
1133
|
+
("url", self.url),
|
|
1134
|
+
("connection", self.connection),
|
|
1135
|
+
("genie_room", self.genie_room),
|
|
1136
|
+
("sql", self.sql),
|
|
1137
|
+
("vector_search", self.vector_search),
|
|
1138
|
+
("functions", self.functions),
|
|
1139
|
+
]
|
|
1140
|
+
|
|
1141
|
+
provided_sources: list[str] = [
|
|
1142
|
+
name for name, value in url_sources if value is not None
|
|
1143
|
+
]
|
|
1144
|
+
|
|
1145
|
+
if self.transport == TransportType.STREAMABLE_HTTP:
|
|
1146
|
+
if len(provided_sources) == 0:
|
|
1147
|
+
raise ValueError(
|
|
1148
|
+
"For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
|
|
1149
|
+
"url, connection, genie_room, sql, vector_search, or functions"
|
|
1150
|
+
)
|
|
1151
|
+
if len(provided_sources) > 1:
|
|
1152
|
+
raise ValueError(
|
|
1153
|
+
f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
|
|
1154
|
+
f"Found: {', '.join(provided_sources)}. "
|
|
1155
|
+
f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
if self.transport == TransportType.STDIO:
|
|
1159
|
+
if not self.command:
|
|
1160
|
+
raise ValueError("command must be provided for STDIO transport")
|
|
1161
|
+
if not self.args:
|
|
1162
|
+
raise ValueError("args must be provided for STDIO transport")
|
|
1163
|
+
|
|
1025
1164
|
return self
|
|
1026
1165
|
|
|
1027
1166
|
@model_validator(mode="after")
|
|
1028
|
-
def update_url(self):
|
|
1167
|
+
def update_url(self) -> "McpFunctionModel":
|
|
1029
1168
|
self.url = value_of(self.url)
|
|
1030
1169
|
return self
|
|
1031
1170
|
|
|
1032
1171
|
@model_validator(mode="after")
|
|
1033
|
-
def update_headers(self):
|
|
1172
|
+
def update_headers(self) -> "McpFunctionModel":
|
|
1034
1173
|
for key, value in self.headers.items():
|
|
1035
1174
|
self.headers[key] = value_of(value)
|
|
1036
1175
|
return self
|
|
1037
1176
|
|
|
1038
1177
|
@model_validator(mode="after")
|
|
1039
|
-
def validate_auth_methods(self):
|
|
1178
|
+
def validate_auth_methods(self) -> "McpFunctionModel":
|
|
1040
1179
|
oauth_fields: Sequence[Any] = [
|
|
1041
1180
|
self.client_id,
|
|
1042
1181
|
self.client_secret,
|
|
@@ -1052,10 +1191,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1052
1191
|
"Please provide either OAuth credentials or user credentials."
|
|
1053
1192
|
)
|
|
1054
1193
|
|
|
1055
|
-
|
|
1056
|
-
raise ValueError(
|
|
1057
|
-
"Workspace host must be provided when using OAuth or user credentials."
|
|
1058
|
-
)
|
|
1194
|
+
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1059
1195
|
|
|
1060
1196
|
return self
|
|
1061
1197
|
|
|
@@ -1181,17 +1317,32 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1181
1317
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1182
1318
|
|
|
1183
1319
|
provider: DatabricksProvider = DatabricksProvider()
|
|
1184
|
-
|
|
1185
|
-
return
|
|
1320
|
+
prompt_version = provider.get_prompt(self)
|
|
1321
|
+
return prompt_version.to_single_brace_format()
|
|
1186
1322
|
|
|
1187
1323
|
@property
|
|
1188
1324
|
def full_name(self) -> str:
|
|
1325
|
+
prompt_name: str = self.name
|
|
1189
1326
|
if self.schema_model:
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1327
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1328
|
+
return prompt_name
|
|
1329
|
+
|
|
1330
|
+
@property
|
|
1331
|
+
def uri(self) -> str:
|
|
1332
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1333
|
+
|
|
1334
|
+
if self.alias:
|
|
1335
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1336
|
+
elif self.version:
|
|
1337
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1338
|
+
else:
|
|
1339
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1340
|
+
|
|
1341
|
+
return prompt_uri
|
|
1342
|
+
|
|
1343
|
+
def as_prompt(self) -> PromptVersion:
|
|
1344
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1345
|
+
return prompt_version
|
|
1195
1346
|
|
|
1196
1347
|
@model_validator(mode="after")
|
|
1197
1348
|
def validate_mutually_exclusive(self):
|
|
@@ -1213,6 +1364,17 @@ class AgentModel(BaseModel):
|
|
|
1213
1364
|
pre_agent_hook: Optional[FunctionHook] = None
|
|
1214
1365
|
post_agent_hook: Optional[FunctionHook] = None
|
|
1215
1366
|
|
|
1367
|
+
def as_runnable(self) -> RunnableLike:
|
|
1368
|
+
from dao_ai.nodes import create_agent_node
|
|
1369
|
+
|
|
1370
|
+
return create_agent_node(self)
|
|
1371
|
+
|
|
1372
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
1373
|
+
from dao_ai.models import create_responses_agent
|
|
1374
|
+
|
|
1375
|
+
graph: CompiledStateGraph = self.as_runnable()
|
|
1376
|
+
return create_responses_agent(graph)
|
|
1377
|
+
|
|
1216
1378
|
|
|
1217
1379
|
class SupervisorModel(BaseModel):
|
|
1218
1380
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1330,6 +1492,19 @@ class ChatPayload(BaseModel):
|
|
|
1330
1492
|
|
|
1331
1493
|
return self
|
|
1332
1494
|
|
|
1495
|
+
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1496
|
+
return messages_from_dict(
|
|
1497
|
+
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
def as_agent_request(self) -> ResponsesAgentRequest:
|
|
1501
|
+
from mlflow.types.responses_helpers import Message as _Message
|
|
1502
|
+
|
|
1503
|
+
return ResponsesAgentRequest(
|
|
1504
|
+
input=[_Message(role=m.role, content=m.content) for m in self.messages],
|
|
1505
|
+
custom_inputs=self.custom_inputs,
|
|
1506
|
+
)
|
|
1507
|
+
|
|
1333
1508
|
|
|
1334
1509
|
class ChatHistoryModel(BaseModel):
|
|
1335
1510
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1459,6 +1634,174 @@ class EvaluationModel(BaseModel):
|
|
|
1459
1634
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
1460
1635
|
|
|
1461
1636
|
|
|
1637
|
+
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
1638
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1639
|
+
expected_response: Optional[str] = None
|
|
1640
|
+
expected_facts: Optional[list[str]] = None
|
|
1641
|
+
|
|
1642
|
+
@model_validator(mode="after")
|
|
1643
|
+
def validate_mutually_exclusive(self):
|
|
1644
|
+
if self.expected_response is not None and self.expected_facts is not None:
|
|
1645
|
+
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1646
|
+
return self
|
|
1647
|
+
|
|
1648
|
+
|
|
1649
|
+
class EvaluationDatasetEntryModel(BaseModel):
|
|
1650
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1651
|
+
inputs: ChatPayload
|
|
1652
|
+
expectations: EvaluationDatasetExpectationsModel
|
|
1653
|
+
|
|
1654
|
+
def to_mlflow_format(self) -> dict[str, Any]:
|
|
1655
|
+
"""
|
|
1656
|
+
Convert to MLflow evaluation dataset format.
|
|
1657
|
+
|
|
1658
|
+
Flattens the expectations fields to the top level alongside inputs,
|
|
1659
|
+
which is the format expected by MLflow's Correctness scorer.
|
|
1660
|
+
|
|
1661
|
+
Returns:
|
|
1662
|
+
dict: Flattened dictionary with inputs and expectation fields at top level
|
|
1663
|
+
"""
|
|
1664
|
+
result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
|
|
1665
|
+
|
|
1666
|
+
# Flatten expectations to top level for MLflow compatibility
|
|
1667
|
+
if self.expectations.expected_response is not None:
|
|
1668
|
+
result["expected_response"] = self.expectations.expected_response
|
|
1669
|
+
if self.expectations.expected_facts is not None:
|
|
1670
|
+
result["expected_facts"] = self.expectations.expected_facts
|
|
1671
|
+
|
|
1672
|
+
return result
|
|
1673
|
+
|
|
1674
|
+
|
|
1675
|
+
class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
1676
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1677
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1678
|
+
name: str
|
|
1679
|
+
data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
|
|
1680
|
+
overwrite: Optional[bool] = False
|
|
1681
|
+
|
|
1682
|
+
def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
|
|
1683
|
+
evaluation_dataset: EvaluationDataset
|
|
1684
|
+
needs_creation: bool = False
|
|
1685
|
+
|
|
1686
|
+
try:
|
|
1687
|
+
evaluation_dataset = get_dataset(name=self.full_name)
|
|
1688
|
+
if self.overwrite:
|
|
1689
|
+
logger.warning(f"Overwriting dataset {self.full_name}")
|
|
1690
|
+
workspace_client: WorkspaceClient = w if w else WorkspaceClient()
|
|
1691
|
+
logger.debug(f"Dropping table: {self.full_name}")
|
|
1692
|
+
workspace_client.tables.delete(full_name=self.full_name)
|
|
1693
|
+
needs_creation = True
|
|
1694
|
+
except Exception:
|
|
1695
|
+
logger.warning(
|
|
1696
|
+
f"Dataset {self.full_name} not found, will create new dataset"
|
|
1697
|
+
)
|
|
1698
|
+
needs_creation = True
|
|
1699
|
+
|
|
1700
|
+
# Create dataset if needed (either new or after overwrite)
|
|
1701
|
+
if needs_creation:
|
|
1702
|
+
evaluation_dataset = create_dataset(name=self.full_name)
|
|
1703
|
+
if self.data:
|
|
1704
|
+
logger.debug(
|
|
1705
|
+
f"Merging {len(self.data)} entries into dataset {self.full_name}"
|
|
1706
|
+
)
|
|
1707
|
+
# Use to_mlflow_format() to flatten expectations for MLflow compatibility
|
|
1708
|
+
evaluation_dataset.merge_records(
|
|
1709
|
+
[e.to_mlflow_format() for e in self.data]
|
|
1710
|
+
)
|
|
1711
|
+
|
|
1712
|
+
return evaluation_dataset
|
|
1713
|
+
|
|
1714
|
+
@property
|
|
1715
|
+
def full_name(self) -> str:
|
|
1716
|
+
if self.schema_model:
|
|
1717
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1718
|
+
return self.name
|
|
1719
|
+
|
|
1720
|
+
|
|
1721
|
+
class PromptOptimizationModel(BaseModel):
|
|
1722
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1723
|
+
name: str
|
|
1724
|
+
prompt: Optional[PromptModel] = None
|
|
1725
|
+
agent: AgentModel
|
|
1726
|
+
dataset: (
|
|
1727
|
+
EvaluationDatasetModel | str
|
|
1728
|
+
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
1729
|
+
reflection_model: Optional[LLMModel | str] = None
|
|
1730
|
+
num_candidates: Optional[int] = 50
|
|
1731
|
+
scorer_model: Optional[LLMModel | str] = None
|
|
1732
|
+
|
|
1733
|
+
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1734
|
+
"""
|
|
1735
|
+
Optimize the prompt using MLflow's prompt optimization.
|
|
1736
|
+
|
|
1737
|
+
Args:
|
|
1738
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1739
|
+
|
|
1740
|
+
Returns:
|
|
1741
|
+
PromptModel: The optimized prompt model with new URI
|
|
1742
|
+
"""
|
|
1743
|
+
from dao_ai.providers.base import ServiceProvider
|
|
1744
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1745
|
+
|
|
1746
|
+
provider: ServiceProvider = DatabricksProvider(w=w)
|
|
1747
|
+
optimized_prompt: PromptModel = provider.optimize_prompt(self)
|
|
1748
|
+
return optimized_prompt
|
|
1749
|
+
|
|
1750
|
+
@model_validator(mode="after")
|
|
1751
|
+
def set_defaults(self):
|
|
1752
|
+
# If no prompt is specified, try to use the agent's prompt
|
|
1753
|
+
if self.prompt is None:
|
|
1754
|
+
if isinstance(self.agent.prompt, PromptModel):
|
|
1755
|
+
self.prompt = self.agent.prompt
|
|
1756
|
+
else:
|
|
1757
|
+
raise ValueError(
|
|
1758
|
+
f"Prompt optimization '{self.name}' requires either an explicit prompt "
|
|
1759
|
+
f"or an agent with a prompt configured"
|
|
1760
|
+
)
|
|
1761
|
+
|
|
1762
|
+
if self.reflection_model is None:
|
|
1763
|
+
self.reflection_model = self.agent.model
|
|
1764
|
+
|
|
1765
|
+
if self.scorer_model is None:
|
|
1766
|
+
self.scorer_model = self.agent.model
|
|
1767
|
+
|
|
1768
|
+
return self
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
class OptimizationsModel(BaseModel):
|
|
1772
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1773
|
+
training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
|
|
1774
|
+
prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
|
|
1775
|
+
default_factory=dict
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
|
|
1779
|
+
"""
|
|
1780
|
+
Optimize all prompts in this configuration.
|
|
1781
|
+
|
|
1782
|
+
This method:
|
|
1783
|
+
1. Ensures all training datasets are created/registered in MLflow
|
|
1784
|
+
2. Runs each prompt optimization
|
|
1785
|
+
|
|
1786
|
+
Args:
|
|
1787
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1788
|
+
|
|
1789
|
+
Returns:
|
|
1790
|
+
dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
|
|
1791
|
+
"""
|
|
1792
|
+
# First, ensure all training datasets are created/registered in MLflow
|
|
1793
|
+
logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
|
|
1794
|
+
for dataset_name, dataset_model in self.training_datasets.items():
|
|
1795
|
+
logger.debug(f"Creating/updating dataset: {dataset_name}")
|
|
1796
|
+
dataset_model.as_dataset()
|
|
1797
|
+
|
|
1798
|
+
# Run optimizations
|
|
1799
|
+
results: dict[str, PromptModel] = {}
|
|
1800
|
+
for name, optimization in self.prompt_optimizations.items():
|
|
1801
|
+
results[name] = optimization.optimize(w)
|
|
1802
|
+
return results
|
|
1803
|
+
|
|
1804
|
+
|
|
1462
1805
|
class DatasetFormat(str, Enum):
|
|
1463
1806
|
CSV = "csv"
|
|
1464
1807
|
DELTA = "delta"
|
|
@@ -1537,6 +1880,7 @@ class AppConfig(BaseModel):
|
|
|
1537
1880
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
1538
1881
|
app: Optional[AppModel] = None
|
|
1539
1882
|
evaluation: Optional[EvaluationModel] = None
|
|
1883
|
+
optimizations: Optional[OptimizationsModel] = None
|
|
1540
1884
|
datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
|
|
1541
1885
|
unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
|
|
1542
1886
|
default_factory=list
|
dao_ai/graph.py
CHANGED
|
@@ -62,11 +62,19 @@ def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTo
|
|
|
62
62
|
logger.debug(
|
|
63
63
|
f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
|
|
64
64
|
)
|
|
65
|
+
|
|
66
|
+
# Use handoff_prompt if provided, otherwise create default description
|
|
67
|
+
handoff_description = handoff_to_agent.handoff_prompt or (
|
|
68
|
+
handoff_to_agent.description
|
|
69
|
+
if handoff_to_agent.description
|
|
70
|
+
else "general assistance and questions"
|
|
71
|
+
)
|
|
72
|
+
|
|
65
73
|
handoff_tools.append(
|
|
66
74
|
swarm_handoff_tool(
|
|
67
75
|
agent_name=handoff_to_agent.name,
|
|
68
76
|
description=f"Ask {handoff_to_agent.name} for help with: "
|
|
69
|
-
+
|
|
77
|
+
+ handoff_description,
|
|
70
78
|
)
|
|
71
79
|
)
|
|
72
80
|
return handoff_tools
|
|
@@ -79,13 +87,25 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
79
87
|
for registered_agent in config.app.agents:
|
|
80
88
|
agents.append(
|
|
81
89
|
create_agent_node(
|
|
82
|
-
|
|
90
|
+
agent=registered_agent,
|
|
91
|
+
memory=config.app.orchestration.memory
|
|
92
|
+
if config.app.orchestration
|
|
93
|
+
else None,
|
|
94
|
+
chat_history=config.app.chat_history,
|
|
95
|
+
additional_tools=[],
|
|
83
96
|
)
|
|
84
97
|
)
|
|
98
|
+
# Use handoff_prompt if provided, otherwise create default description
|
|
99
|
+
handoff_description = registered_agent.handoff_prompt or (
|
|
100
|
+
registered_agent.description
|
|
101
|
+
if registered_agent.description
|
|
102
|
+
else f"General assistance with {registered_agent.name} related tasks"
|
|
103
|
+
)
|
|
104
|
+
|
|
85
105
|
tools.append(
|
|
86
106
|
supervisor_handoff_tool(
|
|
87
107
|
agent_name=registered_agent.name,
|
|
88
|
-
description=
|
|
108
|
+
description=handoff_description,
|
|
89
109
|
)
|
|
90
110
|
)
|
|
91
111
|
|
|
@@ -169,7 +189,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
169
189
|
)
|
|
170
190
|
agents.append(
|
|
171
191
|
create_agent_node(
|
|
172
|
-
|
|
192
|
+
agent=registered_agent,
|
|
193
|
+
memory=config.app.orchestration.memory
|
|
194
|
+
if config.app.orchestration
|
|
195
|
+
else None,
|
|
196
|
+
chat_history=config.app.chat_history,
|
|
197
|
+
additional_tools=handoff_tools,
|
|
173
198
|
)
|
|
174
199
|
)
|
|
175
200
|
|