ibm-watsonx-orchestrate 1.5.0b0__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -2
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +13 -0
  4. ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
  5. ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
  6. ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
  7. ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +43 -7
  8. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
  9. ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
  11. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
  12. ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +136 -259
  14. ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
  15. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
  16. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
  17. ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
  18. ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
  19. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
  20. ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
  21. ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
  22. ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
  23. ibm_watsonx_orchestrate/docker/default.env +10 -11
  24. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
  25. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +28 -25
  26. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
  27. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
  28. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
  29. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -5,8 +5,7 @@
5
5
 
6
6
  pkg_name = "ibm-watsonx-orchestrate"
7
7
 
8
- __version__ = "1.5.0b0"
9
-
8
+ __version__ = "1.5.0b1"
10
9
 
11
10
 
12
11
 
@@ -3,12 +3,14 @@ import yaml
3
3
  from enum import Enum
4
4
  from typing import List, Optional, Dict
5
5
  from pydantic import BaseModel, model_validator, ConfigDict
6
- from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool
6
+ from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool, PythonTool
7
7
  from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.types import KnowledgeBaseSpec
8
8
  from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.knowledge_base import KnowledgeBase
9
9
  from pydantic import Field, AliasChoices
10
10
  from typing import Annotated
11
11
 
12
+ from ibm_watsonx_orchestrate.agent_builder.tools.types import JsonSchemaObject
13
+
12
14
  # TO-DO: this is just a placeholder. Will update this later to align with backend
13
15
  DEFAULT_LLM = "watsonx/meta-llama/llama-3-1-70b-instruct"
14
16
 
@@ -78,6 +80,8 @@ class AgentSpec(BaseAgentSpec):
78
80
  kind: AgentKind = AgentKind.NATIVE
79
81
  llm: str = DEFAULT_LLM
80
82
  style: AgentStyle = AgentStyle.DEFAULT
83
+ custom_join_tool: str | PythonTool | None = None
84
+ structured_output: Optional[JsonSchemaObject] = None
81
85
  instructions: Annotated[Optional[str], Field(json_schema_extra={"min_length_str":1})] = None
82
86
  collaborators: Optional[List[str]] | Optional[List['BaseAgentSpec']] = []
83
87
  tools: Optional[List[str]] | Optional[List['BaseTool']] = []
@@ -117,6 +121,11 @@ def validate_agent_fields(values: dict) -> dict:
117
121
  if collaborator == name:
118
122
  raise ValueError(f"Circular reference detected. The agent '{name}' cannot contain itself as a collaborator")
119
123
 
124
+ if values.get("style") == AgentStyle.PLANNER:
125
+ if not values.get("custom_join_tool") and not values.get("structured_output"):
126
+ raise ValueError("Either 'custom_join_tool' or 'structured_output' must be provided for planner style agents.")
127
+ if values.get("custom_join_tool") and values.get("structured_output"):
128
+ raise ValueError("Only one of 'custom_join_tool' or 'structured_output' can be provided for planner style agents.")
120
129
 
121
130
  return values
122
131
 
@@ -184,11 +184,24 @@ class ElasticSearchConnection(BaseModel):
184
184
  result_filter: Optional[list] = None
185
185
  field_mapping: Optional[FieldMapping] = None
186
186
 
187
+ class CustomSearchConnection(BaseModel):
188
+ """
189
+ example:
190
+ {
191
+ "url": "https://customsearch.xxxx.us-east.codeengine.appdomain.cloud",
192
+ "filter": "...",
193
+ "metadata": {...}
194
+ }
195
+ """
196
+ url: str
197
+ filter: Optional[str] = None
198
+ metadata: Optional[dict] = None
187
199
 
188
200
  class IndexConnection(BaseModel):
189
201
  connection_id: Optional[str] = None
190
202
  milvus: Optional[MilvusConnection] = None
191
203
  elastic_search: Optional[ElasticSearchConnection] = None
