dao-ai 0.0.25__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/agent_as_code.py CHANGED
@@ -7,6 +7,9 @@ from mlflow.pyfunc import ResponsesAgent
7
7
 
8
8
  from dao_ai.config import AppConfig
9
9
 
10
+ mlflow.set_registry_uri("databricks-uc")
11
+ mlflow.set_tracking_uri("databricks")
12
+
10
13
  mlflow.langchain.autolog()
11
14
 
12
15
  model_config: ModelConfig = ModelConfig()
dao_ai/config.py CHANGED
@@ -30,12 +30,15 @@ from databricks_langchain import (
30
30
  DatabricksFunctionClient,
31
31
  )
32
32
  from langchain_core.language_models import LanguageModelLike
33
+ from langchain_core.messages import BaseMessage, messages_from_dict
33
34
  from langchain_core.runnables.base import RunnableLike
34
35
  from langchain_openai import ChatOpenAI
35
36
  from langgraph.checkpoint.base import BaseCheckpointSaver
36
37
  from langgraph.graph.state import CompiledStateGraph
37
38
  from langgraph.store.base import BaseStore
38
39
  from loguru import logger
40
+ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
+ from mlflow.genai.prompts import PromptVersion, load_prompt
39
42
  from mlflow.models import ModelConfig
40
43
  from mlflow.models.resources import (
41
44
  DatabricksFunction,
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
49
52
  DatabricksVectorSearchIndex,
50
53
  )
51
54
  from mlflow.pyfunc import ChatModel, ResponsesAgent
55
+ from mlflow.types.responses import (
56
+ ResponsesAgentRequest,
57
+ )
52
58
  from pydantic import (
53
59
  BaseModel,
54
60
  ConfigDict,
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
324
330
  "serving.serving-endpoints",
325
331
  ]
326
332
 
333
+ @property
334
+ def uri(self) -> str:
335
+ return f"databricks:/{self.name}"
336
+
327
337
  def as_resources(self) -> Sequence[DatabricksResource]:
328
338
  return [
329
339
  DatabricksServingEndpoint(
@@ -387,6 +397,13 @@ class VectorSearchEndpoint(BaseModel):
387
397
  name: str
388
398
  type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
389
399
 
400
+ @field_serializer("type")
401
+ def serialize_type(self, value: VectorSearchEndpointType) -> str:
402
+ """Ensure enum is serialized to string value."""
403
+ if isinstance(value, VectorSearchEndpointType):
404
+ return value.value
405
+ return str(value)
406
+
390
407
 
391
408
  class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
392
409
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -878,6 +895,55 @@ class SearchParametersModel(BaseModel):
878
895
  query_type: Optional[str] = "ANN"
879
896
 
880
897
 
898
+ class RerankParametersModel(BaseModel):
899
+ """
900
+ Configuration for reranking retrieved documents using FlashRank.
901
+
902
+ FlashRank provides fast, local reranking without API calls using lightweight
903
+ cross-encoder models. Reranking improves retrieval quality by reordering results
904
+ based on semantic relevance to the query.
905
+
906
+ Typical workflow:
907
+ 1. Retrieve more documents than needed (e.g., 50 via num_results)
908
+ 2. Rerank all retrieved documents
909
+ 3. Return top_n best matches (e.g., 5)
910
+
911
+ Example:
912
+ ```yaml
913
+ retriever:
914
+ search_parameters:
915
+ num_results: 50 # Retrieve more candidates
916
+ rerank:
917
+ model: ms-marco-MiniLM-L-12-v2
918
+ top_n: 5 # Return top 5 after reranking
919
+ ```
920
+
921
+ Available models (from fastest to most accurate):
922
+ - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
923
+ - "ms-marco-MiniLM-L-6-v2"
924
+ - "ms-marco-MiniLM-L-12-v2" (default, good balance)
925
+ - "rank-T5-flan" (most accurate, slower)
926
+ """
927
+
928
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
929
+
930
+ model: str = Field(
931
+ default="ms-marco-MiniLM-L-12-v2",
932
+ description="FlashRank model name. Default provides good balance of speed and accuracy.",
933
+ )
934
+ top_n: Optional[int] = Field(
935
+ default=None,
936
+ description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
937
+ )
938
+ cache_dir: Optional[str] = Field(
939
+ default="/tmp/flashrank_cache",
940
+ description="Directory to cache downloaded model weights.",
941
+ )
942
+ columns: Optional[list[str]] = Field(
943
+ default_factory=list, description="Columns to rerank using DatabricksReranker"
944
+ )
945
+
946
+
881
947
  class RetrieverModel(BaseModel):
882
948
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
883
949
  vector_store: VectorStoreModel
@@ -885,6 +951,10 @@ class RetrieverModel(BaseModel):
885
951
  search_parameters: SearchParametersModel = Field(
886
952
  default_factory=SearchParametersModel
887
953
  )
954
+ rerank: Optional[RerankParametersModel | bool] = Field(
955
+ default=None,
956
+ description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
957
+ )
888
958
 
889
959
  @model_validator(mode="after")
890
960
  def set_default_columns(self):
@@ -893,6 +963,13 @@ class RetrieverModel(BaseModel):
893
963
  self.columns = columns
894
964
  return self
895
965
 
966
+ @model_validator(mode="after")
967
+ def set_default_reranker(self):
968
+ """Convert bool to ReRankParametersModel with defaults."""
969
+ if isinstance(self.rerank, bool) and self.rerank:
970
+ self.rerank = RerankParametersModel()
971
+ return self
972
+
896
973
 
897
974
  class FunctionType(str, Enum):
898
975
  PYTHON = "python"
@@ -988,22 +1065,120 @@ class TransportType(str, Enum):
988
1065
  class McpFunctionModel(BaseFunctionModel, HasFullName):
989
1066
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
990
1067
  type: Literal[FunctionType.MCP] = FunctionType.MCP
991
-
992
1068
  transport: TransportType = TransportType.STREAMABLE_HTTP
993
1069
  command: Optional[str] = "python"
994
1070
  url: Optional[AnyVariable] = None
995
- connection: Optional[ConnectionModel] = None
996
1071
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
997
1072
  args: list[str] = Field(default_factory=list)
998
1073
  pat: Optional[AnyVariable] = None
999
1074
  client_id: Optional[AnyVariable] = None
1000
1075
  client_secret: Optional[AnyVariable] = None
1001
1076
  workspace_host: Optional[AnyVariable] = None
1077
+ connection: Optional[ConnectionModel] = None
1078
+ functions: Optional[SchemaModel] = None
1079
+ genie_room: Optional[GenieRoomModel] = None
1080
+ sql: Optional[bool] = None
1081
+ vector_search: Optional[VectorStoreModel] = None
1002
1082
 
1003
1083
  @property
1004
1084
  def full_name(self) -> str:
1005
1085
  return self.name
1006
1086
 
1087
+ def _get_workspace_host(self) -> str:
1088
+ """
1089
+ Get the workspace host, either from config or from workspace client.
1090
+
1091
+ If connection is provided, uses its workspace client.
1092
+ Otherwise, falls back to creating a new workspace client.
1093
+
1094
+ Returns:
1095
+ str: The workspace host URL without trailing slash
1096
+ """
1097
+ from databricks.sdk import WorkspaceClient
1098
+
1099
+ # Try to get workspace_host from config
1100
+ workspace_host: str | None = (
1101
+ value_of(self.workspace_host) if self.workspace_host else None
1102
+ )
1103
+
1104
+ # If no workspace_host in config, get it from workspace client
1105
+ if not workspace_host:
1106
+ # Use connection's workspace client if available
1107
+ if self.connection:
1108
+ workspace_host = self.connection.workspace_client.config.host
1109
+ else:
1110
+ # Create a default workspace client
1111
+ w: WorkspaceClient = WorkspaceClient()
1112
+ workspace_host = w.config.host
1113
+
1114
+ # Remove trailing slash
1115
+ return workspace_host.rstrip("/")
1116
+
1117
+ @property
1118
+ def mcp_url(self) -> str:
1119
+ """
1120
+ Get the MCP URL for this function.
1121
+
1122
+ Returns the URL based on the configured source:
1123
+ - If url is set, returns it directly
1124
+ - If connection is set, constructs URL from connection
1125
+ - If genie_room is set, constructs Genie MCP URL
1126
+ - If sql is set, constructs DBSQL MCP URL (serverless)
1127
+ - If vector_search is set, constructs Vector Search MCP URL
1128
+ - If functions is set, constructs UC Functions MCP URL
1129
+
1130
+ URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
1131
+ - Genie: https://{host}/api/2.0/mcp/genie/{space_id}
1132
+ - DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
1133
+ - Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
1134
+ - UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
1135
+ - Connection: https://{host}/api/2.0/mcp/external/{connection_name}
1136
+ """
1137
+ # Direct URL provided
1138
+ if self.url:
1139
+ return self.url
1140
+
1141
+ # Get workspace host (from config, connection, or default workspace client)
1142
+ workspace_host: str = self._get_workspace_host()
1143
+
1144
+ # UC Connection
1145
+ if self.connection:
1146
+ connection_name: str = self.connection.name
1147
+ return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
1148
+
1149
+ # Genie Room
1150
+ if self.genie_room:
1151
+ space_id: str = value_of(self.genie_room.space_id)
1152
+ return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
1153
+
1154
+ # DBSQL MCP server (serverless, workspace-level)
1155
+ if self.sql:
1156
+ return f"{workspace_host}/api/2.0/mcp/sql"
1157
+
1158
+ # Vector Search
1159
+ if self.vector_search:
1160
+ if (
1161
+ not self.vector_search.index
1162
+ or not self.vector_search.index.schema_model
1163
+ ):
1164
+ raise ValueError(
1165
+ "vector_search must have an index with a schema (catalog/schema) configured"
1166
+ )
1167
+ catalog: str = self.vector_search.index.schema_model.catalog_name
1168
+ schema: str = self.vector_search.index.schema_model.schema_name
1169
+ return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
1170
+
1171
+ # UC Functions MCP server
1172
+ if self.functions:
1173
+ catalog: str = self.functions.catalog_name
1174
+ schema: str = self.functions.schema_name
1175
+ return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
1176
+
1177
+ raise ValueError(
1178
+ "No URL source configured. Provide one of: url, connection, genie_room, "
1179
+ "sql, vector_search, or functions"
1180
+ )
1181
+
1007
1182
  @field_serializer("transport")
1008
1183
  def serialize_transport(self, value) -> str:
1009
1184
  if isinstance(value, TransportType):
@@ -1011,32 +1186,56 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1011
1186
  return str(value)
1012
1187
 
1013
1188
  @model_validator(mode="after")
1014
- def validate_mutually_exclusive(self):
1015
- if self.transport == TransportType.STREAMABLE_HTTP and not (
1016
- self.url or self.connection
1017
- ):
1018
- raise ValueError(
1019
- "url or connection must be provided for STREAMABLE_HTTP transport"
1020
- )
1021
- if self.transport == TransportType.STDIO and not self.command:
1022
- raise ValueError("command must not be provided for STDIO transport")
1023
- if self.transport == TransportType.STDIO and not self.args:
1024
- raise ValueError("args must not be provided for STDIO transport")
1189
+ def validate_mutually_exclusive(self) -> "McpFunctionModel":
1190
+ """Validate that exactly one URL source is provided."""
1191
+ # Count how many URL sources are provided
1192
+ url_sources: list[tuple[str, Any]] = [
1193
+ ("url", self.url),
1194
+ ("connection", self.connection),
1195
+ ("genie_room", self.genie_room),
1196
+ ("sql", self.sql),
1197
+ ("vector_search", self.vector_search),
1198
+ ("functions", self.functions),
1199
+ ]
1200
+
1201
+ provided_sources: list[str] = [
1202
+ name for name, value in url_sources if value is not None
1203
+ ]
1204
+
1205
+ if self.transport == TransportType.STREAMABLE_HTTP:
1206
+ if len(provided_sources) == 0:
1207
+ raise ValueError(
1208
+ "For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
1209
+ "url, connection, genie_room, sql, vector_search, or functions"
1210
+ )
1211
+ if len(provided_sources) > 1:
1212
+ raise ValueError(
1213
+ f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
1214
+ f"Found: {', '.join(provided_sources)}. "
1215
+ f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
1216
+ )
1217
+
1218
+ if self.transport == TransportType.STDIO:
1219
+ if not self.command:
1220
+ raise ValueError("command must be provided for STDIO transport")
1221
+ if not self.args:
1222
+ raise ValueError("args must be provided for STDIO transport")
1223
+
1025
1224
  return self
1026
1225
 
1027
1226
  @model_validator(mode="after")
1028
- def update_url(self):
1227
+ def update_url(self) -> "McpFunctionModel":
1029
1228
  self.url = value_of(self.url)
1030
1229
  return self
1031
1230
 
1032
1231
  @model_validator(mode="after")
1033
- def update_headers(self):
1232
+ def update_headers(self) -> "McpFunctionModel":
1034
1233
  for key, value in self.headers.items():
1035
1234
  self.headers[key] = value_of(value)
1036
1235
  return self
1037
1236
 
1038
1237
  @model_validator(mode="after")
1039
- def validate_auth_methods(self):
1238
+ def validate_auth_methods(self) -> "McpFunctionModel":
1040
1239
  oauth_fields: Sequence[Any] = [
1041
1240
  self.client_id,
1042
1241
  self.client_secret,
@@ -1052,10 +1251,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1052
1251
  "Please provide either OAuth credentials or user credentials."
1053
1252
  )
1054
1253
 
1055
- if (has_oauth or has_user_auth) and not self.workspace_host:
1056
- raise ValueError(
1057
- "Workspace host must be provided when using OAuth or user credentials."
1058
- )
1254
+ # Note: workspace_host is optional - it will be derived from workspace client if not provided
1059
1255
 
1060
1256
  return self
1061
1257
 
@@ -1181,17 +1377,32 @@ class PromptModel(BaseModel, HasFullName):
1181
1377
  from dao_ai.providers.databricks import DatabricksProvider
1182
1378
 
1183
1379
  provider: DatabricksProvider = DatabricksProvider()
1184
- prompt: str = provider.get_prompt(self)
1185
- return prompt
1380
+ prompt_version = provider.get_prompt(self)
1381
+ return prompt_version.to_single_brace_format()
1186
1382
 
1187
1383
  @property
1188
1384
  def full_name(self) -> str:
1385
+ prompt_name: str = self.name
1189
1386
  if self.schema_model:
1190
- name: str = ""
1191
- if self.name:
1192
- name = f".{self.name}"
1193
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
1194
- return self.name
1387
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1388
+ return prompt_name
1389
+
1390
+ @property
1391
+ def uri(self) -> str:
1392
+ prompt_uri: str = f"prompts:/{self.full_name}"
1393
+
1394
+ if self.alias:
1395
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1396
+ elif self.version:
1397
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1398
+ else:
1399
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1400
+
1401
+ return prompt_uri
1402
+
1403
+ def as_prompt(self) -> PromptVersion:
1404
+ prompt_version: PromptVersion = load_prompt(self.uri)
1405
+ return prompt_version
1195
1406
 
1196
1407
  @model_validator(mode="after")
1197
1408
  def validate_mutually_exclusive(self):
@@ -1213,6 +1424,17 @@ class AgentModel(BaseModel):
1213
1424
  pre_agent_hook: Optional[FunctionHook] = None
1214
1425
  post_agent_hook: Optional[FunctionHook] = None
1215
1426
 
1427
+ def as_runnable(self) -> RunnableLike:
1428
+ from dao_ai.nodes import create_agent_node
1429
+
1430
+ return create_agent_node(self)
1431
+
1432
+ def as_responses_agent(self) -> ResponsesAgent:
1433
+ from dao_ai.models import create_responses_agent
1434
+
1435
+ graph: CompiledStateGraph = self.as_runnable()
1436
+ return create_responses_agent(graph)
1437
+
1216
1438
 
1217
1439
  class SupervisorModel(BaseModel):
1218
1440
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1330,6 +1552,19 @@ class ChatPayload(BaseModel):
1330
1552
 
1331
1553
  return self
1332
1554
 
1555
+ def as_messages(self) -> Sequence[BaseMessage]:
1556
+ return messages_from_dict(
1557
+ [{"type": m.role, "content": m.content} for m in self.messages]
1558
+ )
1559
+
1560
+ def as_agent_request(self) -> ResponsesAgentRequest:
1561
+ from mlflow.types.responses_helpers import Message as _Message
1562
+
1563
+ return ResponsesAgentRequest(
1564
+ input=[_Message(role=m.role, content=m.content) for m in self.messages],
1565
+ custom_inputs=self.custom_inputs,
1566
+ )
1567
+
1333
1568
 
1334
1569
  class ChatHistoryModel(BaseModel):
1335
1570
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1459,6 +1694,174 @@ class EvaluationModel(BaseModel):
1459
1694
  guidelines: list[GuidelineModel] = Field(default_factory=list)
1460
1695
 
1461
1696
 
1697
+ class EvaluationDatasetExpectationsModel(BaseModel):
1698
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1699
+ expected_response: Optional[str] = None
1700
+ expected_facts: Optional[list[str]] = None
1701
+
1702
+ @model_validator(mode="after")
1703
+ def validate_mutually_exclusive(self):
1704
+ if self.expected_response is not None and self.expected_facts is not None:
1705
+ raise ValueError("Cannot specify both expected_response and expected_facts")
1706
+ return self
1707
+
1708
+
1709
+ class EvaluationDatasetEntryModel(BaseModel):
1710
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1711
+ inputs: ChatPayload
1712
+ expectations: EvaluationDatasetExpectationsModel
1713
+
1714
+ def to_mlflow_format(self) -> dict[str, Any]:
1715
+ """
1716
+ Convert to MLflow evaluation dataset format.
1717
+
1718
+ Flattens the expectations fields to the top level alongside inputs,
1719
+ which is the format expected by MLflow's Correctness scorer.
1720
+
1721
+ Returns:
1722
+ dict: Flattened dictionary with inputs and expectation fields at top level
1723
+ """
1724
+ result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
1725
+
1726
+ # Flatten expectations to top level for MLflow compatibility
1727
+ if self.expectations.expected_response is not None:
1728
+ result["expected_response"] = self.expectations.expected_response
1729
+ if self.expectations.expected_facts is not None:
1730
+ result["expected_facts"] = self.expectations.expected_facts
1731
+
1732
+ return result
1733
+
1734
+
1735
+ class EvaluationDatasetModel(BaseModel, HasFullName):
1736
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1737
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1738
+ name: str
1739
+ data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
1740
+ overwrite: Optional[bool] = False
1741
+
1742
+ def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
1743
+ evaluation_dataset: EvaluationDataset
1744
+ needs_creation: bool = False
1745
+
1746
+ try:
1747
+ evaluation_dataset = get_dataset(name=self.full_name)
1748
+ if self.overwrite:
1749
+ logger.warning(f"Overwriting dataset {self.full_name}")
1750
+ workspace_client: WorkspaceClient = w if w else WorkspaceClient()
1751
+ logger.debug(f"Dropping table: {self.full_name}")
1752
+ workspace_client.tables.delete(full_name=self.full_name)
1753
+ needs_creation = True
1754
+ except Exception:
1755
+ logger.warning(
1756
+ f"Dataset {self.full_name} not found, will create new dataset"
1757
+ )
1758
+ needs_creation = True
1759
+
1760
+ # Create dataset if needed (either new or after overwrite)
1761
+ if needs_creation:
1762
+ evaluation_dataset = create_dataset(name=self.full_name)
1763
+ if self.data:
1764
+ logger.debug(
1765
+ f"Merging {len(self.data)} entries into dataset {self.full_name}"
1766
+ )
1767
+ # Use to_mlflow_format() to flatten expectations for MLflow compatibility
1768
+ evaluation_dataset.merge_records(
1769
+ [e.to_mlflow_format() for e in self.data]
1770
+ )
1771
+
1772
+ return evaluation_dataset
1773
+
1774
+ @property
1775
+ def full_name(self) -> str:
1776
+ if self.schema_model:
1777
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1778
+ return self.name
1779
+
1780
+
1781
+ class PromptOptimizationModel(BaseModel):
1782
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1783
+ name: str
1784
+ prompt: Optional[PromptModel] = None
1785
+ agent: AgentModel
1786
+ dataset: (
1787
+ EvaluationDatasetModel | str
1788
+ ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
1789
+ reflection_model: Optional[LLMModel | str] = None
1790
+ num_candidates: Optional[int] = 50
1791
+ scorer_model: Optional[LLMModel | str] = None
1792
+
1793
+ def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1794
+ """
1795
+ Optimize the prompt using MLflow's prompt optimization.
1796
+
1797
+ Args:
1798
+ w: Optional WorkspaceClient for Databricks operations
1799
+
1800
+ Returns:
1801
+ PromptModel: The optimized prompt model with new URI
1802
+ """
1803
+ from dao_ai.providers.base import ServiceProvider
1804
+ from dao_ai.providers.databricks import DatabricksProvider
1805
+
1806
+ provider: ServiceProvider = DatabricksProvider(w=w)
1807
+ optimized_prompt: PromptModel = provider.optimize_prompt(self)
1808
+ return optimized_prompt
1809
+
1810
+ @model_validator(mode="after")
1811
+ def set_defaults(self):
1812
+ # If no prompt is specified, try to use the agent's prompt
1813
+ if self.prompt is None:
1814
+ if isinstance(self.agent.prompt, PromptModel):
1815
+ self.prompt = self.agent.prompt
1816
+ else:
1817
+ raise ValueError(
1818
+ f"Prompt optimization '{self.name}' requires either an explicit prompt "
1819
+ f"or an agent with a prompt configured"
1820
+ )
1821
+
1822
+ if self.reflection_model is None:
1823
+ self.reflection_model = self.agent.model
1824
+
1825
+ if self.scorer_model is None:
1826
+ self.scorer_model = self.agent.model
1827
+
1828
+ return self
1829
+
1830
+
1831
+ class OptimizationsModel(BaseModel):
1832
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1833
+ training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
1834
+ prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
1835
+ default_factory=dict
1836
+ )
1837
+
1838
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
1839
+ """
1840
+ Optimize all prompts in this configuration.
1841
+
1842
+ This method:
1843
+ 1. Ensures all training datasets are created/registered in MLflow
1844
+ 2. Runs each prompt optimization
1845
+
1846
+ Args:
1847
+ w: Optional WorkspaceClient for Databricks operations
1848
+
1849
+ Returns:
1850
+ dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
1851
+ """
1852
+ # First, ensure all training datasets are created/registered in MLflow
1853
+ logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
1854
+ for dataset_name, dataset_model in self.training_datasets.items():
1855
+ logger.debug(f"Creating/updating dataset: {dataset_name}")
1856
+ dataset_model.as_dataset()
1857
+
1858
+ # Run optimizations
1859
+ results: dict[str, PromptModel] = {}
1860
+ for name, optimization in self.prompt_optimizations.items():
1861
+ results[name] = optimization.optimize(w)
1862
+ return results
1863
+
1864
+
1462
1865
  class DatasetFormat(str, Enum):
1463
1866
  CSV = "csv"
1464
1867
  DELTA = "delta"
@@ -1537,6 +1940,7 @@ class AppConfig(BaseModel):
1537
1940
  agents: dict[str, AgentModel] = Field(default_factory=dict)
1538
1941
  app: Optional[AppModel] = None
1539
1942
  evaluation: Optional[EvaluationModel] = None
1943
+ optimizations: Optional[OptimizationsModel] = None
1540
1944
  datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
1541
1945
  unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
1542
1946
  default_factory=list
dao_ai/graph.py CHANGED
@@ -62,11 +62,19 @@ def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTo
62
62
  logger.debug(
63
63
  f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
64
64
  )
65
+
66
+ # Use handoff_prompt if provided, otherwise create default description
67
+ handoff_description = handoff_to_agent.handoff_prompt or (
68
+ handoff_to_agent.description
69
+ if handoff_to_agent.description
70
+ else "general assistance and questions"
71
+ )
72
+
65
73
  handoff_tools.append(
66
74
  swarm_handoff_tool(
67
75
  agent_name=handoff_to_agent.name,
68
76
  description=f"Ask {handoff_to_agent.name} for help with: "
69
- + handoff_to_agent.handoff_prompt,
77
+ + handoff_description,
70
78
  )
71
79
  )
72
80
  return handoff_tools
@@ -79,13 +87,25 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
79
87
  for registered_agent in config.app.agents:
80
88
  agents.append(
81
89
  create_agent_node(
82
- app=config.app, agent=registered_agent, additional_tools=[]
90
+ agent=registered_agent,
91
+ memory=config.app.orchestration.memory
92
+ if config.app.orchestration
93
+ else None,
94
+ chat_history=config.app.chat_history,
95
+ additional_tools=[],
83
96
  )
84
97
  )
98
+ # Use handoff_prompt if provided, otherwise create default description
99
+ handoff_description = registered_agent.handoff_prompt or (
100
+ registered_agent.description
101
+ if registered_agent.description
102
+ else f"General assistance with {registered_agent.name} related tasks"
103
+ )
104
+
85
105
  tools.append(
86
106
  supervisor_handoff_tool(
87
107
  agent_name=registered_agent.name,
88
- description=registered_agent.handoff_prompt,
108
+ description=handoff_description,
89
109
  )
90
110
  )
91
111
 
@@ -169,7 +189,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
169
189
  )
170
190
  agents.append(
171
191
  create_agent_node(
172
- app=config.app, agent=registered_agent, additional_tools=handoff_tools
192
+ agent=registered_agent,
193
+ memory=config.app.orchestration.memory
194
+ if config.app.orchestration
195
+ else None,
196
+ chat_history=config.app.chat_history,
197
+ additional_tools=handoff_tools,
173
198
  )
174
199
  )
175
200