ibm-watsonx-orchestrate 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +2 -1
- ibm_watsonx_orchestrate/agent_builder/agents/agent.py +3 -3
- ibm_watsonx_orchestrate/agent_builder/agents/assistant_agent.py +3 -2
- ibm_watsonx_orchestrate/agent_builder/agents/external_agent.py +3 -2
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +9 -8
- ibm_watsonx_orchestrate/agent_builder/connections/connections.py +4 -3
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/knowledge_base_requests.py +1 -22
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +1 -17
- ibm_watsonx_orchestrate/agent_builder/tools/base_tool.py +2 -1
- ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +14 -13
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +136 -92
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +10 -9
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +7 -7
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +4 -3
- ibm_watsonx_orchestrate/cli/commands/environment/environment_controller.py +5 -5
- ibm_watsonx_orchestrate/cli/commands/environment/types.py +2 -0
- ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_command.py +0 -18
- ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +33 -19
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +1 -1
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +100 -36
- ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +1 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +11 -4
- ibm_watsonx_orchestrate/cli/config.py +3 -3
- ibm_watsonx_orchestrate/cli/init_helper.py +10 -1
- ibm_watsonx_orchestrate/cli/main.py +2 -0
- ibm_watsonx_orchestrate/client/knowledge_bases/knowledge_base_client.py +1 -1
- ibm_watsonx_orchestrate/client/local_service_instance.py +3 -1
- ibm_watsonx_orchestrate/client/service_instance.py +33 -7
- ibm_watsonx_orchestrate/docker/compose-lite.yml +177 -2
- ibm_watsonx_orchestrate/docker/default.env +22 -2
- ibm_watsonx_orchestrate/flow_builder/flows/__init__.py +3 -1
- ibm_watsonx_orchestrate/flow_builder/flows/decorators.py +4 -2
- ibm_watsonx_orchestrate/flow_builder/flows/events.py +10 -9
- ibm_watsonx_orchestrate/flow_builder/flows/flow.py +91 -20
- ibm_watsonx_orchestrate/flow_builder/node.py +12 -1
- ibm_watsonx_orchestrate/flow_builder/types.py +169 -16
- ibm_watsonx_orchestrate/flow_builder/utils.py +120 -5
- ibm_watsonx_orchestrate/utils/exceptions.py +23 -0
- {ibm_watsonx_orchestrate-1.6.1.dist-info → ibm_watsonx_orchestrate-1.7.0.dist-info}/METADATA +4 -4
- {ibm_watsonx_orchestrate-1.6.1.dist-info → ibm_watsonx_orchestrate-1.7.0.dist-info}/RECORD +43 -43
- ibm_watsonx_orchestrate/flow_builder/resources/flow_status.openapi.yml +0 -66
- {ibm_watsonx_orchestrate-1.6.1.dist-info → ibm_watsonx_orchestrate-1.7.0.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.6.1.dist-info → ibm_watsonx_orchestrate-1.7.0.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.6.1.dist-info → ibm_watsonx_orchestrate-1.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
3
3
|
from .types import AgentSpec
|
4
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
4
5
|
|
5
6
|
|
6
7
|
class Agent(AgentSpec):
|
@@ -13,10 +14,9 @@ class Agent(AgentSpec):
|
|
13
14
|
elif file.endswith('.json'):
|
14
15
|
content = json.load(f)
|
15
16
|
else:
|
16
|
-
raise
|
17
|
-
|
17
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
18
18
|
if not content.get("spec_version"):
|
19
|
-
raise
|
19
|
+
raise BadRequest(f"Field 'spec_version' not provided. Please ensure provided spec conforms to a valid spec format")
|
20
20
|
agent = Agent.model_validate(content)
|
21
21
|
|
22
22
|
return agent
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
3
3
|
from .types import AssistantAgentSpec
|
4
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
4
5
|
|
5
6
|
|
6
7
|
class AssistantAgent(AssistantAgentSpec):
|
@@ -13,10 +14,10 @@ class AssistantAgent(AssistantAgentSpec):
|
|
13
14
|
elif file.endswith('.json'):
|
14
15
|
content = json.load(f)
|
15
16
|
else:
|
16
|
-
raise
|
17
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
17
18
|
|
18
19
|
if not content.get("spec_version"):
|
19
|
-
raise
|
20
|
+
raise BadRequest(f"Field 'spec_version' not provided. Please ensure provided spec conforms to a valid spec format")
|
20
21
|
agent = AssistantAgent.model_validate(content)
|
21
22
|
|
22
23
|
return agent
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
3
3
|
from .types import ExternalAgentSpec
|
4
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
4
5
|
|
5
6
|
|
6
7
|
class ExternalAgent(ExternalAgentSpec):
|
@@ -13,10 +14,10 @@ class ExternalAgent(ExternalAgentSpec):
|
|
13
14
|
elif file.endswith('.json'):
|
14
15
|
content = json.load(f)
|
15
16
|
else:
|
16
|
-
raise
|
17
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
17
18
|
|
18
19
|
if not content.get("spec_version"):
|
19
|
-
raise
|
20
|
+
raise BadRequest(f"Field 'spec_version' not provided. Please ensure provided spec conforms to a valid spec format")
|
20
21
|
agent = ExternalAgent.model_validate(content)
|
21
22
|
|
22
23
|
return agent
|
@@ -9,6 +9,7 @@ from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.knowledge_base import
|
|
9
9
|
from ibm_watsonx_orchestrate.agent_builder.agents.webchat_customizations import StarterPrompts, WelcomeContent
|
10
10
|
from pydantic import Field, AliasChoices
|
11
11
|
from typing import Annotated
|
12
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
12
13
|
|
13
14
|
from ibm_watsonx_orchestrate.agent_builder.tools.types import JsonSchemaObject
|
14
15
|
|
@@ -63,7 +64,7 @@ class BaseAgentSpec(BaseModel):
|
|
63
64
|
elif file.endswith('.json'):
|
64
65
|
json.dump(dumped, f, indent=2)
|
65
66
|
else:
|
66
|
-
raise
|
67
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
67
68
|
|
68
69
|
def dumps_spec(self) -> str:
|
69
70
|
dumped = self.model_dump(mode='json', exclude_none=True)
|
@@ -136,7 +137,7 @@ class AgentSpec(BaseAgentSpec):
|
|
136
137
|
@model_validator(mode="after")
|
137
138
|
def validate_kind(self):
|
138
139
|
if self.kind != AgentKind.NATIVE:
|
139
|
-
raise
|
140
|
+
raise BadRequest(f"The specified kind '{self.kind}' cannot be used to create a native agent.")
|
140
141
|
return self
|
141
142
|
|
142
143
|
def validate_agent_fields(values: dict) -> dict:
|
@@ -144,13 +145,13 @@ def validate_agent_fields(values: dict) -> dict:
|
|
144
145
|
for field in ["id", "name", "kind", "description", "collaborators", "tools", "knowledge_base"]:
|
145
146
|
value = values.get(field)
|
146
147
|
if value and not str(value).strip():
|
147
|
-
raise
|
148
|
+
raise BadRequest(f"{field} cannot be empty or just whitespace")
|
148
149
|
|
149
150
|
name = values.get("name")
|
150
151
|
collaborators = values.get("collaborators", []) if values.get("collaborators", []) else []
|
151
152
|
for collaborator in collaborators:
|
152
153
|
if collaborator == name:
|
153
|
-
raise
|
154
|
+
raise BadRequest(f"Circular reference detected. The agent '{name}' cannot contain itself as a collaborator")
|
154
155
|
|
155
156
|
if values.get("style") == AgentStyle.PLANNER:
|
156
157
|
if values.get("custom_join_tool") and values.get("structured_output"):
|
@@ -197,7 +198,7 @@ class ExternalAgentSpec(BaseAgentSpec):
|
|
197
198
|
@model_validator(mode="after")
|
198
199
|
def validate_kind_for_external(self):
|
199
200
|
if self.kind != AgentKind.EXTERNAL:
|
200
|
-
raise
|
201
|
+
raise BadRequest(f"The specified kind '{self.kind}' cannot be used to create an external agent.")
|
201
202
|
return self
|
202
203
|
|
203
204
|
def validate_external_agent_fields(values: dict) -> dict:
|
@@ -205,7 +206,7 @@ def validate_external_agent_fields(values: dict) -> dict:
|
|
205
206
|
for field in ["name", "kind", "description", "title", "tags", "api_url", "chat_params", "nickname", "app_id"]:
|
206
207
|
value = values.get(field)
|
207
208
|
if value and not str(value).strip():
|
208
|
-
raise
|
209
|
+
raise BadRequest(f"{field} cannot be empty or just whitespace")
|
209
210
|
|
210
211
|
context_variables = values.get("context_variables")
|
211
212
|
if context_variables is not None:
|
@@ -250,7 +251,7 @@ class AssistantAgentSpec(BaseAgentSpec):
|
|
250
251
|
@model_validator(mode="after")
|
251
252
|
def validate_kind_for_external(self):
|
252
253
|
if self.kind != AgentKind.ASSISTANT:
|
253
|
-
raise
|
254
|
+
raise BadRequest(f"The specified kind '{self.kind}' cannot be used to create an assistant agent.")
|
254
255
|
return self
|
255
256
|
|
256
257
|
def validate_assistant_agent_fields(values: dict) -> dict:
|
@@ -258,7 +259,7 @@ def validate_assistant_agent_fields(values: dict) -> dict:
|
|
258
259
|
for field in ["name", "kind", "description", "title", "tags", "nickname", "app_id"]:
|
259
260
|
value = values.get(field)
|
260
261
|
if value and not str(value).strip():
|
261
|
-
raise
|
262
|
+
raise BadRequest(f"{field} cannot be empty or just whitespace")
|
262
263
|
|
263
264
|
# Validate context_variables if provided
|
264
265
|
context_variables = values.get("context_variables")
|
@@ -14,6 +14,7 @@ from ibm_watsonx_orchestrate.agent_builder.connections.types import (
|
|
14
14
|
)
|
15
15
|
|
16
16
|
from ibm_watsonx_orchestrate.utils.utils import sanatize_app_id
|
17
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
17
18
|
|
18
19
|
logger = logging.getLogger(__name__)
|
19
20
|
|
@@ -61,7 +62,7 @@ def _clean_env_vars(vars: dict[str:str], requirements: List[str], app_id: str) -
|
|
61
62
|
missing_requirements_str = ", ".join(missing_requirements)
|
62
63
|
message = f"Missing requirement environment variables '{missing_requirements_str}' for connection '{app_id}'"
|
63
64
|
logger.error(message)
|
64
|
-
raise
|
65
|
+
raise BadRequest(message)
|
65
66
|
|
66
67
|
return required_env_vars
|
67
68
|
|
@@ -114,7 +115,7 @@ def get_connection_type(app_id: str) -> ConnectionSecurityScheme:
|
|
114
115
|
if not expected_schema:
|
115
116
|
message = f"No credentials found for connections '{app_id}'"
|
116
117
|
logger.error(message)
|
117
|
-
raise
|
118
|
+
raise BadRequest(message)
|
118
119
|
|
119
120
|
auth_types = {e.value for e in ConnectionSecurityScheme}
|
120
121
|
if expected_schema not in auth_types:
|
@@ -132,6 +133,6 @@ def get_application_connection_credentials(type: ConnectionType, app_id: str) ->
|
|
132
133
|
if not _validate_schema_type(requested_type=requested_schema, expected_type=expected_schema):
|
133
134
|
message = f"The requested type '{requested_schema}' does not match the type '{expected_schema}' for the connection '{app_id}'"
|
134
135
|
logger.error(message)
|
135
|
-
raise
|
136
|
+
raise BadRequest(message)
|
136
137
|
|
137
138
|
return _get_credentials_model(connection_type=requested_schema, app_id=sanitized_app_id)
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
3
|
-
from .types import KnowledgeBaseSpec
|
4
|
-
|
3
|
+
from .types import KnowledgeBaseSpec
|
5
4
|
|
6
5
|
class KnowledgeBaseCreateRequest(KnowledgeBaseSpec):
|
7
6
|
|
@@ -21,23 +20,3 @@ class KnowledgeBaseCreateRequest(KnowledgeBaseSpec):
|
|
21
20
|
knowledge_base = KnowledgeBaseSpec.model_validate(content)
|
22
21
|
|
23
22
|
return knowledge_base
|
24
|
-
|
25
|
-
|
26
|
-
class KnowledgeBaseUpdateRequest(PatchKnowledgeBase):
|
27
|
-
|
28
|
-
@staticmethod
|
29
|
-
def from_spec(file: str) -> 'PatchKnowledgeBase':
|
30
|
-
with open(file, 'r') as f:
|
31
|
-
if file.endswith('.yaml') or file.endswith('.yml'):
|
32
|
-
content = yaml_safe_load(f)
|
33
|
-
elif file.endswith('.json'):
|
34
|
-
content = json.load(f)
|
35
|
-
else:
|
36
|
-
raise ValueError('file must end in .json, .yaml, or .yml')
|
37
|
-
|
38
|
-
if not content.get("spec_version"):
|
39
|
-
raise ValueError(f"Field 'spec_version' not provided. Please ensure provided spec conforms to a valid spec format")
|
40
|
-
|
41
|
-
patch = PatchKnowledgeBase.model_validate(content)
|
42
|
-
|
43
|
-
return patch
|
@@ -3,7 +3,7 @@ from datetime import datetime
|
|
3
3
|
from uuid import UUID
|
4
4
|
from enum import Enum
|
5
5
|
|
6
|
-
from pydantic import BaseModel
|
6
|
+
from pydantic import BaseModel
|
7
7
|
|
8
8
|
class SpecVersion(str, Enum):
|
9
9
|
V1 = "v1"
|
@@ -219,22 +219,6 @@ class KnowledgeBaseBuiltInVectorIndexConfig(BaseModel):
|
|
219
219
|
chunk_overlap: Optional[int] = None
|
220
220
|
limit: Optional[int] = None
|
221
221
|
|
222
|
-
class PatchKnowledgeBase(BaseModel):
|
223
|
-
"""request payload schema"""
|
224
|
-
description: Optional[str] = None
|
225
|
-
documents: list[str] = None
|
226
|
-
conversational_search_tool: Optional[ConversationalSearchConfig] = None
|
227
|
-
prioritize_built_in_index: Optional[bool] = None
|
228
|
-
representation: Optional[KnowledgeBaseRepresentation] = None
|
229
|
-
|
230
|
-
@model_validator(mode="after")
|
231
|
-
def validate_fields(self):
|
232
|
-
if self.documents and self.conversational_search_tool and self.conversational_search_tool.index_config:
|
233
|
-
raise ValueError("Must not provide both \"documents\" or \"conversational_search_tool.index_config\"")
|
234
|
-
if self.conversational_search_tool and self.conversational_search_tool.index_config and len(self.conversational_search_tool.index_config) != 1:
|
235
|
-
raise ValueError(f"Must provide exactly one conversational_search_tool.index_config. Provided {len(self.conversational_search_tool.index_config)}.")
|
236
|
-
return self
|
237
|
-
|
238
222
|
class KnowledgeBaseSpec(BaseModel):
|
239
223
|
"""Schema for a complete knowledge-base."""
|
240
224
|
spec_version: SpecVersion = None
|
@@ -4,6 +4,7 @@ import yaml
|
|
4
4
|
|
5
5
|
from .types import ToolSpec
|
6
6
|
|
7
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
7
8
|
|
8
9
|
class BaseTool:
|
9
10
|
__tool_spec__: ToolSpec
|
@@ -22,7 +23,7 @@ class BaseTool:
|
|
22
23
|
elif file.endswith('.json'):
|
23
24
|
json.dump(dumped, f, indent=2)
|
24
25
|
else:
|
25
|
-
raise
|
26
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
26
27
|
|
27
28
|
def dumps_spec(self) -> str:
|
28
29
|
dumped = self.__tool_spec__.model_dump(mode='json', exclude_unset=True, exclude_none=True, by_alias=True)
|
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
import os.path
|
4
4
|
import logging
|
5
5
|
from typing import Dict, Any, List
|
6
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
6
7
|
|
7
8
|
import yaml
|
8
9
|
import yaml.constructor
|
@@ -40,7 +41,7 @@ class OpenAPITool(BaseTool):
|
|
40
41
|
BaseTool.__init__(self, spec=spec)
|
41
42
|
|
42
43
|
if self.__tool_spec__.binding.openapi is None:
|
43
|
-
raise
|
44
|
+
raise BadRequest('Missing openapi binding')
|
44
45
|
|
45
46
|
async def __call__(self, **kwargs):
|
46
47
|
raise RuntimeError('OpenAPI Tools are only available when deployed onto watson orchestrate or the watson '
|
@@ -54,10 +55,10 @@ class OpenAPITool(BaseTool):
|
|
54
55
|
elif file.endswith('.json'):
|
55
56
|
spec = ToolSpec.model_validate(json.load(f))
|
56
57
|
else:
|
57
|
-
raise
|
58
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
58
59
|
|
59
60
|
if spec.binding.openapi is None or spec.binding.openapi is None:
|
60
|
-
raise
|
61
|
+
raise BadRequest('failed to load python tool as the tool had no openapi binding')
|
61
62
|
|
62
63
|
return OpenAPITool(spec=spec)
|
63
64
|
|
@@ -108,11 +109,11 @@ def create_openapi_json_tool(
|
|
108
109
|
paths = openapi_contents.get('paths', {})
|
109
110
|
route = paths.get(http_path)
|
110
111
|
if route is None:
|
111
|
-
raise
|
112
|
+
raise BadRequest(f"Path {http_path} not found in paths. Available endpoints are: {list(paths.keys())}")
|
112
113
|
|
113
114
|
route_spec = route.get(http_method.lower(), route.get(http_method.upper()))
|
114
115
|
if route_spec is None:
|
115
|
-
raise
|
116
|
+
raise BadRequest(
|
116
117
|
f"Path {http_path} did not have an http_method {http_method}. Available methods are {list(route.keys())}")
|
117
118
|
|
118
119
|
operation_id = re.sub( r'(\W|_)+', '_', route_spec.get('operationId') ) \
|
@@ -121,12 +122,12 @@ def create_openapi_json_tool(
|
|
121
122
|
spec_name = name or operation_id
|
122
123
|
spec_permission = permission or _action_to_perm(route_spec.get('x-ibm-operation', {}).get('action'))
|
123
124
|
if spec_name is None:
|
124
|
-
raise
|
125
|
-
f"
|
125
|
+
raise BadRequest(
|
126
|
+
f"Failed to import tool from endpoint {http_method}: {http_path} as no operationId was provided. An operationId must be provided to generate the name of the tool.")
|
126
127
|
|
127
128
|
spec_description = description or route_spec.get('description')
|
128
129
|
if spec_description is None:
|
129
|
-
raise
|
130
|
+
raise BadRequest(
|
130
131
|
f"No description provided for tool. {http_method}: {http_path} did not specify a description field, and no description was provided")
|
131
132
|
|
132
133
|
spec = ToolSpec(
|
@@ -199,7 +200,7 @@ def create_openapi_json_tool(
|
|
199
200
|
for needed_security in route_spec.get('security', []) + openapi_spec.get('security', []):
|
200
201
|
name = next(iter(needed_security.keys()), None)
|
201
202
|
if name is None or name not in security_schemes_map:
|
202
|
-
raise
|
203
|
+
raise BadRequest(f"Invalid openapi spec, {HTTP_METHOD} {http_path} asks for a security scheme of {name}, "
|
203
204
|
f"but no such security scheme was configured in the .security section of the spec")
|
204
205
|
|
205
206
|
security.append(security_schemes_map[name])
|
@@ -260,23 +261,23 @@ async def _get_openapi_spec_from_uri(openapi_uri: str) -> Dict[str, Any]:
|
|
260
261
|
elif openapi_uri.endswith('.yaml') or openapi_uri.endswith('.yml'):
|
261
262
|
openapi_contents = yaml_safe_load(fp)
|
262
263
|
else:
|
263
|
-
raise
|
264
|
+
raise BadRequest(
|
264
265
|
f"Unexpected file extension for file {openapi_uri}, expected one of [.json, .yaml, .yml]")
|
265
266
|
elif openapi_uri.endswith('.json'):
|
266
267
|
async with httpx.AsyncClient() as client:
|
267
268
|
r = await client.get(openapi_uri)
|
268
269
|
if r.status_code != 200:
|
269
|
-
raise
|
270
|
+
raise BadRequest(f"Failed to fetch an openapi spec from {openapi_uri}, status code: {r.status_code}")
|
270
271
|
openapi_contents = r.json()
|
271
272
|
elif openapi_uri.endswith('.yaml'):
|
272
273
|
async with httpx.AsyncClient() as client:
|
273
274
|
r = await client.get(openapi_uri)
|
274
275
|
if r.status_code != 200:
|
275
|
-
raise
|
276
|
+
raise BadRequest(f"Failed to fetch an openapi spec from {openapi_uri}, status code: {r.status_code}")
|
276
277
|
openapi_contents = yaml_safe_load(r.text)
|
277
278
|
|
278
279
|
if openapi_contents is None:
|
279
|
-
raise
|
280
|
+
raise BadRequest(f"Unrecognized path or uri {openapi_uri}")
|
280
281
|
|
281
282
|
return openapi_contents
|
282
283
|
|
@@ -5,9 +5,6 @@ import os
|
|
5
5
|
from typing import Any, Callable, Dict, List, get_type_hints
|
6
6
|
import logging
|
7
7
|
|
8
|
-
import docstring_parser
|
9
|
-
from langchain_core.tools.base import create_schema_from_function
|
10
|
-
from langchain_core.utils.json_schema import dereference_refs
|
11
8
|
from pydantic import TypeAdapter, BaseModel
|
12
9
|
|
13
10
|
from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
@@ -15,6 +12,7 @@ from ibm_watsonx_orchestrate.agent_builder.connections import ExpectedCredential
|
|
15
12
|
from .base_tool import BaseTool
|
16
13
|
from .types import PythonToolKind, ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody, JsonSchemaObject, ToolBinding, \
|
17
14
|
PythonToolBinding
|
15
|
+
from ibm_watsonx_orchestrate.utils.exceptions import BadRequest
|
18
16
|
|
19
17
|
_all_tools = []
|
20
18
|
logger = logging.getLogger(__name__)
|
@@ -25,15 +23,136 @@ JOIN_TOOL_PARAMS = {
|
|
25
23
|
'messages': List[Dict[str, Any]],
|
26
24
|
}
|
27
25
|
|
26
|
+
def _parse_expected_credentials(expected_credentials: ExpectedCredentials | dict):
|
27
|
+
parsed_expected_credentials = []
|
28
|
+
if expected_credentials:
|
29
|
+
for credential in expected_credentials:
|
30
|
+
if isinstance(credential, ExpectedCredentials):
|
31
|
+
parsed_expected_credentials.append(credential)
|
32
|
+
else:
|
33
|
+
parsed_expected_credentials.append(ExpectedCredentials.model_validate(credential))
|
34
|
+
|
35
|
+
return parsed_expected_credentials
|
36
|
+
|
28
37
|
class PythonTool(BaseTool):
|
29
|
-
def __init__(self,
|
30
|
-
|
38
|
+
def __init__(self,
|
39
|
+
fn,
|
40
|
+
name: str = None,
|
41
|
+
description: str = None,
|
42
|
+
input_schema: ToolRequestBody = None,
|
43
|
+
output_schema: ToolResponseBody = None,
|
44
|
+
permission: ToolPermission = ToolPermission.READ_ONLY,
|
45
|
+
expected_credentials: List[ExpectedCredentials] = None,
|
46
|
+
display_name: str = None,
|
47
|
+
kind: PythonToolKind = PythonToolKind.TOOL,
|
48
|
+
spec=None
|
49
|
+
):
|
31
50
|
self.fn = fn
|
32
|
-
self.
|
51
|
+
self.name = name
|
52
|
+
self.description = description
|
53
|
+
self.input_schema = input_schema
|
54
|
+
self.output_schema = output_schema
|
55
|
+
self.permission = permission
|
56
|
+
self.display_name = display_name
|
57
|
+
self.kind = kind
|
58
|
+
self.expected_credentials=_parse_expected_credentials(expected_credentials)
|
59
|
+
self._spec = None
|
60
|
+
if spec:
|
61
|
+
self._spec = spec
|
33
62
|
|
34
63
|
def __call__(self, *args, **kwargs):
|
35
64
|
return self.fn(*args, **kwargs)
|
65
|
+
|
66
|
+
@property
|
67
|
+
def __tool_spec__(self):
|
68
|
+
if self._spec:
|
69
|
+
return self._spec
|
70
|
+
|
71
|
+
import docstring_parser
|
72
|
+
from langchain_core.tools.base import create_schema_from_function
|
73
|
+
from langchain_core.utils.json_schema import dereference_refs
|
74
|
+
|
75
|
+
if self.fn.__doc__ is not None:
|
76
|
+
doc = docstring_parser.parse(self.fn.__doc__)
|
77
|
+
else:
|
78
|
+
doc = None
|
79
|
+
|
80
|
+
_desc = self.description
|
81
|
+
if self.description is None and doc is not None:
|
82
|
+
_desc = doc.description
|
83
|
+
|
84
|
+
|
85
|
+
spec = ToolSpec(
|
86
|
+
name=self.name or self.fn.__name__,
|
87
|
+
display_name=self.display_name,
|
88
|
+
description=_desc,
|
89
|
+
permission=self.permission
|
90
|
+
)
|
91
|
+
|
92
|
+
spec.binding = ToolBinding(python=PythonToolBinding(function=''))
|
93
|
+
|
94
|
+
linux_friendly_os_cwd = os.getcwd().replace("\\", "/")
|
95
|
+
function_binding = (inspect.getsourcefile(self.fn)
|
96
|
+
.replace("\\", "/")
|
97
|
+
.replace(linux_friendly_os_cwd+'/', '')
|
98
|
+
.replace('.py', '')
|
99
|
+
.replace('/','.') +
|
100
|
+
f":{self.fn.__name__}")
|
101
|
+
spec.binding.python.function = function_binding
|
102
|
+
|
103
|
+
sig = inspect.signature(self.fn)
|
104
|
+
|
105
|
+
# If the function is a join tool, validate its signature matches the expected parameters. If not, raise error with details.
|
106
|
+
if self.kind == PythonToolKind.JOIN_TOOL:
|
107
|
+
_validate_join_tool_func(self.fn, sig, spec.name)
|
108
|
+
|
109
|
+
if not self.input_schema:
|
110
|
+
try:
|
111
|
+
input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, self.fn, parse_docstring=True)
|
112
|
+
except:
|
113
|
+
logger.warning("Unable to properly parse parameter descriptions due to incorrectly formatted docstring. This may result in degraded agent performance. To fix this, please ensure the docstring conforms to Google's docstring format.")
|
114
|
+
input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, self.fn, parse_docstring=False)
|
115
|
+
input_schema_json = input_schema_model.model_json_schema()
|
116
|
+
input_schema_json = dereference_refs(input_schema_json)
|
117
|
+
|
118
|
+
# Convert the input schema to a JsonSchemaObject
|
119
|
+
input_schema_obj = JsonSchemaObject(**input_schema_json)
|
120
|
+
input_schema_obj = _fix_optional(input_schema_obj)
|
121
|
+
|
122
|
+
spec.input_schema = ToolRequestBody(
|
123
|
+
type='object',
|
124
|
+
properties=input_schema_obj.properties or {},
|
125
|
+
required=input_schema_obj.required or []
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
spec.input_schema = self.input_schema
|
129
|
+
|
130
|
+
_validate_input_schema(spec.input_schema)
|
131
|
+
|
132
|
+
if not self.output_schema:
|
133
|
+
ret = sig.return_annotation
|
134
|
+
if ret != sig.empty:
|
135
|
+
_schema = dereference_refs(TypeAdapter(ret).json_schema())
|
136
|
+
if '$defs' in _schema:
|
137
|
+
_schema.pop('$defs')
|
138
|
+
spec.output_schema = _fix_optional(ToolResponseBody(**_schema))
|
139
|
+
else:
|
140
|
+
spec.output_schema = ToolResponseBody()
|
141
|
+
|
142
|
+
if doc is not None and doc.returns is not None and doc.returns.description is not None:
|
143
|
+
spec.output_schema.description = doc.returns.description
|
144
|
+
|
145
|
+
else:
|
146
|
+
spec.output_schema = ToolResponseBody()
|
147
|
+
|
148
|
+
# Validate the generated schema still conforms to the requirement for a join tool
|
149
|
+
if self.kind == PythonToolKind.JOIN_TOOL:
|
150
|
+
if not spec.is_custom_join_tool():
|
151
|
+
raise ValueError(f"Join tool '{spec.name}' does not conform to the expected join tool schema. Please ensure the input schema has the required fields: {JOIN_TOOL_PARAMS.keys()} and the output schema is a string.")
|
36
152
|
|
153
|
+
self._spec = spec
|
154
|
+
return spec
|
155
|
+
|
37
156
|
@staticmethod
|
38
157
|
def from_spec(file: str) -> 'PythonTool':
|
39
158
|
with open(file, 'r') as f:
|
@@ -42,10 +161,10 @@ class PythonTool(BaseTool):
|
|
42
161
|
elif file.endswith('.json'):
|
43
162
|
spec = ToolSpec.model_validate(json.load(f))
|
44
163
|
else:
|
45
|
-
raise
|
164
|
+
raise BadRequest('file must end in .json, .yaml, or .yml')
|
46
165
|
|
47
166
|
if spec.binding.python is None:
|
48
|
-
raise
|
167
|
+
raise BadRequest('failed to load python tool as the tool had no python binding')
|
49
168
|
|
50
169
|
[module, fn_name] = spec.binding.python.function.split(':')
|
51
170
|
fn = getattr(importlib.import_module(module), fn_name)
|
@@ -147,92 +266,17 @@ def tool(
|
|
147
266
|
"""
|
148
267
|
# inspiration: https://github.com/pydantic/pydantic/blob/main/pydantic/validate_call_decorator.py
|
149
268
|
def _tool_decorator(fn):
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
spec = ToolSpec(
|
161
|
-
name=name or fn.__name__,
|
269
|
+
t = PythonTool(
|
270
|
+
fn=fn,
|
271
|
+
name=name,
|
272
|
+
description=description,
|
273
|
+
input_schema=input_schema,
|
274
|
+
output_schema=output_schema,
|
275
|
+
permission=permission,
|
276
|
+
expected_credentials=expected_credentials,
|
162
277
|
display_name=display_name,
|
163
|
-
|
164
|
-
permission=permission
|
278
|
+
kind=kind
|
165
279
|
)
|
166
|
-
|
167
|
-
parsed_expected_credentials = []
|
168
|
-
if expected_credentials:
|
169
|
-
for credential in expected_credentials:
|
170
|
-
if isinstance(credential, ExpectedCredentials):
|
171
|
-
parsed_expected_credentials.append(credential)
|
172
|
-
else:
|
173
|
-
parsed_expected_credentials.append(ExpectedCredentials.model_validate(credential))
|
174
|
-
|
175
|
-
t = PythonTool(fn=fn, spec=spec, expected_credentials=parsed_expected_credentials)
|
176
|
-
spec.binding = ToolBinding(python=PythonToolBinding(function=''))
|
177
|
-
|
178
|
-
linux_friendly_os_cwd = os.getcwd().replace("\\", "/")
|
179
|
-
function_binding = (inspect.getsourcefile(fn)
|
180
|
-
.replace("\\", "/")
|
181
|
-
.replace(linux_friendly_os_cwd+'/', '')
|
182
|
-
.replace('.py', '')
|
183
|
-
.replace('/','.') +
|
184
|
-
f":{fn.__name__}")
|
185
|
-
spec.binding.python.function = function_binding
|
186
|
-
|
187
|
-
sig = inspect.signature(fn)
|
188
|
-
|
189
|
-
# If the function is a join tool, validate its signature matches the expected parameters. If not, raise error with details.
|
190
|
-
if kind == PythonToolKind.JOIN_TOOL:
|
191
|
-
_validate_join_tool_func(fn, sig, spec.name)
|
192
|
-
|
193
|
-
if not input_schema:
|
194
|
-
try:
|
195
|
-
input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, fn, parse_docstring=True)
|
196
|
-
except:
|
197
|
-
logger.warning("Unable to properly parse parameter descriptions due to incorrectly formatted docstring. This may result in degraded agent performance. To fix this, please ensure the docstring conforms to Google's docstring format.")
|
198
|
-
input_schema_model: type[BaseModel] = create_schema_from_function(spec.name, fn, parse_docstring=False)
|
199
|
-
input_schema_json = input_schema_model.model_json_schema()
|
200
|
-
input_schema_json = dereference_refs(input_schema_json)
|
201
|
-
|
202
|
-
# Convert the input schema to a JsonSchemaObject
|
203
|
-
input_schema_obj = JsonSchemaObject(**input_schema_json)
|
204
|
-
input_schema_obj = _fix_optional(input_schema_obj)
|
205
|
-
|
206
|
-
spec.input_schema = ToolRequestBody(
|
207
|
-
type='object',
|
208
|
-
properties=input_schema_obj.properties or {},
|
209
|
-
required=input_schema_obj.required or []
|
210
|
-
)
|
211
|
-
else:
|
212
|
-
spec.input_schema = input_schema
|
213
|
-
|
214
|
-
_validate_input_schema(spec.input_schema)
|
215
|
-
|
216
|
-
if not output_schema:
|
217
|
-
ret = sig.return_annotation
|
218
|
-
if ret != sig.empty:
|
219
|
-
_schema = dereference_refs(TypeAdapter(ret).json_schema())
|
220
|
-
if '$defs' in _schema:
|
221
|
-
_schema.pop('$defs')
|
222
|
-
spec.output_schema = _fix_optional(ToolResponseBody(**_schema))
|
223
|
-
else:
|
224
|
-
spec.output_schema = ToolResponseBody()
|
225
|
-
|
226
|
-
if doc is not None and doc.returns is not None and doc.returns.description is not None:
|
227
|
-
spec.output_schema.description = doc.returns.description
|
228
|
-
|
229
|
-
else:
|
230
|
-
spec.output_schema = ToolResponseBody()
|
231
|
-
|
232
|
-
# Validate the generated schema still conforms to the requirement for a join tool
|
233
|
-
if kind == PythonToolKind.JOIN_TOOL:
|
234
|
-
if not spec.is_custom_join_tool():
|
235
|
-
raise ValueError(f"Join tool '{spec.name}' does not conform to the expected join tool schema. Please ensure the input schema has the required fields: {JOIN_TOOL_PARAMS.keys()} and the output schema is a string.")
|
236
280
|
|
237
281
|
_all_tools.append(t)
|
238
282
|
return t
|