204
+ custom_search: Optional[CustomSearchConnection] = None
192
205
 
193
206
 
194
207
  class ConversationalSearchConfig(BaseModel):
@@ -0,0 +1 @@
1
+ from .types import ModelPolicy, ModelPolicyInner, ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
@@ -1,7 +1,7 @@
1
1
  from enum import Enum
2
- from typing import Optional, List, Dict, Any, Union
2
+ from typing import List, Union
3
3
 
4
- from pydantic import Field, BaseModel, ConfigDict
4
+ from pydantic import BaseModel, ConfigDict
5
5
 
6
6
  class ModelPolicyStrategyMode(str, Enum):
7
7
  LOAD_BALANCED = "loadbalance"
@@ -17,14 +17,13 @@ class ModelPolicyRetry(BaseModel):
17
17
  on_status_codes: List[int] = None
18
18
 
19
19
  class ModelPolicyTarget(BaseModel):
20
- model_id: str = None
21
- weight: int = None
22
-
20
+ weight: float = None
21
+ model_name: str = None
23
22
 
24
23
  class ModelPolicyInner(BaseModel):
25
24
  strategy: ModelPolicyStrategy = None
26
25
  retry: ModelPolicyRetry = None
27
- targets: List[Union[ModelPolicyTarget, 'ModelPolicyInner']] = None
26
+ targets: List[Union['ModelPolicyInner', ModelPolicyTarget]] = None
28
27
 
29
28
 
30
29
  class ModelPolicy(BaseModel):
@@ -32,5 +31,5 @@ class ModelPolicy(BaseModel):
32
31
 
33
32
  name: str
34
33
  display_name: str
35
- policy: ModelPolicyInner
36
-
34
+ description: str
35
+ policy: ModelPolicyInner
@@ -0,0 +1 @@
1
+ from .types import VirtualModel
@@ -1,6 +1,6 @@
1
1
  from typing import Optional, List, Dict, Any, Union
2
2
  from enum import Enum
3
- from pydantic import Field, BaseModel, ConfigDict
3
+ from pydantic import Field, BaseModel, ConfigDict, model_validator
4
4
 
5
5
  class ModelProvider(str, Enum):
6
6
  OPENAI = 'openai'
@@ -29,6 +29,10 @@ class ModelProvider(str, Enum):
29
29
 
30
30
  def __repr__(self):
31
31
  return self.value
32
+
33
+ @classmethod
34
+ def has_value(cls, value):
35
+ return value in cls._value2member_map_
32
36
 
33
37
  class ModelType(str, Enum):
34
38
  CHAT = 'chat'
@@ -42,13 +46,24 @@ class ModelType(str, Enum):
42
46
  def __repr__(self):
43
47
  return self.value
44
48
 
49
+ class ModelType(str, Enum):
50
+ CHAT = 'chat'
51
+ CHAT_VISION = 'chat_vision'
52
+ COMPLETION = 'completion'
53
+ EMBEDDING = 'embedding'
54
+
55
+ def __str__(self):
56
+ return self.value
57
+
58
+ def __repr__(self):
59
+ return self.value
45
60
 
46
61
  class ProviderConfig(BaseModel):
47
62
  # Required fields
48
63
  provider: Optional[str]=''
49
64
 
50
65
 
51
- api_key: Optional[str] = '' # this is not optional
66
+ api_key: Optional[str] = None
52
67
  url_to_fetch: Optional[str] = Field(None, alias="urlToFetch")
53
68
 
54
69
  # Misc
@@ -135,7 +150,7 @@ class ProviderConfig(BaseModel):
135
150
  azure_extra_params: Optional[str] = Field(None, alias="azureExtraParams")
136
151
  azure_foundry_url: Optional[str] = Field(None, alias="azureFoundryUrl")
137
152
 
