ibm-watsonx-orchestrate 1.5.0b0__py3-none-any.whl → 1.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -2
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +13 -0
- ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
- ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +43 -7
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
- ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +136 -259
- ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
- ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
- ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
- ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
- ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
- ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
- ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
- ibm_watsonx_orchestrate/docker/default.env +10 -11
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +28 -25
- ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -3,12 +3,14 @@ import yaml
|
|
3
3
|
from enum import Enum
|
4
4
|
from typing import List, Optional, Dict
|
5
5
|
from pydantic import BaseModel, model_validator, ConfigDict
|
6
|
-
from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool
|
6
|
+
from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool, PythonTool
|
7
7
|
from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.types import KnowledgeBaseSpec
|
8
8
|
from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.knowledge_base import KnowledgeBase
|
9
9
|
from pydantic import Field, AliasChoices
|
10
10
|
from typing import Annotated
|
11
11
|
|
12
|
+
from ibm_watsonx_orchestrate.agent_builder.tools.types import JsonSchemaObject
|
13
|
+
|
12
14
|
# TO-DO: this is just a placeholder. Will update this later to align with backend
|
13
15
|
DEFAULT_LLM = "watsonx/meta-llama/llama-3-1-70b-instruct"
|
14
16
|
|
@@ -78,6 +80,8 @@ class AgentSpec(BaseAgentSpec):
|
|
78
80
|
kind: AgentKind = AgentKind.NATIVE
|
79
81
|
llm: str = DEFAULT_LLM
|
80
82
|
style: AgentStyle = AgentStyle.DEFAULT
|
83
|
+
custom_join_tool: str | PythonTool | None = None
|
84
|
+
structured_output: Optional[JsonSchemaObject] = None
|
81
85
|
instructions: Annotated[Optional[str], Field(json_schema_extra={"min_length_str":1})] = None
|
82
86
|
collaborators: Optional[List[str]] | Optional[List['BaseAgentSpec']] = []
|
83
87
|
tools: Optional[List[str]] | Optional[List['BaseTool']] = []
|
@@ -117,6 +121,11 @@ def validate_agent_fields(values: dict) -> dict:
|
|
117
121
|
if collaborator == name:
|
118
122
|
raise ValueError(f"Circular reference detected. The agent '{name}' cannot contain itself as a collaborator")
|
119
123
|
|
124
|
+
if values.get("style") == AgentStyle.PLANNER:
|
125
|
+
if not values.get("custom_join_tool") and not values.get("structured_output"):
|
126
|
+
raise ValueError("Either 'custom_join_tool' or 'structured_output' must be provided for planner style agents.")
|
127
|
+
if values.get("custom_join_tool") and values.get("structured_output"):
|
128
|
+
raise ValueError("Only one of 'custom_join_tool' or 'structured_output' can be provided for planner style agents.")
|
120
129
|
|
121
130
|
return values
|
122
131
|
|
@@ -184,11 +184,24 @@ class ElasticSearchConnection(BaseModel):
|
|
184
184
|
result_filter: Optional[list] = None
|
185
185
|
field_mapping: Optional[FieldMapping] = None
|
186
186
|
|
187
|
+
class CustomSearchConnection(BaseModel):
|
188
|
+
"""
|
189
|
+
example:
|
190
|
+
{
|
191
|
+
"url": "https://customsearch.xxxx.us-east.codeengine.appdomain.cloud",
|
192
|
+
"filter": "...",
|
193
|
+
"metadata": {...}
|
194
|
+
}
|
195
|
+
"""
|
196
|
+
url: str
|
197
|
+
filter: Optional[str] = None
|
198
|
+
metadata: Optional[dict] = None
|
187
199
|
|
188
200
|
class IndexConnection(BaseModel):
|
189
201
|
connection_id: Optional[str] = None
|
190
202
|
milvus: Optional[MilvusConnection] = None
|
191
203
|
elastic_search: Optional[ElasticSearchConnection] = None
|
204
|
+
custom_search: Optional[CustomSearchConnection] = None
|
192
205
|
|
193
206
|
|
194
207
|
class ConversationalSearchConfig(BaseModel):
|
@@ -0,0 +1 @@
|
|
1
|
+
from .types import ModelPolicy, ModelPolicyInner, ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import
|
2
|
+
from typing import List, Union
|
3
3
|
|
4
|
-
from pydantic import
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
5
5
|
|
6
6
|
class ModelPolicyStrategyMode(str, Enum):
|
7
7
|
LOAD_BALANCED = "loadbalance"
|
@@ -17,14 +17,13 @@ class ModelPolicyRetry(BaseModel):
|
|
17
17
|
on_status_codes: List[int] = None
|
18
18
|
|
19
19
|
class ModelPolicyTarget(BaseModel):
|
20
|
-
|
21
|
-
|
22
|
-
|
20
|
+
weight: float = None
|
21
|
+
model_name: str = None
|
23
22
|
|
24
23
|
class ModelPolicyInner(BaseModel):
|
25
24
|
strategy: ModelPolicyStrategy = None
|
26
25
|
retry: ModelPolicyRetry = None
|
27
|
-
targets: List[Union[
|
26
|
+
targets: List[Union['ModelPolicyInner', ModelPolicyTarget]] = None
|
28
27
|
|
29
28
|
|
30
29
|
class ModelPolicy(BaseModel):
|
@@ -32,5 +31,5 @@ class ModelPolicy(BaseModel):
|
|
32
31
|
|
33
32
|
name: str
|
34
33
|
display_name: str
|
35
|
-
|
36
|
-
|
34
|
+
description: str
|
35
|
+
policy: ModelPolicyInner
|
@@ -0,0 +1 @@
|
|
1
|
+
from .types import VirtualModel
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Optional, List, Dict, Any, Union
|
2
2
|
from enum import Enum
|
3
|
-
from pydantic import Field, BaseModel, ConfigDict
|
3
|
+
from pydantic import Field, BaseModel, ConfigDict, model_validator
|
4
4
|
|
5
5
|
class ModelProvider(str, Enum):
|
6
6
|
OPENAI = 'openai'
|
@@ -29,6 +29,10 @@ class ModelProvider(str, Enum):
|
|
29
29
|
|
30
30
|
def __repr__(self):
|
31
31
|
return self.value
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def has_value(cls, value):
|
35
|
+
return value in cls._value2member_map_
|
32
36
|
|
33
37
|
class ModelType(str, Enum):
|
34
38
|
CHAT = 'chat'
|
@@ -42,13 +46,24 @@ class ModelType(str, Enum):
|
|
42
46
|
def __repr__(self):
|
43
47
|
return self.value
|
44
48
|
|
49
|
+
class ModelType(str, Enum):
|
50
|
+
CHAT = 'chat'
|
51
|
+
CHAT_VISION = 'chat_vision'
|
52
|
+
COMPLETION = 'completion'
|
53
|
+
EMBEDDING = 'embedding'
|
54
|
+
|
55
|
+
def __str__(self):
|
56
|
+
return self.value
|
57
|
+
|
58
|
+
def __repr__(self):
|
59
|
+
return self.value
|
45
60
|
|
46
61
|
class ProviderConfig(BaseModel):
|
47
62
|
# Required fields
|
48
63
|
provider: Optional[str]=''
|
49
64
|
|
50
65
|
|
51
|
-
api_key: Optional[str] =
|
66
|
+
api_key: Optional[str] = None
|
52
67
|
url_to_fetch: Optional[str] = Field(None, alias="urlToFetch")
|
53
68
|
|
54
69
|
# Misc
|
@@ -135,7 +150,7 @@ class ProviderConfig(BaseModel):
|
|
135
150
|
azure_extra_params: Optional[str] = Field(None, alias="azureExtraParams")
|
136
151
|
azure_foundry_url: Optional[str] = Field(None, alias="azureFoundryUrl")
|
137
152
|
|
138
|
-
strict_open_ai_compliance: Optional[bool] = Field(
|
153
|
+
strict_open_ai_compliance: Optional[bool] = Field(None, alias="strictOpenAiCompliance")
|
139
154
|
mistral_fim_completion: Optional[str] = Field(None, alias="mistralFimCompletion")
|
140
155
|
|
141
156
|
# Anthropic
|
@@ -152,6 +167,10 @@ class ProviderConfig(BaseModel):
|
|
152
167
|
watsonx_version: Optional[str] = Field(None, alias="watsonxVersion")
|
153
168
|
watsonx_space_id: Optional[str] = Field(None, alias="watsonxSpaceId")
|
154
169
|
watsonx_project_id: Optional[str] = Field(None, alias="watsonxProjectId")
|
170
|
+
watsonx_deployment_id: Optional[str] = Field(None, alias="watsonxDeploymentId")
|
171
|
+
watsonx_cpd_url:Optional[str] = Field(None, alias="watsonxCpdUrl")
|
172
|
+
watsonx_cpd_username:Optional[str] = Field(None, alias="watsonxCpdUsername")
|
173
|
+
watsonx_cpd_password:Optional[str] = Field(None, alias="watsonxCpdPassword")
|
155
174
|
|
156
175
|
model_config = {
|
157
176
|
"populate_by_name": True, # Replaces allow_population_by_field_name
|
@@ -159,19 +178,36 @@ class ProviderConfig(BaseModel):
|
|
159
178
|
"json_schema_extra": lambda schema: schema.get("properties", {}).pop("provider", None)
|
160
179
|
}
|
161
180
|
|
181
|
+
def update(self, new_config: "ProviderConfig") -> "ProviderConfig":
|
182
|
+
old_config_dict = dict(self)
|
183
|
+
new_config_dict = dict(new_config)
|
162
184
|
|
185
|
+
new_config_dict = {k:v for k, v in new_config_dict.items() if v is not None}
|
186
|
+
old_config_dict.update(new_config_dict)
|
163
187
|
|
164
|
-
|
188
|
+
return ProviderConfig.model_validate(old_config_dict)
|
189
|
+
|
190
|
+
|
191
|
+
class VirtualModel(BaseModel):
|
165
192
|
model_config = ConfigDict(extra='allow')
|
166
193
|
|
167
194
|
name: str
|
168
195
|
display_name: Optional[str]
|
169
196
|
description: Optional[str]
|
170
197
|
config: Optional[dict] = None
|
171
|
-
provider_config: ProviderConfig
|
172
|
-
tags: List[str]
|
198
|
+
provider_config: Optional[ProviderConfig] = None
|
199
|
+
tags: List[str] = []
|
173
200
|
model_type: str = ModelType.CHAT
|
201
|
+
connection_id: Optional[str] = None
|
174
202
|
|
203
|
+
@model_validator(mode="before")
|
204
|
+
def validate_fields(cls, values):
|
205
|
+
if not values.get("display_name"):
|
206
|
+
values["display_name"] = values.get("name")
|
207
|
+
if not values.get("description"):
|
208
|
+
values["description"] = values.get("name")
|
209
|
+
|
210
|
+
return values
|
175
211
|
|
176
212
|
class ListVirtualModel(BaseModel):
|
177
213
|
model_config = ConfigDict(extra='allow')
|
@@ -186,4 +222,4 @@ class ListVirtualModel(BaseModel):
|
|
186
222
|
tags: Optional[List[str]] = None
|
187
223
|
model_type: Optional[str] = None
|
188
224
|
|
189
|
-
ANTHROPIC_DEFAULT_MAX_TOKENS = 4096
|
225
|
+
ANTHROPIC_DEFAULT_MAX_TOKENS = 4096
|
@@ -2,7 +2,7 @@ import importlib
|
|
2
2
|
import inspect
|
3
3
|
import json
|
4
4
|
import os
|
5
|
-
from typing import Callable, List
|
5
|
+
from typing import Any, Callable, Dict, List, get_type_hints
|
6
6
|
import logging
|
7
7
|
|
8
8
|
import docstring_parser
|
@@ -13,12 +13,18 @@ from pydantic import TypeAdapter, BaseModel
|
|
13
13
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
14
14
|
from ibm_watsonx_orchestrate.agent_builder.connections import ExpectedCredentials
|
15
15
|
from .base_tool import BaseTool
|
16
|
-
from .types import ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody, JsonSchemaObject, ToolBinding, \
|
16
|
+
from .types import PythonToolKind, ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody, JsonSchemaObject, ToolBinding, \
|
17
17
|
PythonToolBinding
|
18
18
|
|
19
19
|
_all_tools = []
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
|
22
|
+
JOIN_TOOL_PARAMS = {
|
23
|
+
'original_query': str,
|
24
|
+
'task_results': Dict[str, Any],
|
25
|
+
'messages': List[Dict[str, Any]],
|
26
|
+
}
|
27
|
+
|
22
28
|
class PythonTool(BaseTool):
|
23
29
|
def __init__(self, fn, spec: ToolSpec, expected_credentials: List[ExpectedCredentials]=None):
|
24
30
|
BaseTool.__init__(self, spec=spec)
|
@@ -92,6 +98,31 @@ def _validate_input_schema(input_schema: ToolRequestBody) -> None:
|
|
92
98
|
if not props.get(prop).type:
|
93
99
|
logger.warning(f"Missing type hint for tool property '{prop}' defaulting to 'str'. To remove this warning add a type hint to the property in the tools signature. See Python docs for guidance: https://docs.python.org/3/library/typing.html")
|
94
100
|
|
101
|
+
def _validate_join_tool_func(fn: Callable, sig: inspect.Signature | None = None, name: str | None = None) -> None:
|
102
|
+
if sig is None:
|
103
|
+
sig = inspect.signature(fn)
|
104
|
+
if name is None:
|
105
|
+
name = fn.__name__
|
106
|
+
|
107
|
+
params = sig.parameters
|
108
|
+
type_hints = get_type_hints(fn)
|
109
|
+
|
110
|
+
# Validate parameter order
|
111
|
+
actual_param_names = list(params.keys())
|
112
|
+
expected_param_names = list(JOIN_TOOL_PARAMS.keys())
|
113
|
+
if actual_param_names[:len(expected_param_names)] != expected_param_names:
|
114
|
+
raise ValueError(
|
115
|
+
f"Join tool function '{name}' has incorrect parameter names or order. Expected: {expected_param_names}, got: {actual_param_names}"
|
116
|
+
)
|
117
|
+
|
118
|
+
# Validate the type hints
|
119
|
+
for param, expected_type in JOIN_TOOL_PARAMS.items():
|
120
|
+
if param not in type_hints:
|
121
|
+
raise ValueError(f"Join tool function '{name}' is missing type for parameter '{param}'")
|
122
|
+
actual_type = type_hints[param]
|
123
|
+
if actual_type != expected_type:
|
124
|
+
raise ValueError(f"Join tool function '{name}' has incorrect type for parameter '{param}'. Expected {expected_type}, got {actual_type}")
|
125
|
+
|
95
126
|
def tool(
|
96
127
|
*args,
|
97
128
|
name: str = None,
|
@@ -100,7 +131,8 @@ def tool(
|
|
100
131
|
output_schema: ToolResponseBody = None,
|
101
132
|
permission: ToolPermission = ToolPermission.READ_ONLY,
|
102
133
|
expected_credentials: List[ExpectedCredentials] = None,
|
103
|
-
display_name: str = None
|
134
|
+
display_name: str = None,
|
135
|
+
kind: PythonToolKind = PythonToolKind.TOOL,
|
104
136
|
) -> Callable[[{__name__, __doc__}], PythonTool]:
|
105
137
|
"""
|
106
138
|
Decorator to convert a python function into a callable tool.
|
@@ -152,6 +184,11 @@ def tool(
|
|
152
184
|
spec.binding.python.function = function_binding
|
153
185
|
|
154
186
|
sig = inspect.signature(fn)
|
187
|
+
|
188
|
+
# If the function is a join tool, validate its signature matches the expected parameters. If not, raise error with details.
|
189
|
+
if kind == PythonToolKind.JOIN_TOOL:
|
190
|
+
_validate_join_tool_func(fn, sig, spec.name)
|
191
|
+
|
155
192
|
if not input_schema:
|
156
193
|
try:
|
157
194
|
input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, fn, parse_docstring=True)
|
@@ -190,6 +227,12 @@ def tool(
|
|
190
227
|
|
191
228
|
else:
|
192
229
|
spec.output_schema = ToolResponseBody()
|
230
|
+
|
231
|
+
# Validate the generated schema still conforms to the requirement for a join tool
|
232
|
+
if kind == PythonToolKind.JOIN_TOOL:
|
233
|
+
if not spec.is_custom_join_tool():
|
234
|
+
raise ValueError(f"Join tool '{spec.name}' does not conform to the expected join tool schema. Please ensure the input schema has the required fields: {JOIN_TOOL_PARAMS.keys()} and the output schema is a string.")
|
235
|
+
|
193
236
|
_all_tools.append(t)
|
194
237
|
return t
|
195
238
|
|
@@ -10,6 +10,9 @@ class ToolPermission(str, Enum):
|
|
10
10
|
READ_WRITE = 'read_write'
|
11
11
|
ADMIN = 'admin'
|
12
12
|
|
13
|
+
class PythonToolKind(str, Enum):
|
14
|
+
JOIN_TOOL = 'join_tool'
|
15
|
+
TOOL = 'tool'
|
13
16
|
|
14
17
|
class JsonSchemaObject(BaseModel):
|
15
18
|
model_config = ConfigDict(extra='allow')
|
@@ -163,7 +166,6 @@ class ToolBinding(BaseModel):
|
|
163
166
|
raise ValueError("Only one binding can be set")
|
164
167
|
return self
|
165
168
|
|
166
|
-
|
167
169
|
class ToolSpec(BaseModel):
|
168
170
|
name: str
|
169
171
|
display_name: str | None = None
|
@@ -174,3 +176,47 @@ class ToolSpec(BaseModel):
|
|
174
176
|
binding: ToolBinding = None
|
175
177
|
toolkit_id: str | None = None
|
176
178
|
|
179
|
+
def is_custom_join_tool(self) -> bool:
|
180
|
+
if self.binding.python is None:
|
181
|
+
return False
|
182
|
+
|
183
|
+
# The code below validates the input schema to have the following structure:
|
184
|
+
# {
|
185
|
+
# "type": "object",
|
186
|
+
# "properties": {
|
187
|
+
# "messages": {
|
188
|
+
# "type": "array",
|
189
|
+
# "items": {
|
190
|
+
# "type": "object",
|
191
|
+
# },
|
192
|
+
# },
|
193
|
+
# "task_results": {
|
194
|
+
# "type": "object",
|
195
|
+
# },
|
196
|
+
# "original_query": {
|
197
|
+
# "type": "string",
|
198
|
+
# },
|
199
|
+
# },
|
200
|
+
# "required": {"original_query", "task_results", "messages"},
|
201
|
+
# }
|
202
|
+
|
203
|
+
input_schema = self.input_schema
|
204
|
+
if input_schema.type != 'object':
|
205
|
+
return False
|
206
|
+
|
207
|
+
required_fields = {"original_query", "task_results", "messages"}
|
208
|
+
if input_schema.required is None or set(input_schema.required) != required_fields:
|
209
|
+
return False
|
210
|
+
if input_schema.properties is None or set(input_schema.properties.keys()) != required_fields:
|
211
|
+
return False
|
212
|
+
|
213
|
+
if input_schema.properties["messages"].type != "array":
|
214
|
+
return False
|
215
|
+
if not input_schema.properties["messages"].items or input_schema.properties["messages"].items.type != "object":
|
216
|
+
return False
|
217
|
+
if input_schema.properties["task_results"].type != "object":
|
218
|
+
return False
|
219
|
+
if input_schema.properties["original_query"].type != "string":
|
220
|
+
return False
|
221
|
+
|
222
|
+
return True
|
@@ -105,6 +105,20 @@ def agent_create(
|
|
105
105
|
AgentStyle,
|
106
106
|
typer.Option("--style", help="The style of agent you wish to create"),
|
107
107
|
] = AgentStyle.DEFAULT,
|
108
|
+
custom_join_tool: Annotated[
|
109
|
+
str | None,
|
110
|
+
typer.Option(
|
111
|
+
"--custom-join-tool",
|
112
|
+
help='The name of the python tool to be used by the agent to format and generate the final output. Only needed for "planner" style agents.',
|
113
|
+
),
|
114
|
+
] = None,
|
115
|
+
structured_output: Annotated[
|
116
|
+
str | None,
|
117
|
+
typer.Option(
|
118
|
+
"--structured-output",
|
119
|
+
help='A JSON Schema object that defines the desired structure of the agent\'s final output. Only needed for "planner" style agents.',
|
120
|
+
),
|
121
|
+
] = None,
|
108
122
|
collaborators: Annotated[
|
109
123
|
List[str],
|
110
124
|
typer.Option(
|
@@ -138,6 +152,7 @@ def agent_create(
|
|
138
152
|
chat_params_dict = json.loads(chat_params) if chat_params else {}
|
139
153
|
config_dict = json.loads(config) if config else {}
|
140
154
|
auth_config_dict = json.loads(auth_config) if auth_config else {}
|
155
|
+
structured_output_dict = json.loads(structured_output) if structured_output else None
|
141
156
|
|
142
157
|
agents_controller = AgentsController()
|
143
158
|
agent = agents_controller.generate_agent_spec(
|
@@ -151,6 +166,8 @@ def agent_create(
|
|
151
166
|
provider=provider,
|
152
167
|
llm=llm,
|
153
168
|
style=style,
|
169
|
+
custom_join_tool=custom_join_tool,
|
170
|
+
structured_output=structured_output_dict,
|
154
171
|
collaborators=collaborators,
|
155
172
|
tools=tools,
|
156
173
|
knowledge_base=knowledge_base,
|
@@ -11,9 +11,12 @@ import logging
|
|
11
11
|
from pathlib import Path
|
12
12
|
from copy import deepcopy
|
13
13
|
|
14
|
-
from typing import Iterable, List
|
14
|
+
from typing import Iterable, List, TypeVar
|
15
|
+
from ibm_watsonx_orchestrate.agent_builder.agents.types import AgentStyle
|
16
|
+
from ibm_watsonx_orchestrate.agent_builder.tools.types import ToolSpec
|
15
17
|
from ibm_watsonx_orchestrate.cli.commands.tools.tools_controller import import_python_tool, ToolsController
|
16
18
|
from ibm_watsonx_orchestrate.cli.commands.knowledge_bases.knowledge_bases_controller import import_python_knowledge_base
|
19
|
+
from ibm_watsonx_orchestrate.cli.commands.models.models_controller import import_python_model
|
17
20
|
|
18
21
|
from ibm_watsonx_orchestrate.agent_builder.agents import (
|
19
22
|
Agent,
|
@@ -34,10 +37,14 @@ from ibm_watsonx_orchestrate.utils.utils import check_file_in_zip
|
|
34
37
|
|
35
38
|
logger = logging.getLogger(__name__)
|
36
39
|
|
40
|
+
# Helper generic type for any agent
|
41
|
+
AnyAgentT = TypeVar("AnyAgentT", bound=Agent | ExternalAgent | AssistantAgent)
|
42
|
+
|
37
43
|
def import_python_agent(file: str) -> List[Agent | ExternalAgent | AssistantAgent]:
|
38
44
|
# Import tools
|
39
45
|
import_python_tool(file)
|
40
46
|
import_python_knowledge_base(file)
|
47
|
+
import_python_model(file)
|
41
48
|
|
42
49
|
file_path = Path(file)
|
43
50
|
file_directory = file_path.parent
|
@@ -90,6 +97,8 @@ def parse_create_native_args(name: str, kind: AgentKind, description: str | None
|
|
90
97
|
"description": description,
|
91
98
|
"llm": args.get("llm"),
|
92
99
|
"style": args.get("style"),
|
100
|
+
"custom_join_tool": args.get("custom_join_tool"),
|
101
|
+
"structured_output": args.get("structured_output"),
|
93
102
|
}
|
94
103
|
|
95
104
|
collaborators = args.get("collaborators", [])
|
@@ -202,7 +211,7 @@ class AgentsController:
|
|
202
211
|
return self.knowledge_base_client
|
203
212
|
|
204
213
|
@staticmethod
|
205
|
-
def import_agent(file: str, app_id: str) ->
|
214
|
+
def import_agent(file: str, app_id: str) -> List[Agent | ExternalAgent | AssistantAgent]:
|
206
215
|
agents = parse_file(file)
|
207
216
|
for agent in agents:
|
208
217
|
if app_id and agent.kind != AgentKind.NATIVE and agent.kind != AgentKind.ASSISTANT:
|
@@ -216,7 +225,9 @@ class AgentsController:
|
|
216
225
|
) -> Agent | ExternalAgent | AssistantAgent:
|
217
226
|
match kind:
|
218
227
|
case AgentKind.NATIVE:
|
219
|
-
agent_details = parse_create_native_args(
|
228
|
+
agent_details = parse_create_native_args(
|
229
|
+
name, kind=kind, description=description, **kwargs
|
230
|
+
)
|
220
231
|
agent = Agent.model_validate(agent_details)
|
221
232
|
AgentsController().persist_record(agent=agent, **kwargs)
|
222
233
|
case AgentKind.EXTERNAL:
|
@@ -296,12 +307,17 @@ class AgentsController:
|
|
296
307
|
ref_agent.collaborators = ref_collaborators
|
297
308
|
|
298
309
|
return ref_agent
|
299
|
-
|
310
|
+
|
300
311
|
def dereference_tools(self, agent: Agent) -> Agent:
|
301
312
|
tool_client = self.get_tool_client()
|
302
313
|
|
303
314
|
deref_agent = deepcopy(agent)
|
304
|
-
|
315
|
+
|
316
|
+
# If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
|
317
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
318
|
+
matching_tools = tool_client.get_drafts_by_names(deref_agent.tools + [deref_agent.custom_join_tool])
|
319
|
+
else:
|
320
|
+
matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
|
305
321
|
|
306
322
|
name_id_lookup = {}
|
307
323
|
for tool in matching_tools:
|
@@ -318,6 +334,13 @@ class AgentsController:
|
|
318
334
|
sys.exit(1)
|
319
335
|
deref_tools.append(id)
|
320
336
|
deref_agent.tools = deref_tools
|
337
|
+
|
338
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
339
|
+
join_tool_id = name_id_lookup.get(agent.custom_join_tool)
|
340
|
+
if not join_tool_id:
|
341
|
+
logger.error(f"Failed to find custom join tool. No tools found with the name '{agent.custom_join_tool}'")
|
342
|
+
sys.exit(1)
|
343
|
+
deref_agent.custom_join_tool = join_tool_id
|
321
344
|
|
322
345
|
return deref_agent
|
323
346
|
|
@@ -325,7 +348,12 @@ class AgentsController:
|
|
325
348
|
tool_client = self.get_tool_client()
|
326
349
|
|
327
350
|
ref_agent = deepcopy(agent)
|
328
|
-
|
351
|
+
|
352
|
+
# If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
|
353
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
354
|
+
matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools + [ref_agent.custom_join_tool])
|
355
|
+
else:
|
356
|
+
matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
|
329
357
|
|
330
358
|
id_name_lookup = {}
|
331
359
|
for tool in matching_tools:
|
@@ -342,6 +370,13 @@ class AgentsController:
|
|
342
370
|
sys.exit(1)
|
343
371
|
ref_tools.append(name)
|
344
372
|
ref_agent.tools = ref_tools
|
373
|
+
|
374
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
375
|
+
join_tool_name = id_name_lookup.get(agent.custom_join_tool)
|
376
|
+
if not join_tool_name:
|
377
|
+
logger.error(f"Failed to find custom join tool. No tools found with the id '{agent.custom_join_tool}'")
|
378
|
+
sys.exit(1)
|
379
|
+
ref_agent.custom_join_tool = join_tool_name
|
345
380
|
|
346
381
|
return ref_agent
|
347
382
|
|
@@ -409,7 +444,7 @@ class AgentsController:
|
|
409
444
|
def dereference_native_agent_dependencies(self, agent: Agent) -> Agent:
|
410
445
|
if agent.collaborators and len(agent.collaborators):
|
411
446
|
agent = self.dereference_collaborators(agent)
|
412
|
-
if agent.tools and len(agent.tools):
|
447
|
+
if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
|
413
448
|
agent = self.dereference_tools(agent)
|
414
449
|
if agent.knowledge_base and len(agent.knowledge_base):
|
415
450
|
agent = self.dereference_knowledge_bases(agent)
|
@@ -419,7 +454,7 @@ class AgentsController:
|
|
419
454
|
def reference_native_agent_dependencies(self, agent: Agent) -> Agent:
|
420
455
|
if agent.collaborators and len(agent.collaborators):
|
421
456
|
agent = self.reference_collaborators(agent)
|
422
|
-
if agent.tools and len(agent.tools):
|
457
|
+
if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
|
423
458
|
agent = self.reference_tools(agent)
|
424
459
|
if agent.knowledge_base and len(agent.knowledge_base):
|
425
460
|
agent = self.reference_knowledge_bases(agent)
|
@@ -443,21 +478,21 @@ class AgentsController:
|
|
443
478
|
return agent
|
444
479
|
|
445
480
|
# Convert all names used in an agent to the corresponding ids
|
446
|
-
def dereference_agent_dependencies(self, agent:
|
481
|
+
def dereference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
|
447
482
|
if isinstance(agent, Agent):
|
448
483
|
return self.dereference_native_agent_dependencies(agent)
|
449
484
|
if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
|
450
485
|
return self.dereference_external_or_assistant_agent_dependencies(agent)
|
451
|
-
|
486
|
+
|
452
487
|
# Convert all ids used in an agent to the corresponding names
|
453
|
-
def reference_agent_dependencies(self, agent:
|
488
|
+
def reference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
|
454
489
|
if isinstance(agent, Agent):
|
455
490
|
return self.reference_native_agent_dependencies(agent)
|
456
491
|
if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
|
457
492
|
return self.reference_external_or_assistant_agent_dependencies(agent)
|
458
493
|
|
459
494
|
def publish_or_update_agents(
|
460
|
-
self, agents: Iterable[Agent]
|
495
|
+
self, agents: Iterable[Agent | ExternalAgent | AssistantAgent]
|
461
496
|
):
|
462
497
|
for agent in agents:
|
463
498
|
agent_name = agent.name
|
@@ -476,6 +511,18 @@ class AgentsController:
|
|
476
511
|
all_existing_agents = existing_external_clients + existing_native_agents + existing_assistant_clients
|
477
512
|
agent = self.dereference_agent_dependencies(agent)
|
478
513
|
|
514
|
+
if isinstance(agent, Agent) and agent.style == AgentStyle.PLANNER and isinstance(agent.custom_join_tool, str):
|
515
|
+
tool_client = self.get_tool_client()
|
516
|
+
|
517
|
+
join_tool_spec = ToolSpec.model_validate(
|
518
|
+
tool_client.get_draft_by_id(agent.custom_join_tool)
|
519
|
+
)
|
520
|
+
if not join_tool_spec.is_custom_join_tool():
|
521
|
+
logger.error(
|
522
|
+
f"Tool '{join_tool_spec.name}' configured as the custom join tool is not a valid join tool. A custom join tool must be a Python tool with specific input and output schema."
|
523
|
+
)
|
524
|
+
sys.exit(1)
|
525
|
+
|
479
526
|
agent_kind = agent.kind
|
480
527
|
|
481
528
|
if len(all_existing_agents) > 1:
|
@@ -762,32 +809,32 @@ class AgentsController:
|
|
762
809
|
rich.print(assistants_table)
|
763
810
|
|
764
811
|
def remove_agent(self, name: str, kind: AgentKind):
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
812
|
+
try:
|
813
|
+
if kind == AgentKind.NATIVE:
|
814
|
+
client = self.get_native_client()
|
815
|
+
elif kind == AgentKind.EXTERNAL:
|
816
|
+
client = self.get_external_client()
|
817
|
+
elif kind == AgentKind.ASSISTANT:
|
818
|
+
client = self.get_assistant_client()
|
819
|
+
else:
|
820
|
+
raise ValueError("'kind' must be 'native'")
|
821
|
+
|
822
|
+
draft_agents = client.get_draft_by_name(name)
|
823
|
+
if len(draft_agents) > 1:
|
824
|
+
logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
|
825
|
+
sys.exit(1)
|
826
|
+
if len(draft_agents) > 0:
|
827
|
+
draft_agent = draft_agents[0]
|
828
|
+
agent_id = draft_agent.get("id")
|
829
|
+
client.delete(agent_id=agent_id)
|
830
|
+
|
831
|
+
logger.info(f"Successfully removed agent {name}")
|
832
|
+
else:
|
833
|
+
logger.warning(f"No agent named '{name}' found")
|
834
|
+
except requests.HTTPError as e:
|
835
|
+
logger.error(e.response.text)
|
836
|
+
exit(1)
|
837
|
+
|
791
838
|
def get_spec_file_content(self, agent: Agent | ExternalAgent | AssistantAgent):
|
792
839
|
ref_agent = self.reference_agent_dependencies(agent)
|
793
840
|
agent_spec = ref_agent.model_dump(mode='json', exclude_none=True)
|
@@ -810,7 +857,7 @@ class AgentsController:
|
|
810
857
|
|
811
858
|
return agent
|
812
859
|
|
813
|
-
def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent:
|
860
|
+
def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent | None:
|
814
861
|
native_client = self.get_native_client()
|
815
862
|
external_client = self.get_external_client()
|
816
863
|
assistant_client = self.get_assistant_client()
|