dao-ai 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -30,12 +30,15 @@ from databricks_langchain import (
30
30
  DatabricksFunctionClient,
31
31
  )
32
32
  from langchain_core.language_models import LanguageModelLike
33
+ from langchain_core.messages import BaseMessage, messages_from_dict
33
34
  from langchain_core.runnables.base import RunnableLike
34
35
  from langchain_openai import ChatOpenAI
35
36
  from langgraph.checkpoint.base import BaseCheckpointSaver
36
37
  from langgraph.graph.state import CompiledStateGraph
37
38
  from langgraph.store.base import BaseStore
38
39
  from loguru import logger
40
+ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
+ from mlflow.genai.prompts import PromptVersion, load_prompt
39
42
  from mlflow.models import ModelConfig
40
43
  from mlflow.models.resources import (
41
44
  DatabricksFunction,
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
49
52
  DatabricksVectorSearchIndex,
50
53
  )
51
54
  from mlflow.pyfunc import ChatModel, ResponsesAgent
55
+ from mlflow.types.responses import (
56
+ ResponsesAgentRequest,
57
+ )
52
58
  from pydantic import (
53
59
  BaseModel,
54
60
  ConfigDict,
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
324
330
  "serving.serving-endpoints",
325
331
  ]
326
332
 
333
+ @property
334
+ def uri(self) -> str:
335
+ return f"databricks:/{self.name}"
336
+
327
337
  def as_resources(self) -> Sequence[DatabricksResource]:
328
338
  return [
329
339
  DatabricksServingEndpoint(
@@ -387,6 +397,13 @@ class VectorSearchEndpoint(BaseModel):
387
397
  name: str
388
398
  type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
389
399
 
400
+ @field_serializer("type")
401
+ def serialize_type(self, value: VectorSearchEndpointType) -> str:
402
+ """Ensure enum is serialized to string value."""
403
+ if isinstance(value, VectorSearchEndpointType):
404
+ return value.value
405
+ return str(value)
406
+
390
407
 
391
408
  class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
392
409
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -988,22 +1005,120 @@ class TransportType(str, Enum):
988
1005
  class McpFunctionModel(BaseFunctionModel, HasFullName):
989
1006
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
990
1007
  type: Literal[FunctionType.MCP] = FunctionType.MCP
991
-
992
1008
  transport: TransportType = TransportType.STREAMABLE_HTTP
993
1009
  command: Optional[str] = "python"
994
1010
  url: Optional[AnyVariable] = None
995
- connection: Optional[ConnectionModel] = None
996
1011
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
997
1012
  args: list[str] = Field(default_factory=list)
998
1013
  pat: Optional[AnyVariable] = None
999
1014
  client_id: Optional[AnyVariable] = None
1000
1015
  client_secret: Optional[AnyVariable] = None
1001
1016
  workspace_host: Optional[AnyVariable] = None
1017
+ connection: Optional[ConnectionModel] = None
1018
+ functions: Optional[SchemaModel] = None
1019
+ genie_room: Optional[GenieRoomModel] = None
1020
+ sql: Optional[bool] = None
1021
+ vector_search: Optional[VectorStoreModel] = None
1002
1022
 
1003
1023
  @property
1004
1024
  def full_name(self) -> str:
1005
1025
  return self.name
1006
1026
 
1027
+ def _get_workspace_host(self) -> str:
1028
+ """
1029
+ Get the workspace host, either from config or from workspace client.
1030
+
1031
+ If connection is provided, uses its workspace client.
1032
+ Otherwise, falls back to creating a new workspace client.
1033
+
1034
+ Returns:
1035
+ str: The workspace host URL without trailing slash
1036
+ """
1037
+ from databricks.sdk import WorkspaceClient
1038
+
1039
+ # Try to get workspace_host from config
1040
+ workspace_host: str | None = (
1041
+ value_of(self.workspace_host) if self.workspace_host else None
1042
+ )
1043
+
1044
+ # If no workspace_host in config, get it from workspace client
1045
+ if not workspace_host:
1046
+ # Use connection's workspace client if available
1047
+ if self.connection:
1048
+ workspace_host = self.connection.workspace_client.config.host
1049
+ else:
1050
+ # Create a default workspace client
1051
+ w: WorkspaceClient = WorkspaceClient()
1052
+ workspace_host = w.config.host
1053
+
1054
+ # Remove trailing slash
1055
+ return workspace_host.rstrip("/")
1056
+
1057
+ @property
1058
+ def mcp_url(self) -> str:
1059
+ """
1060
+ Get the MCP URL for this function.
1061
+
1062
+ Returns the URL based on the configured source:
1063
+ - If url is set, returns it directly
1064
+ - If connection is set, constructs URL from connection
1065
+ - If genie_room is set, constructs Genie MCP URL
1066
+ - If sql is set, constructs DBSQL MCP URL (serverless)
1067
+ - If vector_search is set, constructs Vector Search MCP URL
1068
+ - If functions is set, constructs UC Functions MCP URL
1069
+
1070
+ URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
1071
+ - Genie: https://{host}/api/2.0/mcp/genie/{space_id}
1072
+ - DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
1073
+ - Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
1074
+ - UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
1075
+ - Connection: https://{host}/api/2.0/mcp/external/{connection_name}
1076
+ """
1077
+ # Direct URL provided
1078
+ if self.url:
1079
+ return self.url
1080
+
1081
+ # Get workspace host (from config, connection, or default workspace client)
1082
+ workspace_host: str = self._get_workspace_host()
1083
+
1084
+ # UC Connection
1085
+ if self.connection:
1086
+ connection_name: str = self.connection.name
1087
+ return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
1088
+
1089
+ # Genie Room
1090
+ if self.genie_room:
1091
+ space_id: str = value_of(self.genie_room.space_id)
1092
+ return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
1093
+
1094
+ # DBSQL MCP server (serverless, workspace-level)
1095
+ if self.sql:
1096
+ return f"{workspace_host}/api/2.0/mcp/sql"
1097
+
1098
+ # Vector Search
1099
+ if self.vector_search:
1100
+ if (
1101
+ not self.vector_search.index
1102
+ or not self.vector_search.index.schema_model
1103
+ ):
1104
+ raise ValueError(
1105
+ "vector_search must have an index with a schema (catalog/schema) configured"
1106
+ )
1107
+ catalog: str = self.vector_search.index.schema_model.catalog_name
1108
+ schema: str = self.vector_search.index.schema_model.schema_name
1109
+ return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
1110
+
1111
+ # UC Functions MCP server
1112
+ if self.functions:
1113
+ catalog: str = self.functions.catalog_name
1114
+ schema: str = self.functions.schema_name
1115
+ return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
1116
+
1117
+ raise ValueError(
1118
+ "No URL source configured. Provide one of: url, connection, genie_room, "
1119
+ "sql, vector_search, or functions"
1120
+ )
1121
+
1007
1122
  @field_serializer("transport")
1008
1123
  def serialize_transport(self, value) -> str:
1009
1124
  if isinstance(value, TransportType):
@@ -1011,32 +1126,56 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1011
1126
  return str(value)
1012
1127
 
1013
1128
  @model_validator(mode="after")
1014
- def validate_mutually_exclusive(self):
1015
- if self.transport == TransportType.STREAMABLE_HTTP and not (
1016
- self.url or self.connection
1017
- ):
1018
- raise ValueError(
1019
- "url or connection must be provided for STREAMABLE_HTTP transport"
1020
- )
1021
- if self.transport == TransportType.STDIO and not self.command:
1022
- raise ValueError("command must not be provided for STDIO transport")
1023
- if self.transport == TransportType.STDIO and not self.args:
1024
- raise ValueError("args must not be provided for STDIO transport")
1129
+ def validate_mutually_exclusive(self) -> "McpFunctionModel":
1130
+ """Validate that exactly one URL source is provided."""
1131
+ # Count how many URL sources are provided
1132
+ url_sources: list[tuple[str, Any]] = [
1133
+ ("url", self.url),
1134
+ ("connection", self.connection),
1135
+ ("genie_room", self.genie_room),
1136
+ ("sql", self.sql),
1137
+ ("vector_search", self.vector_search),
1138
+ ("functions", self.functions),
1139
+ ]
1140
+
1141
+ provided_sources: list[str] = [
1142
+ name for name, value in url_sources if value is not None
1143
+ ]
1144
+
1145
+ if self.transport == TransportType.STREAMABLE_HTTP:
1146
+ if len(provided_sources) == 0:
1147
+ raise ValueError(
1148
+ "For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
1149
+ "url, connection, genie_room, sql, vector_search, or functions"
1150
+ )
1151
+ if len(provided_sources) > 1:
1152
+ raise ValueError(
1153
+ f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
1154
+ f"Found: {', '.join(provided_sources)}. "
1155
+ f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
1156
+ )
1157
+
1158
+ if self.transport == TransportType.STDIO:
1159
+ if not self.command:
1160
+ raise ValueError("command must be provided for STDIO transport")
1161
+ if not self.args:
1162
+ raise ValueError("args must be provided for STDIO transport")
1163
+
1025
1164
  return self
1026
1165
 
1027
1166
  @model_validator(mode="after")
1028
- def update_url(self):
1167
+ def update_url(self) -> "McpFunctionModel":
1029
1168
  self.url = value_of(self.url)
1030
1169
  return self
1031
1170
 
1032
1171
  @model_validator(mode="after")
1033
- def update_headers(self):
1172
+ def update_headers(self) -> "McpFunctionModel":
1034
1173
  for key, value in self.headers.items():
1035
1174
  self.headers[key] = value_of(value)
1036
1175
  return self
1037
1176
 
1038
1177
  @model_validator(mode="after")
1039
- def validate_auth_methods(self):
1178
+ def validate_auth_methods(self) -> "McpFunctionModel":
1040
1179
  oauth_fields: Sequence[Any] = [
1041
1180
  self.client_id,
1042
1181
  self.client_secret,
@@ -1052,10 +1191,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1052
1191
  "Please provide either OAuth credentials or user credentials."
1053
1192
  )
1054
1193
 
1055
- if (has_oauth or has_user_auth) and not self.workspace_host:
1056
- raise ValueError(
1057
- "Workspace host must be provided when using OAuth or user credentials."
1058
- )
1194
+ # Note: workspace_host is optional - it will be derived from workspace client if not provided
1059
1195
 
1060
1196
  return self
1061
1197
 
@@ -1181,17 +1317,32 @@ class PromptModel(BaseModel, HasFullName):
1181
1317
  from dao_ai.providers.databricks import DatabricksProvider
1182
1318
 
1183
1319
  provider: DatabricksProvider = DatabricksProvider()
1184
- prompt: str = provider.get_prompt(self)
1185
- return prompt
1320
+ prompt_version = provider.get_prompt(self)
1321
+ return prompt_version.to_single_brace_format()
1186
1322
 
1187
1323
  @property
1188
1324
  def full_name(self) -> str:
1325
+ prompt_name: str = self.name
1189
1326
  if self.schema_model:
1190
- name: str = ""
1191
- if self.name:
1192
- name = f".{self.name}"
1193
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
1194
- return self.name
1327
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1328
+ return prompt_name
1329
+
1330
+ @property
1331
+ def uri(self) -> str:
1332
+ prompt_uri: str = f"prompts:/{self.full_name}"
1333
+
1334
+ if self.alias:
1335
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1336
+ elif self.version:
1337
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1338
+ else:
1339
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1340
+
1341
+ return prompt_uri
1342
+
1343
+ def as_prompt(self) -> PromptVersion:
1344
+ prompt_version: PromptVersion = load_prompt(self.uri)
1345
+ return prompt_version
1195
1346
 
1196
1347
  @model_validator(mode="after")
1197
1348
  def validate_mutually_exclusive(self):
@@ -1213,6 +1364,17 @@ class AgentModel(BaseModel):
1213
1364
  pre_agent_hook: Optional[FunctionHook] = None
1214
1365
  post_agent_hook: Optional[FunctionHook] = None
1215
1366
 
1367
+ def as_runnable(self) -> RunnableLike:
1368
+ from dao_ai.nodes import create_agent_node
1369
+
1370
+ return create_agent_node(self)
1371
+
1372
+ def as_responses_agent(self) -> ResponsesAgent:
1373
+ from dao_ai.models import create_responses_agent
1374
+
1375
+ graph: CompiledStateGraph = self.as_runnable()
1376
+ return create_responses_agent(graph)
1377
+
1216
1378
 
1217
1379
  class SupervisorModel(BaseModel):
1218
1380
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1330,6 +1492,19 @@ class ChatPayload(BaseModel):
1330
1492
 
1331
1493
  return self
1332
1494
 
1495
+ def as_messages(self) -> Sequence[BaseMessage]:
1496
+ return messages_from_dict(
1497
+ [{"type": m.role, "content": m.content} for m in self.messages]
1498
+ )
1499
+
1500
+ def as_agent_request(self) -> ResponsesAgentRequest:
1501
+ from mlflow.types.responses_helpers import Message as _Message
1502
+
1503
+ return ResponsesAgentRequest(
1504
+ input=[_Message(role=m.role, content=m.content) for m in self.messages],
1505
+ custom_inputs=self.custom_inputs,
1506
+ )
1507
+
1333
1508
 
1334
1509
  class ChatHistoryModel(BaseModel):
1335
1510
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1459,6 +1634,174 @@ class EvaluationModel(BaseModel):
1459
1634
  guidelines: list[GuidelineModel] = Field(default_factory=list)
1460
1635
 
1461
1636
 
1637
+ class EvaluationDatasetExpectationsModel(BaseModel):
1638
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1639
+ expected_response: Optional[str] = None
1640
+ expected_facts: Optional[list[str]] = None
1641
+
1642
+ @model_validator(mode="after")
1643
+ def validate_mutually_exclusive(self):
1644
+ if self.expected_response is not None and self.expected_facts is not None:
1645
+ raise ValueError("Cannot specify both expected_response and expected_facts")
1646
+ return self
1647
+
1648
+
1649
+ class EvaluationDatasetEntryModel(BaseModel):
1650
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1651
+ inputs: ChatPayload
1652
+ expectations: EvaluationDatasetExpectationsModel
1653
+
1654
+ def to_mlflow_format(self) -> dict[str, Any]:
1655
+ """
1656
+ Convert to MLflow evaluation dataset format.
1657
+
1658
+ Flattens the expectations fields to the top level alongside inputs,
1659
+ which is the format expected by MLflow's Correctness scorer.
1660
+
1661
+ Returns:
1662
+ dict: Flattened dictionary with inputs and expectation fields at top level
1663
+ """
1664
+ result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
1665
+
1666
+ # Flatten expectations to top level for MLflow compatibility
1667
+ if self.expectations.expected_response is not None:
1668
+ result["expected_response"] = self.expectations.expected_response
1669
+ if self.expectations.expected_facts is not None:
1670
+ result["expected_facts"] = self.expectations.expected_facts
1671
+
1672
+ return result
1673
+
1674
+
1675
+ class EvaluationDatasetModel(BaseModel, HasFullName):
1676
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1677
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1678
+ name: str
1679
+ data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
1680
+ overwrite: Optional[bool] = False
1681
+
1682
+ def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
1683
+ evaluation_dataset: EvaluationDataset
1684
+ needs_creation: bool = False
1685
+
1686
+ try:
1687
+ evaluation_dataset = get_dataset(name=self.full_name)
1688
+ if self.overwrite:
1689
+ logger.warning(f"Overwriting dataset {self.full_name}")
1690
+ workspace_client: WorkspaceClient = w if w else WorkspaceClient()
1691
+ logger.debug(f"Dropping table: {self.full_name}")
1692
+ workspace_client.tables.delete(full_name=self.full_name)
1693
+ needs_creation = True
1694
+ except Exception:
1695
+ logger.warning(
1696
+ f"Dataset {self.full_name} not found, will create new dataset"
1697
+ )
1698
+ needs_creation = True
1699
+
1700
+ # Create dataset if needed (either new or after overwrite)
1701
+ if needs_creation:
1702
+ evaluation_dataset = create_dataset(name=self.full_name)
1703
+ if self.data:
1704
+ logger.debug(
1705
+ f"Merging {len(self.data)} entries into dataset {self.full_name}"
1706
+ )
1707
+ # Use to_mlflow_format() to flatten expectations for MLflow compatibility
1708
+ evaluation_dataset.merge_records(
1709
+ [e.to_mlflow_format() for e in self.data]
1710
+ )
1711
+
1712
+ return evaluation_dataset
1713
+
1714
+ @property
1715
+ def full_name(self) -> str:
1716
+ if self.schema_model:
1717
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1718
+ return self.name
1719
+
1720
+
1721
+ class PromptOptimizationModel(BaseModel):
1722
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1723
+ name: str
1724
+ prompt: Optional[PromptModel] = None
1725
+ agent: AgentModel
1726
+ dataset: (
1727
+ EvaluationDatasetModel | str
1728
+ ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
1729
+ reflection_model: Optional[LLMModel | str] = None
1730
+ num_candidates: Optional[int] = 50
1731
+ scorer_model: Optional[LLMModel | str] = None
1732
+
1733
+ def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1734
+ """
1735
+ Optimize the prompt using MLflow's prompt optimization.
1736
+
1737
+ Args:
1738
+ w: Optional WorkspaceClient for Databricks operations
1739
+
1740
+ Returns:
1741
+ PromptModel: The optimized prompt model with new URI
1742
+ """
1743
+ from dao_ai.providers.base import ServiceProvider
1744
+ from dao_ai.providers.databricks import DatabricksProvider
1745
+
1746
+ provider: ServiceProvider = DatabricksProvider(w=w)
1747
+ optimized_prompt: PromptModel = provider.optimize_prompt(self)
1748
+ return optimized_prompt
1749
+
1750
+ @model_validator(mode="after")
1751
+ def set_defaults(self):
1752
+ # If no prompt is specified, try to use the agent's prompt
1753
+ if self.prompt is None:
1754
+ if isinstance(self.agent.prompt, PromptModel):
1755
+ self.prompt = self.agent.prompt
1756
+ else:
1757
+ raise ValueError(
1758
+ f"Prompt optimization '{self.name}' requires either an explicit prompt "
1759
+ f"or an agent with a prompt configured"
1760
+ )
1761
+
1762
+ if self.reflection_model is None:
1763
+ self.reflection_model = self.agent.model
1764
+
1765
+ if self.scorer_model is None:
1766
+ self.scorer_model = self.agent.model
1767
+
1768
+ return self
1769
+
1770
+
1771
+ class OptimizationsModel(BaseModel):
1772
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1773
+ training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
1774
+ prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
1775
+ default_factory=dict
1776
+ )
1777
+
1778
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
1779
+ """
1780
+ Optimize all prompts in this configuration.
1781
+
1782
+ This method:
1783
+ 1. Ensures all training datasets are created/registered in MLflow
1784
+ 2. Runs each prompt optimization
1785
+
1786
+ Args:
1787
+ w: Optional WorkspaceClient for Databricks operations
1788
+
1789
+ Returns:
1790
+ dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
1791
+ """
1792
+ # First, ensure all training datasets are created/registered in MLflow
1793
+ logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
1794
+ for dataset_name, dataset_model in self.training_datasets.items():
1795
+ logger.debug(f"Creating/updating dataset: {dataset_name}")
1796
+ dataset_model.as_dataset()
1797
+
1798
+ # Run optimizations
1799
+ results: dict[str, PromptModel] = {}
1800
+ for name, optimization in self.prompt_optimizations.items():
1801
+ results[name] = optimization.optimize(w)
1802
+ return results
1803
+
1804
+
1462
1805
  class DatasetFormat(str, Enum):
1463
1806
  CSV = "csv"
1464
1807
  DELTA = "delta"
@@ -1537,6 +1880,7 @@ class AppConfig(BaseModel):
1537
1880
  agents: dict[str, AgentModel] = Field(default_factory=dict)
1538
1881
  app: Optional[AppModel] = None
1539
1882
  evaluation: Optional[EvaluationModel] = None
1883
+ optimizations: Optional[OptimizationsModel] = None
1540
1884
  datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
1541
1885
  unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
1542
1886
  default_factory=list
dao_ai/graph.py CHANGED
@@ -62,11 +62,19 @@ def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTo
62
62
  logger.debug(
63
63
  f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
64
64
  )
65
+
66
+ # Use handoff_prompt if provided, otherwise create default description
67
+ handoff_description = handoff_to_agent.handoff_prompt or (
68
+ handoff_to_agent.description
69
+ if handoff_to_agent.description
70
+ else "general assistance and questions"
71
+ )
72
+
65
73
  handoff_tools.append(
66
74
  swarm_handoff_tool(
67
75
  agent_name=handoff_to_agent.name,
68
76
  description=f"Ask {handoff_to_agent.name} for help with: "
69
- + handoff_to_agent.handoff_prompt,
77
+ + handoff_description,
70
78
  )
71
79
  )
72
80
  return handoff_tools
@@ -79,13 +87,25 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
79
87
  for registered_agent in config.app.agents:
80
88
  agents.append(
81
89
  create_agent_node(
82
- app=config.app, agent=registered_agent, additional_tools=[]
90
+ agent=registered_agent,
91
+ memory=config.app.orchestration.memory
92
+ if config.app.orchestration
93
+ else None,
94
+ chat_history=config.app.chat_history,
95
+ additional_tools=[],
83
96
  )
84
97
  )
98
+ # Use handoff_prompt if provided, otherwise create default description
99
+ handoff_description = registered_agent.handoff_prompt or (
100
+ registered_agent.description
101
+ if registered_agent.description
102
+ else f"General assistance with {registered_agent.name} related tasks"
103
+ )
104
+
85
105
  tools.append(
86
106
  supervisor_handoff_tool(
87
107
  agent_name=registered_agent.name,
88
- description=registered_agent.handoff_prompt,
108
+ description=handoff_description,
89
109
  )
90
110
  )
91
111
 
@@ -169,7 +189,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
169
189
  )
170
190
  agents.append(
171
191
  create_agent_node(
172
- app=config.app, agent=registered_agent, additional_tools=handoff_tools
192
+ agent=registered_agent,
193
+ memory=config.app.orchestration.memory
194
+ if config.app.orchestration
195
+ else None,
196
+ chat_history=config.app.chat_history,
197
+ additional_tools=handoff_tools,
173
198
  )
174
199
  )
175
200