138
- strict_open_ai_compliance: Optional[bool] = Field(False, alias="strictOpenAiCompliance")
153
+ strict_open_ai_compliance: Optional[bool] = Field(None, alias="strictOpenAiCompliance")
139
154
  mistral_fim_completion: Optional[str] = Field(None, alias="mistralFimCompletion")
140
155
 
141
156
  # Anthropic
@@ -152,6 +167,10 @@ class ProviderConfig(BaseModel):
152
167
  watsonx_version: Optional[str] = Field(None, alias="watsonxVersion")
153
168
  watsonx_space_id: Optional[str] = Field(None, alias="watsonxSpaceId")
154
169
  watsonx_project_id: Optional[str] = Field(None, alias="watsonxProjectId")
170
+ watsonx_deployment_id: Optional[str] = Field(None, alias="watsonxDeploymentId")
171
+ watsonx_cpd_url:Optional[str] = Field(None, alias="watsonxCpdUrl")
172
+ watsonx_cpd_username:Optional[str] = Field(None, alias="watsonxCpdUsername")
173
+ watsonx_cpd_password:Optional[str] = Field(None, alias="watsonxCpdPassword")
155
174
 
156
175
  model_config = {
157
176
  "populate_by_name": True, # Replaces allow_population_by_field_name
@@ -159,19 +178,36 @@ class ProviderConfig(BaseModel):
159
178
  "json_schema_extra": lambda schema: schema.get("properties", {}).pop("provider", None)
160
179
  }
161
180
 
181
+ def update(self, new_config: "ProviderConfig") -> "ProviderConfig":
182
+ old_config_dict = dict(self)
183
+ new_config_dict = dict(new_config)
162
184
 
185
+ new_config_dict = {k:v for k, v in new_config_dict.items() if v is not None}
186
+ old_config_dict.update(new_config_dict)
163
187
 
164
- class CreateVirtualModel(BaseModel):
188
+ return ProviderConfig.model_validate(old_config_dict)
189
+
190
+
191
+ class VirtualModel(BaseModel):
165
192
  model_config = ConfigDict(extra='allow')
166
193
 
167
194
  name: str
168
195
  display_name: Optional[str]
169
196
  description: Optional[str]
170
197
  config: Optional[dict] = None
171
- provider_config: ProviderConfig
172
- tags: List[str]
198
+ provider_config: Optional[ProviderConfig] = None
199
+ tags: List[str] = []
173
200
  model_type: str = ModelType.CHAT
201
+ connection_id: Optional[str] = None
174
202
 
203
+ @model_validator(mode="before")
204
+ def validate_fields(cls, values):
205
+ if not values.get("display_name"):
206
+ values["display_name"] = values.get("name")
207
+ if not values.get("description"):
208
+ values["description"] = values.get("name")
209
+
210
+ return values
175
211
 
176
212
  class ListVirtualModel(BaseModel):
177
213
  model_config = ConfigDict(extra='allow')
@@ -186,4 +222,4 @@ class ListVirtualModel(BaseModel):
186
222
  tags: Optional[List[str]] = None
187
223
  model_type: Optional[str] = None
188
224
 
189
- ANTHROPIC_DEFAULT_MAX_TOKENS = 4096
225
+ ANTHROPIC_DEFAULT_MAX_TOKENS = 4096
@@ -2,7 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import json
4
4
  import os
5
- from typing import Callable, List
5
+ from typing import Any, Callable, Dict, List, get_type_hints
6
6
  import logging
7
7
 
8
8
  import docstring_parser
@@ -13,12 +13,18 @@ from pydantic import TypeAdapter, BaseModel
13
13
  from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
14
14
  from ibm_watsonx_orchestrate.agent_builder.connections import ExpectedCredentials
15
15
  from .base_tool import BaseTool
16
- from .types import ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody, JsonSchemaObject, ToolBinding, \
16
+ from .types import PythonToolKind, ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody, JsonSchemaObject, ToolBinding, \
17
17
  PythonToolBinding
18
18
 
19
19
  _all_tools = []
20
20
  logger = logging.getLogger(__name__)
