dao-ai 0.0.25__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +3 -0
- dao_ai/config.py +431 -27
- dao_ai/graph.py +29 -4
- dao_ai/nodes.py +29 -20
- dao_ai/providers/databricks.py +536 -35
- dao_ai/tools/genie.py +2 -3
- dao_ai/tools/mcp.py +46 -27
- dao_ai/tools/vector_search.py +232 -22
- dao_ai/utils.py +57 -1
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/METADATA +6 -3
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/RECORD +14 -14
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/WHEEL +1 -1
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/licenses/LICENSE +0 -0
dao_ai/agent_as_code.py
CHANGED
dao_ai/config.py
CHANGED
|
@@ -30,12 +30,15 @@ from databricks_langchain import (
|
|
|
30
30
|
DatabricksFunctionClient,
|
|
31
31
|
)
|
|
32
32
|
from langchain_core.language_models import LanguageModelLike
|
|
33
|
+
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
33
34
|
from langchain_core.runnables.base import RunnableLike
|
|
34
35
|
from langchain_openai import ChatOpenAI
|
|
35
36
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
36
37
|
from langgraph.graph.state import CompiledStateGraph
|
|
37
38
|
from langgraph.store.base import BaseStore
|
|
38
39
|
from loguru import logger
|
|
40
|
+
from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
41
|
+
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
39
42
|
from mlflow.models import ModelConfig
|
|
40
43
|
from mlflow.models.resources import (
|
|
41
44
|
DatabricksFunction,
|
|
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
|
|
|
49
52
|
DatabricksVectorSearchIndex,
|
|
50
53
|
)
|
|
51
54
|
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
55
|
+
from mlflow.types.responses import (
|
|
56
|
+
ResponsesAgentRequest,
|
|
57
|
+
)
|
|
52
58
|
from pydantic import (
|
|
53
59
|
BaseModel,
|
|
54
60
|
ConfigDict,
|
|
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
324
330
|
"serving.serving-endpoints",
|
|
325
331
|
]
|
|
326
332
|
|
|
333
|
+
@property
|
|
334
|
+
def uri(self) -> str:
|
|
335
|
+
return f"databricks:/{self.name}"
|
|
336
|
+
|
|
327
337
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
328
338
|
return [
|
|
329
339
|
DatabricksServingEndpoint(
|
|
@@ -387,6 +397,13 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
387
397
|
name: str
|
|
388
398
|
type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
|
|
389
399
|
|
|
400
|
+
@field_serializer("type")
|
|
401
|
+
def serialize_type(self, value: VectorSearchEndpointType) -> str:
|
|
402
|
+
"""Ensure enum is serialized to string value."""
|
|
403
|
+
if isinstance(value, VectorSearchEndpointType):
|
|
404
|
+
return value.value
|
|
405
|
+
return str(value)
|
|
406
|
+
|
|
390
407
|
|
|
391
408
|
class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
392
409
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -878,6 +895,55 @@ class SearchParametersModel(BaseModel):
|
|
|
878
895
|
query_type: Optional[str] = "ANN"
|
|
879
896
|
|
|
880
897
|
|
|
898
|
+
class RerankParametersModel(BaseModel):
|
|
899
|
+
"""
|
|
900
|
+
Configuration for reranking retrieved documents using FlashRank.
|
|
901
|
+
|
|
902
|
+
FlashRank provides fast, local reranking without API calls using lightweight
|
|
903
|
+
cross-encoder models. Reranking improves retrieval quality by reordering results
|
|
904
|
+
based on semantic relevance to the query.
|
|
905
|
+
|
|
906
|
+
Typical workflow:
|
|
907
|
+
1. Retrieve more documents than needed (e.g., 50 via num_results)
|
|
908
|
+
2. Rerank all retrieved documents
|
|
909
|
+
3. Return top_n best matches (e.g., 5)
|
|
910
|
+
|
|
911
|
+
Example:
|
|
912
|
+
```yaml
|
|
913
|
+
retriever:
|
|
914
|
+
search_parameters:
|
|
915
|
+
num_results: 50 # Retrieve more candidates
|
|
916
|
+
rerank:
|
|
917
|
+
model: ms-marco-MiniLM-L-12-v2
|
|
918
|
+
top_n: 5 # Return top 5 after reranking
|
|
919
|
+
```
|
|
920
|
+
|
|
921
|
+
Available models (from fastest to most accurate):
|
|
922
|
+
- "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
|
|
923
|
+
- "ms-marco-MiniLM-L-6-v2"
|
|
924
|
+
- "ms-marco-MiniLM-L-12-v2" (default, good balance)
|
|
925
|
+
- "rank-T5-flan" (most accurate, slower)
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
929
|
+
|
|
930
|
+
model: str = Field(
|
|
931
|
+
default="ms-marco-MiniLM-L-12-v2",
|
|
932
|
+
description="FlashRank model name. Default provides good balance of speed and accuracy.",
|
|
933
|
+
)
|
|
934
|
+
top_n: Optional[int] = Field(
|
|
935
|
+
default=None,
|
|
936
|
+
description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
|
|
937
|
+
)
|
|
938
|
+
cache_dir: Optional[str] = Field(
|
|
939
|
+
default="/tmp/flashrank_cache",
|
|
940
|
+
description="Directory to cache downloaded model weights.",
|
|
941
|
+
)
|
|
942
|
+
columns: Optional[list[str]] = Field(
|
|
943
|
+
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
|
|
881
947
|
class RetrieverModel(BaseModel):
|
|
882
948
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
883
949
|
vector_store: VectorStoreModel
|
|
@@ -885,6 +951,10 @@ class RetrieverModel(BaseModel):
|
|
|
885
951
|
search_parameters: SearchParametersModel = Field(
|
|
886
952
|
default_factory=SearchParametersModel
|
|
887
953
|
)
|
|
954
|
+
rerank: Optional[RerankParametersModel | bool] = Field(
|
|
955
|
+
default=None,
|
|
956
|
+
description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
|
|
957
|
+
)
|
|
888
958
|
|
|
889
959
|
@model_validator(mode="after")
|
|
890
960
|
def set_default_columns(self):
|
|
@@ -893,6 +963,13 @@ class RetrieverModel(BaseModel):
|
|
|
893
963
|
self.columns = columns
|
|
894
964
|
return self
|
|
895
965
|
|
|
966
|
+
@model_validator(mode="after")
|
|
967
|
+
def set_default_reranker(self):
|
|
968
|
+
"""Convert bool to ReRankParametersModel with defaults."""
|
|
969
|
+
if isinstance(self.rerank, bool) and self.rerank:
|
|
970
|
+
self.rerank = RerankParametersModel()
|
|
971
|
+
return self
|
|
972
|
+
|
|
896
973
|
|
|
897
974
|
class FunctionType(str, Enum):
|
|
898
975
|
PYTHON = "python"
|
|
@@ -988,22 +1065,120 @@ class TransportType(str, Enum):
|
|
|
988
1065
|
class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
989
1066
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
990
1067
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
991
|
-
|
|
992
1068
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
993
1069
|
command: Optional[str] = "python"
|
|
994
1070
|
url: Optional[AnyVariable] = None
|
|
995
|
-
connection: Optional[ConnectionModel] = None
|
|
996
1071
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
997
1072
|
args: list[str] = Field(default_factory=list)
|
|
998
1073
|
pat: Optional[AnyVariable] = None
|
|
999
1074
|
client_id: Optional[AnyVariable] = None
|
|
1000
1075
|
client_secret: Optional[AnyVariable] = None
|
|
1001
1076
|
workspace_host: Optional[AnyVariable] = None
|
|
1077
|
+
connection: Optional[ConnectionModel] = None
|
|
1078
|
+
functions: Optional[SchemaModel] = None
|
|
1079
|
+
genie_room: Optional[GenieRoomModel] = None
|
|
1080
|
+
sql: Optional[bool] = None
|
|
1081
|
+
vector_search: Optional[VectorStoreModel] = None
|
|
1002
1082
|
|
|
1003
1083
|
@property
|
|
1004
1084
|
def full_name(self) -> str:
|
|
1005
1085
|
return self.name
|
|
1006
1086
|
|
|
1087
|
+
def _get_workspace_host(self) -> str:
|
|
1088
|
+
"""
|
|
1089
|
+
Get the workspace host, either from config or from workspace client.
|
|
1090
|
+
|
|
1091
|
+
If connection is provided, uses its workspace client.
|
|
1092
|
+
Otherwise, falls back to creating a new workspace client.
|
|
1093
|
+
|
|
1094
|
+
Returns:
|
|
1095
|
+
str: The workspace host URL without trailing slash
|
|
1096
|
+
"""
|
|
1097
|
+
from databricks.sdk import WorkspaceClient
|
|
1098
|
+
|
|
1099
|
+
# Try to get workspace_host from config
|
|
1100
|
+
workspace_host: str | None = (
|
|
1101
|
+
value_of(self.workspace_host) if self.workspace_host else None
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
# If no workspace_host in config, get it from workspace client
|
|
1105
|
+
if not workspace_host:
|
|
1106
|
+
# Use connection's workspace client if available
|
|
1107
|
+
if self.connection:
|
|
1108
|
+
workspace_host = self.connection.workspace_client.config.host
|
|
1109
|
+
else:
|
|
1110
|
+
# Create a default workspace client
|
|
1111
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
1112
|
+
workspace_host = w.config.host
|
|
1113
|
+
|
|
1114
|
+
# Remove trailing slash
|
|
1115
|
+
return workspace_host.rstrip("/")
|
|
1116
|
+
|
|
1117
|
+
@property
|
|
1118
|
+
def mcp_url(self) -> str:
|
|
1119
|
+
"""
|
|
1120
|
+
Get the MCP URL for this function.
|
|
1121
|
+
|
|
1122
|
+
Returns the URL based on the configured source:
|
|
1123
|
+
- If url is set, returns it directly
|
|
1124
|
+
- If connection is set, constructs URL from connection
|
|
1125
|
+
- If genie_room is set, constructs Genie MCP URL
|
|
1126
|
+
- If sql is set, constructs DBSQL MCP URL (serverless)
|
|
1127
|
+
- If vector_search is set, constructs Vector Search MCP URL
|
|
1128
|
+
- If functions is set, constructs UC Functions MCP URL
|
|
1129
|
+
|
|
1130
|
+
URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
|
|
1131
|
+
- Genie: https://{host}/api/2.0/mcp/genie/{space_id}
|
|
1132
|
+
- DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
|
|
1133
|
+
- Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
|
|
1134
|
+
- UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
|
|
1135
|
+
- Connection: https://{host}/api/2.0/mcp/external/{connection_name}
|
|
1136
|
+
"""
|
|
1137
|
+
# Direct URL provided
|
|
1138
|
+
if self.url:
|
|
1139
|
+
return self.url
|
|
1140
|
+
|
|
1141
|
+
# Get workspace host (from config, connection, or default workspace client)
|
|
1142
|
+
workspace_host: str = self._get_workspace_host()
|
|
1143
|
+
|
|
1144
|
+
# UC Connection
|
|
1145
|
+
if self.connection:
|
|
1146
|
+
connection_name: str = self.connection.name
|
|
1147
|
+
return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
|
|
1148
|
+
|
|
1149
|
+
# Genie Room
|
|
1150
|
+
if self.genie_room:
|
|
1151
|
+
space_id: str = value_of(self.genie_room.space_id)
|
|
1152
|
+
return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
|
|
1153
|
+
|
|
1154
|
+
# DBSQL MCP server (serverless, workspace-level)
|
|
1155
|
+
if self.sql:
|
|
1156
|
+
return f"{workspace_host}/api/2.0/mcp/sql"
|
|
1157
|
+
|
|
1158
|
+
# Vector Search
|
|
1159
|
+
if self.vector_search:
|
|
1160
|
+
if (
|
|
1161
|
+
not self.vector_search.index
|
|
1162
|
+
or not self.vector_search.index.schema_model
|
|
1163
|
+
):
|
|
1164
|
+
raise ValueError(
|
|
1165
|
+
"vector_search must have an index with a schema (catalog/schema) configured"
|
|
1166
|
+
)
|
|
1167
|
+
catalog: str = self.vector_search.index.schema_model.catalog_name
|
|
1168
|
+
schema: str = self.vector_search.index.schema_model.schema_name
|
|
1169
|
+
return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
|
|
1170
|
+
|
|
1171
|
+
# UC Functions MCP server
|
|
1172
|
+
if self.functions:
|
|
1173
|
+
catalog: str = self.functions.catalog_name
|
|
1174
|
+
schema: str = self.functions.schema_name
|
|
1175
|
+
return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
|
|
1176
|
+
|
|
1177
|
+
raise ValueError(
|
|
1178
|
+
"No URL source configured. Provide one of: url, connection, genie_room, "
|
|
1179
|
+
"sql, vector_search, or functions"
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1007
1182
|
@field_serializer("transport")
|
|
1008
1183
|
def serialize_transport(self, value) -> str:
|
|
1009
1184
|
if isinstance(value, TransportType):
|
|
@@ -1011,32 +1186,56 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1011
1186
|
return str(value)
|
|
1012
1187
|
|
|
1013
1188
|
@model_validator(mode="after")
|
|
1014
|
-
def validate_mutually_exclusive(self):
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1189
|
+
def validate_mutually_exclusive(self) -> "McpFunctionModel":
|
|
1190
|
+
"""Validate that exactly one URL source is provided."""
|
|
1191
|
+
# Count how many URL sources are provided
|
|
1192
|
+
url_sources: list[tuple[str, Any]] = [
|
|
1193
|
+
("url", self.url),
|
|
1194
|
+
("connection", self.connection),
|
|
1195
|
+
("genie_room", self.genie_room),
|
|
1196
|
+
("sql", self.sql),
|
|
1197
|
+
("vector_search", self.vector_search),
|
|
1198
|
+
("functions", self.functions),
|
|
1199
|
+
]
|
|
1200
|
+
|
|
1201
|
+
provided_sources: list[str] = [
|
|
1202
|
+
name for name, value in url_sources if value is not None
|
|
1203
|
+
]
|
|
1204
|
+
|
|
1205
|
+
if self.transport == TransportType.STREAMABLE_HTTP:
|
|
1206
|
+
if len(provided_sources) == 0:
|
|
1207
|
+
raise ValueError(
|
|
1208
|
+
"For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
|
|
1209
|
+
"url, connection, genie_room, sql, vector_search, or functions"
|
|
1210
|
+
)
|
|
1211
|
+
if len(provided_sources) > 1:
|
|
1212
|
+
raise ValueError(
|
|
1213
|
+
f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
|
|
1214
|
+
f"Found: {', '.join(provided_sources)}. "
|
|
1215
|
+
f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
if self.transport == TransportType.STDIO:
|
|
1219
|
+
if not self.command:
|
|
1220
|
+
raise ValueError("command must be provided for STDIO transport")
|
|
1221
|
+
if not self.args:
|
|
1222
|
+
raise ValueError("args must be provided for STDIO transport")
|
|
1223
|
+
|
|
1025
1224
|
return self
|
|
1026
1225
|
|
|
1027
1226
|
@model_validator(mode="after")
|
|
1028
|
-
def update_url(self):
|
|
1227
|
+
def update_url(self) -> "McpFunctionModel":
|
|
1029
1228
|
self.url = value_of(self.url)
|
|
1030
1229
|
return self
|
|
1031
1230
|
|
|
1032
1231
|
@model_validator(mode="after")
|
|
1033
|
-
def update_headers(self):
|
|
1232
|
+
def update_headers(self) -> "McpFunctionModel":
|
|
1034
1233
|
for key, value in self.headers.items():
|
|
1035
1234
|
self.headers[key] = value_of(value)
|
|
1036
1235
|
return self
|
|
1037
1236
|
|
|
1038
1237
|
@model_validator(mode="after")
|
|
1039
|
-
def validate_auth_methods(self):
|
|
1238
|
+
def validate_auth_methods(self) -> "McpFunctionModel":
|
|
1040
1239
|
oauth_fields: Sequence[Any] = [
|
|
1041
1240
|
self.client_id,
|
|
1042
1241
|
self.client_secret,
|
|
@@ -1052,10 +1251,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1052
1251
|
"Please provide either OAuth credentials or user credentials."
|
|
1053
1252
|
)
|
|
1054
1253
|
|
|
1055
|
-
|
|
1056
|
-
raise ValueError(
|
|
1057
|
-
"Workspace host must be provided when using OAuth or user credentials."
|
|
1058
|
-
)
|
|
1254
|
+
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1059
1255
|
|
|
1060
1256
|
return self
|
|
1061
1257
|
|
|
@@ -1181,17 +1377,32 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1181
1377
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1182
1378
|
|
|
1183
1379
|
provider: DatabricksProvider = DatabricksProvider()
|
|
1184
|
-
|
|
1185
|
-
return
|
|
1380
|
+
prompt_version = provider.get_prompt(self)
|
|
1381
|
+
return prompt_version.to_single_brace_format()
|
|
1186
1382
|
|
|
1187
1383
|
@property
|
|
1188
1384
|
def full_name(self) -> str:
|
|
1385
|
+
prompt_name: str = self.name
|
|
1189
1386
|
if self.schema_model:
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1387
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1388
|
+
return prompt_name
|
|
1389
|
+
|
|
1390
|
+
@property
|
|
1391
|
+
def uri(self) -> str:
|
|
1392
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1393
|
+
|
|
1394
|
+
if self.alias:
|
|
1395
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1396
|
+
elif self.version:
|
|
1397
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1398
|
+
else:
|
|
1399
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1400
|
+
|
|
1401
|
+
return prompt_uri
|
|
1402
|
+
|
|
1403
|
+
def as_prompt(self) -> PromptVersion:
|
|
1404
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1405
|
+
return prompt_version
|
|
1195
1406
|
|
|
1196
1407
|
@model_validator(mode="after")
|
|
1197
1408
|
def validate_mutually_exclusive(self):
|
|
@@ -1213,6 +1424,17 @@ class AgentModel(BaseModel):
|
|
|
1213
1424
|
pre_agent_hook: Optional[FunctionHook] = None
|
|
1214
1425
|
post_agent_hook: Optional[FunctionHook] = None
|
|
1215
1426
|
|
|
1427
|
+
def as_runnable(self) -> RunnableLike:
|
|
1428
|
+
from dao_ai.nodes import create_agent_node
|
|
1429
|
+
|
|
1430
|
+
return create_agent_node(self)
|
|
1431
|
+
|
|
1432
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
1433
|
+
from dao_ai.models import create_responses_agent
|
|
1434
|
+
|
|
1435
|
+
graph: CompiledStateGraph = self.as_runnable()
|
|
1436
|
+
return create_responses_agent(graph)
|
|
1437
|
+
|
|
1216
1438
|
|
|
1217
1439
|
class SupervisorModel(BaseModel):
|
|
1218
1440
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1330,6 +1552,19 @@ class ChatPayload(BaseModel):
|
|
|
1330
1552
|
|
|
1331
1553
|
return self
|
|
1332
1554
|
|
|
1555
|
+
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1556
|
+
return messages_from_dict(
|
|
1557
|
+
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
def as_agent_request(self) -> ResponsesAgentRequest:
|
|
1561
|
+
from mlflow.types.responses_helpers import Message as _Message
|
|
1562
|
+
|
|
1563
|
+
return ResponsesAgentRequest(
|
|
1564
|
+
input=[_Message(role=m.role, content=m.content) for m in self.messages],
|
|
1565
|
+
custom_inputs=self.custom_inputs,
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1333
1568
|
|
|
1334
1569
|
class ChatHistoryModel(BaseModel):
|
|
1335
1570
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1459,6 +1694,174 @@ class EvaluationModel(BaseModel):
|
|
|
1459
1694
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
1460
1695
|
|
|
1461
1696
|
|
|
1697
|
+
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
1698
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1699
|
+
expected_response: Optional[str] = None
|
|
1700
|
+
expected_facts: Optional[list[str]] = None
|
|
1701
|
+
|
|
1702
|
+
@model_validator(mode="after")
|
|
1703
|
+
def validate_mutually_exclusive(self):
|
|
1704
|
+
if self.expected_response is not None and self.expected_facts is not None:
|
|
1705
|
+
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1706
|
+
return self
|
|
1707
|
+
|
|
1708
|
+
|
|
1709
|
+
class EvaluationDatasetEntryModel(BaseModel):
|
|
1710
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1711
|
+
inputs: ChatPayload
|
|
1712
|
+
expectations: EvaluationDatasetExpectationsModel
|
|
1713
|
+
|
|
1714
|
+
def to_mlflow_format(self) -> dict[str, Any]:
|
|
1715
|
+
"""
|
|
1716
|
+
Convert to MLflow evaluation dataset format.
|
|
1717
|
+
|
|
1718
|
+
Flattens the expectations fields to the top level alongside inputs,
|
|
1719
|
+
which is the format expected by MLflow's Correctness scorer.
|
|
1720
|
+
|
|
1721
|
+
Returns:
|
|
1722
|
+
dict: Flattened dictionary with inputs and expectation fields at top level
|
|
1723
|
+
"""
|
|
1724
|
+
result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
|
|
1725
|
+
|
|
1726
|
+
# Flatten expectations to top level for MLflow compatibility
|
|
1727
|
+
if self.expectations.expected_response is not None:
|
|
1728
|
+
result["expected_response"] = self.expectations.expected_response
|
|
1729
|
+
if self.expectations.expected_facts is not None:
|
|
1730
|
+
result["expected_facts"] = self.expectations.expected_facts
|
|
1731
|
+
|
|
1732
|
+
return result
|
|
1733
|
+
|
|
1734
|
+
|
|
1735
|
+
class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
1736
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1737
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1738
|
+
name: str
|
|
1739
|
+
data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
|
|
1740
|
+
overwrite: Optional[bool] = False
|
|
1741
|
+
|
|
1742
|
+
def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
|
|
1743
|
+
evaluation_dataset: EvaluationDataset
|
|
1744
|
+
needs_creation: bool = False
|
|
1745
|
+
|
|
1746
|
+
try:
|
|
1747
|
+
evaluation_dataset = get_dataset(name=self.full_name)
|
|
1748
|
+
if self.overwrite:
|
|
1749
|
+
logger.warning(f"Overwriting dataset {self.full_name}")
|
|
1750
|
+
workspace_client: WorkspaceClient = w if w else WorkspaceClient()
|
|
1751
|
+
logger.debug(f"Dropping table: {self.full_name}")
|
|
1752
|
+
workspace_client.tables.delete(full_name=self.full_name)
|
|
1753
|
+
needs_creation = True
|
|
1754
|
+
except Exception:
|
|
1755
|
+
logger.warning(
|
|
1756
|
+
f"Dataset {self.full_name} not found, will create new dataset"
|
|
1757
|
+
)
|
|
1758
|
+
needs_creation = True
|
|
1759
|
+
|
|
1760
|
+
# Create dataset if needed (either new or after overwrite)
|
|
1761
|
+
if needs_creation:
|
|
1762
|
+
evaluation_dataset = create_dataset(name=self.full_name)
|
|
1763
|
+
if self.data:
|
|
1764
|
+
logger.debug(
|
|
1765
|
+
f"Merging {len(self.data)} entries into dataset {self.full_name}"
|
|
1766
|
+
)
|
|
1767
|
+
# Use to_mlflow_format() to flatten expectations for MLflow compatibility
|
|
1768
|
+
evaluation_dataset.merge_records(
|
|
1769
|
+
[e.to_mlflow_format() for e in self.data]
|
|
1770
|
+
)
|
|
1771
|
+
|
|
1772
|
+
return evaluation_dataset
|
|
1773
|
+
|
|
1774
|
+
@property
|
|
1775
|
+
def full_name(self) -> str:
|
|
1776
|
+
if self.schema_model:
|
|
1777
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1778
|
+
return self.name
|
|
1779
|
+
|
|
1780
|
+
|
|
1781
|
+
class PromptOptimizationModel(BaseModel):
|
|
1782
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1783
|
+
name: str
|
|
1784
|
+
prompt: Optional[PromptModel] = None
|
|
1785
|
+
agent: AgentModel
|
|
1786
|
+
dataset: (
|
|
1787
|
+
EvaluationDatasetModel | str
|
|
1788
|
+
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
1789
|
+
reflection_model: Optional[LLMModel | str] = None
|
|
1790
|
+
num_candidates: Optional[int] = 50
|
|
1791
|
+
scorer_model: Optional[LLMModel | str] = None
|
|
1792
|
+
|
|
1793
|
+
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1794
|
+
"""
|
|
1795
|
+
Optimize the prompt using MLflow's prompt optimization.
|
|
1796
|
+
|
|
1797
|
+
Args:
|
|
1798
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1799
|
+
|
|
1800
|
+
Returns:
|
|
1801
|
+
PromptModel: The optimized prompt model with new URI
|
|
1802
|
+
"""
|
|
1803
|
+
from dao_ai.providers.base import ServiceProvider
|
|
1804
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1805
|
+
|
|
1806
|
+
provider: ServiceProvider = DatabricksProvider(w=w)
|
|
1807
|
+
optimized_prompt: PromptModel = provider.optimize_prompt(self)
|
|
1808
|
+
return optimized_prompt
|
|
1809
|
+
|
|
1810
|
+
@model_validator(mode="after")
|
|
1811
|
+
def set_defaults(self):
|
|
1812
|
+
# If no prompt is specified, try to use the agent's prompt
|
|
1813
|
+
if self.prompt is None:
|
|
1814
|
+
if isinstance(self.agent.prompt, PromptModel):
|
|
1815
|
+
self.prompt = self.agent.prompt
|
|
1816
|
+
else:
|
|
1817
|
+
raise ValueError(
|
|
1818
|
+
f"Prompt optimization '{self.name}' requires either an explicit prompt "
|
|
1819
|
+
f"or an agent with a prompt configured"
|
|
1820
|
+
)
|
|
1821
|
+
|
|
1822
|
+
if self.reflection_model is None:
|
|
1823
|
+
self.reflection_model = self.agent.model
|
|
1824
|
+
|
|
1825
|
+
if self.scorer_model is None:
|
|
1826
|
+
self.scorer_model = self.agent.model
|
|
1827
|
+
|
|
1828
|
+
return self
|
|
1829
|
+
|
|
1830
|
+
|
|
1831
|
+
class OptimizationsModel(BaseModel):
|
|
1832
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1833
|
+
training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
|
|
1834
|
+
prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
|
|
1835
|
+
default_factory=dict
|
|
1836
|
+
)
|
|
1837
|
+
|
|
1838
|
+
def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
|
|
1839
|
+
"""
|
|
1840
|
+
Optimize all prompts in this configuration.
|
|
1841
|
+
|
|
1842
|
+
This method:
|
|
1843
|
+
1. Ensures all training datasets are created/registered in MLflow
|
|
1844
|
+
2. Runs each prompt optimization
|
|
1845
|
+
|
|
1846
|
+
Args:
|
|
1847
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1848
|
+
|
|
1849
|
+
Returns:
|
|
1850
|
+
dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
|
|
1851
|
+
"""
|
|
1852
|
+
# First, ensure all training datasets are created/registered in MLflow
|
|
1853
|
+
logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
|
|
1854
|
+
for dataset_name, dataset_model in self.training_datasets.items():
|
|
1855
|
+
logger.debug(f"Creating/updating dataset: {dataset_name}")
|
|
1856
|
+
dataset_model.as_dataset()
|
|
1857
|
+
|
|
1858
|
+
# Run optimizations
|
|
1859
|
+
results: dict[str, PromptModel] = {}
|
|
1860
|
+
for name, optimization in self.prompt_optimizations.items():
|
|
1861
|
+
results[name] = optimization.optimize(w)
|
|
1862
|
+
return results
|
|
1863
|
+
|
|
1864
|
+
|
|
1462
1865
|
class DatasetFormat(str, Enum):
|
|
1463
1866
|
CSV = "csv"
|
|
1464
1867
|
DELTA = "delta"
|
|
@@ -1537,6 +1940,7 @@ class AppConfig(BaseModel):
|
|
|
1537
1940
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
1538
1941
|
app: Optional[AppModel] = None
|
|
1539
1942
|
evaluation: Optional[EvaluationModel] = None
|
|
1943
|
+
optimizations: Optional[OptimizationsModel] = None
|
|
1540
1944
|
datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
|
|
1541
1945
|
unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
|
|
1542
1946
|
default_factory=list
|
dao_ai/graph.py
CHANGED
|
@@ -62,11 +62,19 @@ def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTo
|
|
|
62
62
|
logger.debug(
|
|
63
63
|
f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
|
|
64
64
|
)
|
|
65
|
+
|
|
66
|
+
# Use handoff_prompt if provided, otherwise create default description
|
|
67
|
+
handoff_description = handoff_to_agent.handoff_prompt or (
|
|
68
|
+
handoff_to_agent.description
|
|
69
|
+
if handoff_to_agent.description
|
|
70
|
+
else "general assistance and questions"
|
|
71
|
+
)
|
|
72
|
+
|
|
65
73
|
handoff_tools.append(
|
|
66
74
|
swarm_handoff_tool(
|
|
67
75
|
agent_name=handoff_to_agent.name,
|
|
68
76
|
description=f"Ask {handoff_to_agent.name} for help with: "
|
|
69
|
-
+
|
|
77
|
+
+ handoff_description,
|
|
70
78
|
)
|
|
71
79
|
)
|
|
72
80
|
return handoff_tools
|
|
@@ -79,13 +87,25 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
79
87
|
for registered_agent in config.app.agents:
|
|
80
88
|
agents.append(
|
|
81
89
|
create_agent_node(
|
|
82
|
-
|
|
90
|
+
agent=registered_agent,
|
|
91
|
+
memory=config.app.orchestration.memory
|
|
92
|
+
if config.app.orchestration
|
|
93
|
+
else None,
|
|
94
|
+
chat_history=config.app.chat_history,
|
|
95
|
+
additional_tools=[],
|
|
83
96
|
)
|
|
84
97
|
)
|
|
98
|
+
# Use handoff_prompt if provided, otherwise create default description
|
|
99
|
+
handoff_description = registered_agent.handoff_prompt or (
|
|
100
|
+
registered_agent.description
|
|
101
|
+
if registered_agent.description
|
|
102
|
+
else f"General assistance with {registered_agent.name} related tasks"
|
|
103
|
+
)
|
|
104
|
+
|
|
85
105
|
tools.append(
|
|
86
106
|
supervisor_handoff_tool(
|
|
87
107
|
agent_name=registered_agent.name,
|
|
88
|
-
description=
|
|
108
|
+
description=handoff_description,
|
|
89
109
|
)
|
|
90
110
|
)
|
|
91
111
|
|
|
@@ -169,7 +189,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
169
189
|
)
|
|
170
190
|
agents.append(
|
|
171
191
|
create_agent_node(
|
|
172
|
-
|
|
192
|
+
agent=registered_agent,
|
|
193
|
+
memory=config.app.orchestration.memory
|
|
194
|
+
if config.app.orchestration
|
|
195
|
+
else None,
|
|
196
|
+
chat_history=config.app.chat_history,
|
|
197
|
+
additional_tools=handoff_tools,
|
|
173
198
|
)
|
|
174
199
|
)
|
|
175
200
|
|