21
21
 
22
+ JOIN_TOOL_PARAMS = {
23
+ 'original_query': str,
24
+ 'task_results': Dict[str, Any],
25
+ 'messages': List[Dict[str, Any]],
26
+ }
27
+
22
28
  class PythonTool(BaseTool):
23
29
  def __init__(self, fn, spec: ToolSpec, expected_credentials: List[ExpectedCredentials]=None):
24
30
  BaseTool.__init__(self, spec=spec)
@@ -92,6 +98,31 @@ def _validate_input_schema(input_schema: ToolRequestBody) -> None:
92
98
  if not props.get(prop).type:
93
99
  logger.warning(f"Missing type hint for tool property '{prop}' defaulting to 'str'. To remove this warning add a type hint to the property in the tools signature. See Python docs for guidance: https://docs.python.org/3/library/typing.html")
94
100
 
101
+ def _validate_join_tool_func(fn: Callable, sig: inspect.Signature | None = None, name: str | None = None) -> None:
102
+ if sig is None:
103
+ sig = inspect.signature(fn)
104
+ if name is None:
105
+ name = fn.__name__
106
+
107
+ params = sig.parameters
108
+ type_hints = get_type_hints(fn)
109
+
110
+ # Validate parameter order
111
+ actual_param_names = list(params.keys())
112
+ expected_param_names = list(JOIN_TOOL_PARAMS.keys())
113
+ if actual_param_names[:len(expected_param_names)] != expected_param_names:
114
+ raise ValueError(
115
+ f"Join tool function '{name}' has incorrect parameter names or order. Expected: {expected_param_names}, got: {actual_param_names}"
116
+ )
117
+
118
+ # Validate the type hints
119
+ for param, expected_type in JOIN_TOOL_PARAMS.items():
120
+ if param not in type_hints:
121
+ raise ValueError(f"Join tool function '{name}' is missing type for parameter '{param}'")
122
+ actual_type = type_hints[param]
123
+ if actual_type != expected_type:
124
+ raise ValueError(f"Join tool function '{name}' has incorrect type for parameter '{param}'. Expected {expected_type}, got {actual_type}")
125
+
95
126
  def tool(
96
127
  *args,
97
128
  name: str = None,
@@ -100,7 +131,8 @@ def tool(
100
131
  output_schema: ToolResponseBody = None,
101
132
  permission: ToolPermission = ToolPermission.READ_ONLY,
102
133
  expected_credentials: List[ExpectedCredentials] = None,
103
- display_name: str = None
134
+ display_name: str = None,
135
+ kind: PythonToolKind = PythonToolKind.TOOL,
104
136
  ) -> Callable[[{__name__, __doc__}], PythonTool]:
105
137
  """
106
138
  Decorator to convert a python function into a callable tool.
@@ -152,6 +184,11 @@ def tool(
152
184
  spec.binding.python.function = function_binding
153
185
 
154
186
  sig = inspect.signature(fn)
187
+
188
+ # If the function is a join tool, validate its signature matches the expected parameters. If not, raise error with details.
189
+ if kind == PythonToolKind.JOIN_TOOL:
190
+ _validate_join_tool_func(fn, sig, spec.name)
191
+
155
192
  if not input_schema:
156
193
  try:
157
194
  input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, fn, parse_docstring=True)
@@ -190,6 +227,12 @@ def tool(
190
227
 
191
228
  else:
192
229
  spec.output_schema = ToolResponseBody()
230
+
231
+ # Validate the generated schema still conforms to the requirement for a join tool
232
+ if kind == PythonToolKind.JOIN_TOOL:
233
+ if not spec.is_custom_join_tool():
234
+ raise ValueError(f"Join tool '{spec.name}' does not conform to the expected join tool schema. Please ensure the input schema has the required fields: {JOIN_TOOL_PARAMS.keys()} and the output schema is a string.")
235
+
193
236
  _all_tools.append(t)
194
237
  return t
195
238
 
@@ -10,6 +10,9 @@ class ToolPermission(str, Enum):
10
10
  READ_WRITE = 'read_write'
11
11
  ADMIN = 'admin'
12
12
 
13
+ class PythonToolKind(str, Enum):
14
+ JOIN_TOOL = 'join_tool'
15
+ TOOL = 'tool'
13
16
 
14
17
  class JsonSchemaObject(BaseModel):
15
18
  model_config = ConfigDict(extra='allow')
@@ -163,7 +166,6 @@ class ToolBinding(BaseModel):
163
166
  raise ValueError("Only one binding can be set")
164
167
  return self
165
168
 
166
-
167
169
  class ToolSpec(BaseModel):
168
170
  name: str
169
171
  display_name: str | None = None
@@ -174,3 +176,47 @@ class ToolSpec(BaseModel):
174
176
  binding: ToolBinding = None
175
177
  toolkit_id: str | None = None
176
178
 
179
+ def is_custom_join_tool(self) -> bool:
180
+ if self.binding.python is None:
181
+ return False
182
+
183
+ # The code below validates the input schema to have the following structure:
184
+ # {
185
+ # "type": "object",
186
+ # "properties": {
187
+ # "messages": {
188
+ # "type": "array",
189
+ # "items": {
190
+ # "type": "object",
191
+ # },
192
+ # },
193
+ # "task_results": {
194
+ # "type": "object",
195
+ # },
196
+ # "original_query": {
197
+ # "type": "string",
198
+ # },
199
+ # },
200
+ # "required": {"original_query", "task_results", "messages"},
201
+ # }
202
+
203
+ input_schema = self.input_schema
204
+ if input_schema.type != 'object':
205
+ return False
206
+
207
+ required_fields = {"original_query", "task_results", "messages"}
208
+ if input_schema.required is None or set(input_schema.required) != required_fields:
209
+ return False
210
+ if input_schema.properties is None or set(input_schema.properties.keys()) != required_fields:
211
+ return False
212
+
213
+ if input_schema.properties["messages"].type != "array":
214
+ return False
215
+ if not input_schema.properties["messages"].items or input_schema.properties["messages"].items.type != "object":
216
+ return False
217
+ if input_schema.properties["task_results"].type != "object":
218
+ return False
219
+ if input_schema.properties["original_query"].type != "string":
220
+ return False
221
+
222
+ return True
@@ -105,6 +105,20 @@ def agent_create(
105
105
  AgentStyle,
106
106
  typer.Option("--style", help="The style of agent you wish to create"),
107
107
  ] = AgentStyle.DEFAULT,
108
+ custom_join_tool: Annotated[
109
+ str | None,
110
+ typer.Option(
111
+ "--custom-join-tool",
112
+ help='The name of the python tool to be used by the agent to format and generate the final output. Only needed for "planner" style agents.',
113
+ ),
114
+ ] = None,
115
+ structured_output: Annotated[
116
+ str | None,
117
+ typer.Option(
118
+ "--structured-output",
119
+ help='A JSON Schema object that defines the desired structure of the agent\'s final output. Only needed for "planner" style agents.',
120
+ ),
121
+ ] = None,
108
122
  collaborators: Annotated[
109
123
  List[str],
110
124
  typer.Option(
@@ -138,6 +152,7 @@ def agent_create(
138
152
  chat_params_dict = json.loads(chat_params) if chat_params else {}
139
153
  config_dict = json.loads(config) if config else {}
140
154
  auth_config_dict = json.loads(auth_config) if auth_config else {}
155
+ structured_output_dict = json.loads(structured_output) if structured_output else None
141
156
 
142
157
  agents_controller = AgentsController()
143
158
  agent = agents_controller.generate_agent_spec(
@@ -151,6 +166,8 @@ def agent_create(
151
166
  provider=provider,
152
167
  llm=llm,
153
168
  style=style,
169
+ custom_join_tool=custom_join_tool,
170
+ structured_output=structured_output_dict,
154
171
  collaborators=collaborators,
155
172
  tools=tools,
156
173
  knowledge_base=knowledge_base,
@@ -11,9 +11,12 @@ import logging
11
11
  from pathlib import Path
12
12
  from copy import deepcopy
13
13
 
14
- from typing import Iterable, List
14
+ from typing import Iterable, List, TypeVar
15
+ from ibm_watsonx_orchestrate.agent_builder.agents.types import AgentStyle
16
+ from ibm_watsonx_orchestrate.agent_builder.tools.types import ToolSpec
15
17
  from ibm_watsonx_orchestrate.cli.commands.tools.tools_controller import import_python_tool, ToolsController
16
18
  from ibm_watsonx_orchestrate.cli.commands.knowledge_bases.knowledge_bases_controller import import_python_knowledge_base
19
+ from ibm_watsonx_orchestrate.cli.commands.models.models_controller import import_python_model
17
20
 
18
21
  from ibm_watsonx_orchestrate.agent_builder.agents import (
19
22
  Agent,
@@ -34,10 +37,14 @@ from ibm_watsonx_orchestrate.utils.utils import check_file_in_zip
34
37
 
35
38
  logger = logging.getLogger(__name__)
36
39
 
40
+ # Helper generic type for any agent
41
+ AnyAgentT = TypeVar("AnyAgentT", bound=Agent | ExternalAgent | AssistantAgent)
42
+
37
43
  def import_python_agent(file: str) -> List[Agent | ExternalAgent | AssistantAgent]:
38
44
  # Import tools
39
45
  import_python_tool(file)
40
46
  import_python_knowledge_base(file)
47
+ import_python_model(file)
41
48
 
42
49
  file_path = Path(file)
43
50
  file_directory = file_path.parent
@@ -90,6 +97,8 @@ def parse_create_native_args(name: str, kind: AgentKind, description: str | None
90
97
  "description": description,
91
98
  "llm": args.get("llm"),
92
99
  "style": args.get("style"),
100
+ "custom_join_tool": args.get("custom_join_tool"),
101
+ "structured_output": args.get("structured_output"),
93
102
  }
94
103
 
95
104
  collaborators = args.get("collaborators", [])
@@ -202,7 +211,7 @@ class AgentsController:
202
211
  return self.knowledge_base_client
203
212
 
204
213
  @staticmethod
205
- def import_agent(file: str, app_id: str) -> Iterable:
214
+ def import_agent(file: str, app_id: str) -> List[Agent | ExternalAgent | AssistantAgent]:
206
215
  agents = parse_file(file)
207
216
  for agent in agents:
208
217
  if app_id and agent.kind != AgentKind.NATIVE and agent.kind != AgentKind.ASSISTANT:
@@ -216,7 +225,9 @@ class AgentsController:
216
225
  ) -> Agent | ExternalAgent | AssistantAgent:
217
226
  match kind:
218
227
  case AgentKind.NATIVE:
219
- agent_details = parse_create_native_args(name, kind=kind, description=description, **kwargs)
228
+ agent_details = parse_create_native_args(
229
+ name, kind=kind, description=description, **kwargs
230
+ )
220
231
  agent = Agent.model_validate(agent_details)
221
232
  AgentsController().persist_record(agent=agent, **kwargs)
222
233
  case AgentKind.EXTERNAL:
@@ -296,12 +307,17 @@ class AgentsController:
296
307
  ref_agent.collaborators = ref_collaborators
297
308
 
298
309
  return ref_agent
299
-
310
+
300
311
  def dereference_tools(self, agent: Agent) -> Agent:
301
312
  tool_client = self.get_tool_client()
302
313
 
303
314
  deref_agent = deepcopy(agent)
304
- matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
315
+
316
+ # If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
317
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
318
+ matching_tools = tool_client.get_drafts_by_names(deref_agent.tools + [deref_agent.custom_join_tool])
319
+ else:
320
+ matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
305
321
 
306
322
  name_id_lookup = {}
307
323
  for tool in matching_tools:
@@ -318,6 +334,13 @@ class AgentsController:
318
334
  sys.exit(1)
319
335
  deref_tools.append(id)
320
336
  deref_agent.tools = deref_tools
337
+
338
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
339
+ join_tool_id = name_id_lookup.get(agent.custom_join_tool)
340
+ if not join_tool_id:
341
+ logger.error(f"Failed to find custom join tool. No tools found with the name '{agent.custom_join_tool}'")
342
+ sys.exit(1)
343
+ deref_agent.custom_join_tool = join_tool_id
321
344
 
322
345
  return deref_agent
323
346
 
@@ -325,7 +348,12 @@ class AgentsController:
325
348
  tool_client = self.get_tool_client()
326
349
 
327
350
  ref_agent = deepcopy(agent)
328
- matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
351
+
352
+ # If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
353
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
354
+ matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools + [ref_agent.custom_join_tool])
355
+ else:
356
+ matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
329
357
 
330
358
  id_name_lookup = {}
331
359
  for tool in matching_tools:
@@ -342,6 +370,13 @@ class AgentsController:
342
370
  sys.exit(1)
343
371
  ref_tools.append(name)
344
372
  ref_agent.tools = ref_tools
373
+
374
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
375
+ join_tool_name = id_name_lookup.get(agent.custom_join_tool)
376
+ if not join_tool_name:
377
+ logger.error(f"Failed to find custom join tool. No tools found with the id '{agent.custom_join_tool}'")
378
+ sys.exit(1)
379
+ ref_agent.custom_join_tool = join_tool_name
345
380
 
346
381
  return ref_agent
347
382
 
@@ -409,7 +444,7 @@ class AgentsController:
409
444
  def dereference_native_agent_dependencies(self, agent: Agent) -> Agent:
410
445
  if agent.collaborators and len(agent.collaborators):
411
446
  agent = self.dereference_collaborators(agent)
412
- if agent.tools and len(agent.tools):
447
+ if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
413
448
  agent = self.dereference_tools(agent)
414
449
  if agent.knowledge_base and len(agent.knowledge_base):
415
450
  agent = self.dereference_knowledge_bases(agent)
@@ -419,7 +454,7 @@ class AgentsController:
419
454
  def reference_native_agent_dependencies(self, agent: Agent) -> Agent:
420
455
  if agent.collaborators and len(agent.collaborators):
421
456
  agent = self.reference_collaborators(agent)
422
- if agent.tools and len(agent.tools):
457
+ if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
423
458
  agent = self.reference_tools(agent)
424
459
  if agent.knowledge_base and len(agent.knowledge_base):
425
460
  agent = self.reference_knowledge_bases(agent)
@@ -443,21 +478,21 @@ class AgentsController:
443
478
  return agent
444
479
 
445
480
  # Convert all names used in an agent to the corresponding ids
446
- def dereference_agent_dependencies(self, agent: Agent | ExternalAgent | AssistantAgent ) -> Agent | ExternalAgent | AssistantAgent:
481
+ def dereference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
447
482
  if isinstance(agent, Agent):
448
483
  return self.dereference_native_agent_dependencies(agent)
449
484
  if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
450
485
  return self.dereference_external_or_assistant_agent_dependencies(agent)
451
-
486
+
452
487
  # Convert all ids used in an agent to the corresponding names
453
- def reference_agent_dependencies(self, agent: Agent | ExternalAgent | AssistantAgent ) -> Agent | ExternalAgent | AssistantAgent:
488
+ def reference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
454
489
  if isinstance(agent, Agent):
455
490
  return self.reference_native_agent_dependencies(agent)
456
491
  if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
457
492
  return self.reference_external_or_assistant_agent_dependencies(agent)
458
493
 
459
494
  def publish_or_update_agents(
460
- self, agents: Iterable[Agent]
495
+ self, agents: Iterable[Agent | ExternalAgent | AssistantAgent]
461
496
  ):
462
497
  for agent in agents:
463
498
  agent_name = agent.name
@@ -476,6 +511,18 @@ class AgentsController:
476
511
  all_existing_agents = existing_external_clients + existing_native_agents + existing_assistant_clients
477
512
  agent = self.dereference_agent_dependencies(agent)
478
513
 
514
+ if isinstance(agent, Agent) and agent.style == AgentStyle.PLANNER and isinstance(agent.custom_join_tool, str):
515
+ tool_client = self.get_tool_client()
516
+
517
+ join_tool_spec = ToolSpec.model_validate(
518
+ tool_client.get_draft_by_id(agent.custom_join_tool)
519
+ )
520
+ if not join_tool_spec.is_custom_join_tool():
521
+ logger.error(
522
+ f"Tool '{join_tool_spec.name}' configured as the custom join tool is not a valid join tool. A custom join tool must be a Python tool with specific input and output schema."
523
+ )
524
+ sys.exit(1)
525
+
479
526
  agent_kind = agent.kind
480
527
 
481
528
  if len(all_existing_agents) > 1:
@@ -762,32 +809,32 @@ class AgentsController:
762
809
  rich.print(assistants_table)
763
810
 
764
811
  def remove_agent(self, name: str, kind: AgentKind):
765
- try:
766
- if kind == AgentKind.NATIVE:
767
- client = self.get_native_client()
768
- elif kind == AgentKind.EXTERNAL:
769
- client = self.get_external_client()
770
- elif kind == AgentKind.ASSISTANT:
771
- client = self.get_assistant_client()
772
- else:
773
- raise ValueError("'kind' must be 'native'")
774
-
775
- draft_agents = client.get_draft_by_name(name)
776
- if len(draft_agents) > 1:
777
- logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
778
- sys.exit(1)
779
- if len(draft_agents) > 0:
780
- draft_agent = draft_agents[0]
781
- agent_id = draft_agent.get("id")
782
- client.delete(agent_id=agent_id)
783
-
784
- logger.info(f"Successfully removed agent {name}")
785
- else:
786
- logger.warning(f"No agent named '{name}' found")
787
- except requests.HTTPError as e:
788
- logger.error(e.response.text)
789
- exit(1)
790
-
812
+ try:
813
+ if kind == AgentKind.NATIVE:
814
+ client = self.get_native_client()
815
+ elif kind == AgentKind.EXTERNAL:
816
+ client = self.get_external_client()
817
+ elif kind == AgentKind.ASSISTANT:
818
+ client = self.get_assistant_client()
819
+ else:
820
+ raise ValueError("'kind' must be 'native'")
821
+
822
+ draft_agents = client.get_draft_by_name(name)
823
+ if len(draft_agents) > 1:
824
+ logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
825
+ sys.exit(1)
826
+ if len(draft_agents) > 0:
827
+ draft_agent = draft_agents[0]
828
+ agent_id = draft_agent.get("id")
829
+ client.delete(agent_id=agent_id)
830
+
831
+ logger.info(f"Successfully removed agent {name}")
832
+ else:
833
+ logger.warning(f"No agent named '{name}' found")
834
+ except requests.HTTPError as e:
835
+ logger.error(e.response.text)
836
+ exit(1)
837
+
791
838
  def get_spec_file_content(self, agent: Agent | ExternalAgent | AssistantAgent):
792
839
  ref_agent = self.reference_agent_dependencies(agent)
793
840
  agent_spec = ref_agent.model_dump(mode='json', exclude_none=True)
@@ -810,7 +857,7 @@ class AgentsController:
810
857
 
811
858
  return agent
812
859
 
813
- def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent:
860
+ def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent | None:
814
861
  native_client = self.get_native_client()
815
862
  external_client = self.get_external_client()
816
863
  assistant_client = self.get_assistant_client()