ibm-watsonx-orchestrate 1.8.0b1__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -1
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +2 -2
- ibm_watsonx_orchestrate/agent_builder/models/types.py +5 -0
- ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +61 -11
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +18 -6
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +12 -5
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +15 -3
- ibm_watsonx_orchestrate/cli/commands/channels/types.py +2 -2
- ibm_watsonx_orchestrate/cli/commands/channels/webchat/channels_webchat_controller.py +2 -3
- ibm_watsonx_orchestrate/cli/commands/connections/connections_controller.py +6 -3
- ibm_watsonx_orchestrate/cli/commands/copilot/copilot_controller.py +103 -23
- ibm_watsonx_orchestrate/cli/commands/evaluations/evaluations_command.py +86 -31
- ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +17 -13
- ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +5 -8
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +147 -37
- ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +4 -2
- ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_controller.py +9 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
- ibm_watsonx_orchestrate/client/connections/connections_client.py +19 -32
- ibm_watsonx_orchestrate/client/copilot/cpe/copilot_cpe_client.py +5 -3
- ibm_watsonx_orchestrate/client/utils.py +17 -16
- ibm_watsonx_orchestrate/docker/compose-lite.yml +127 -12
- ibm_watsonx_orchestrate/docker/default.env +26 -21
- ibm_watsonx_orchestrate/flow_builder/flows/__init__.py +2 -2
- ibm_watsonx_orchestrate/flow_builder/flows/constants.py +2 -0
- ibm_watsonx_orchestrate/flow_builder/flows/flow.py +52 -10
- ibm_watsonx_orchestrate/flow_builder/node.py +34 -3
- ibm_watsonx_orchestrate/flow_builder/types.py +144 -26
- ibm_watsonx_orchestrate/flow_builder/utils.py +7 -5
- {ibm_watsonx_orchestrate-1.8.0b1.dist-info → ibm_watsonx_orchestrate-1.9.0.dist-info}/METADATA +1 -3
- {ibm_watsonx_orchestrate-1.8.0b1.dist-info → ibm_watsonx_orchestrate-1.9.0.dist-info}/RECORD +34 -35
- ibm_watsonx_orchestrate/agent_builder/utils/pydantic_utils.py +0 -149
- {ibm_watsonx_orchestrate-1.8.0b1.dist-info → ibm_watsonx_orchestrate-1.9.0.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.8.0b1.dist-info → ibm_watsonx_orchestrate-1.9.0.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.8.0b1.dist-info → ibm_watsonx_orchestrate-1.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -86,11 +86,10 @@ class GenerationConfiguration(BaseModel):
|
|
86
86
|
{
|
87
87
|
"model_id": "meta-llama/llama-3-1-70b-instruct",
|
88
88
|
"prompt_instruction": "When the documents are in different languages, you should respond in english.",
|
89
|
-
"retrieval_confidence_threshold": "Lowest",
|
90
89
|
"generated_response_length": "Moderate",
|
91
|
-
"response_confidence_threshold": "Low",
|
92
90
|
"display_text_no_results_found": "no docs found",
|
93
91
|
"display_text_connectivity_issue": "conn failed",
|
92
|
+
"idk_message": "I dont know",
|
94
93
|
}
|
95
94
|
"""
|
96
95
|
|
@@ -99,6 +98,7 @@ class GenerationConfiguration(BaseModel):
|
|
99
98
|
generated_response_length: Optional[GeneratedResponseLength] = None
|
100
99
|
display_text_no_results_found: Optional[str] = None
|
101
100
|
display_text_connectivity_issue: Optional[str] = None
|
101
|
+
idk_message: Optional[str] = None
|
102
102
|
|
103
103
|
class FieldMapping(BaseModel):
|
104
104
|
"""
|
@@ -87,6 +87,11 @@ class ProviderConfig(BaseModel):
|
|
87
87
|
azure_entra_tenant_id: Optional[str] = Field(None, alias="azureEntraTenantId")
|
88
88
|
azure_ad_token: Optional[str] = Field(None, alias="azureAdToken")
|
89
89
|
azure_model_name: Optional[str] = Field(None, alias="azureModelName")
|
90
|
+
azure_inference_deployment_name: Optional[str] = Field(None, alias="azureDeploymentName")
|
91
|
+
azure_inference_api_version: Optional[str] = Field(None, alias="azureApiVersion")
|
92
|
+
azure_inference_extra_params: Optional[str] = Field(None, alias="azureExtraParams")
|
93
|
+
azure_inference_foundry_url: Optional[str] = Field(None, alias="azureFoundryUrl")
|
94
|
+
|
90
95
|
|
91
96
|
# Workers AI specific
|
92
97
|
workers_ai_account_id: Optional[str] = Field(None, alias="workersAiAccountId")
|
@@ -14,7 +14,7 @@ from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
|
|
14
14
|
from .types import ToolSpec
|
15
15
|
from .base_tool import BaseTool
|
16
16
|
from .types import HTTP_METHOD, ToolPermission, ToolRequestBody, ToolResponseBody, \
|
17
|
-
OpenApiToolBinding, \
|
17
|
+
OpenApiToolBinding, AcknowledgementBinding, \
|
18
18
|
JsonSchemaObject, ToolBinding, OpenApiSecurityScheme, CallbackBinding
|
19
19
|
|
20
20
|
import json
|
@@ -207,35 +207,82 @@ def create_openapi_json_tool(
|
|
207
207
|
|
208
208
|
# If it's an async tool, add callback binding
|
209
209
|
if spec.is_async:
|
210
|
-
|
211
|
-
|
212
210
|
callbacks = route_spec.get('callbacks', {})
|
213
211
|
callback_name = next(iter(callbacks.keys()))
|
214
212
|
callback_spec = callbacks[callback_name]
|
215
213
|
callback_path = next(iter(callback_spec.keys()))
|
216
214
|
callback_method = next(iter(callback_spec[callback_path].keys()))
|
215
|
+
callback_operation = callback_spec[callback_path][callback_method]
|
217
216
|
|
218
|
-
#
|
219
|
-
# Note: Currently assuming the callback URL parameter will be named 'callbackUrl' in the OpenAPI spec
|
220
|
-
# Future phases will handle other naming conventions
|
217
|
+
# Extract callback input schema from the callback requestBody
|
221
218
|
callback_input_schema = ToolRequestBody(
|
219
|
+
type='object',
|
220
|
+
properties={},
|
221
|
+
required=[]
|
222
|
+
)
|
223
|
+
|
224
|
+
# Handle callback parameters (query, path, header params)
|
225
|
+
callback_parameters = callback_operation.get('parameters') or []
|
226
|
+
for parameter in callback_parameters:
|
227
|
+
name = f"{parameter['in']}_{parameter['name']}"
|
228
|
+
if parameter.get('required'):
|
229
|
+
callback_input_schema.required.append(name)
|
230
|
+
parameter['schema']['title'] = parameter['name']
|
231
|
+
parameter['schema']['description'] = parameter.get('description', None)
|
232
|
+
callback_input_schema.properties[name] = JsonSchemaObject.model_validate(parameter['schema'])
|
233
|
+
callback_input_schema.properties[name].in_field = parameter['in']
|
234
|
+
callback_input_schema.properties[name].aliasName = parameter['name']
|
235
|
+
|
236
|
+
# Handle callback request body
|
237
|
+
callback_request_body_params = callback_operation.get('requestBody', {}).get('content', {}).get(http_response_content_type, {}).get('schema', None)
|
238
|
+
if callback_request_body_params is not None:
|
239
|
+
callback_input_schema.required.append('__requestBody__')
|
240
|
+
callback_request_body_params = copy.deepcopy(callback_request_body_params)
|
241
|
+
callback_request_body_params['in'] = 'body'
|
242
|
+
if callback_request_body_params.get('title') is None:
|
243
|
+
callback_request_body_params['title'] = 'CallbackRequestBody'
|
244
|
+
if callback_request_body_params.get('description') is None:
|
245
|
+
callback_request_body_params['description'] = 'The callback request body used for this async operation.'
|
246
|
+
|
247
|
+
callback_input_schema.properties['__requestBody__'] = JsonSchemaObject.model_validate(callback_request_body_params)
|
248
|
+
|
249
|
+
# Extract callback output schema
|
250
|
+
callback_responses = callback_operation.get('responses', {})
|
251
|
+
callback_response = callback_responses.get(str(http_success_response_code), {})
|
252
|
+
callback_response_description = callback_response.get('description')
|
253
|
+
callback_response_schema = callback_response.get('content', {}).get(http_response_content_type, {}).get('schema', {})
|
254
|
+
|
255
|
+
callback_response_schema['required'] = []
|
256
|
+
callback_output_schema = ToolResponseBody.model_validate(callback_response_schema)
|
257
|
+
callback_output_schema.description = callback_response_description
|
258
|
+
|
259
|
+
# Remove callbackUrl parameter from main tool input schema
|
260
|
+
original_input_schema = ToolRequestBody(
|
222
261
|
type='object',
|
223
262
|
properties={k: v for k, v in spec.input_schema.properties.items() if not k.endswith('_callbackUrl')},
|
224
263
|
required=[r for r in spec.input_schema.required if not r.endswith('_callbackUrl')]
|
225
264
|
)
|
265
|
+
spec.input_schema = original_input_schema
|
226
266
|
|
227
|
-
|
228
|
-
|
229
|
-
|
267
|
+
original_response_schema = spec.output_schema
|
268
|
+
|
230
269
|
callback_binding = CallbackBinding(
|
231
270
|
callback_url=callback_path,
|
232
271
|
method=callback_method.upper(),
|
233
|
-
|
234
|
-
output_schema=spec.output_schema
|
272
|
+
output_schema=callback_output_schema
|
235
273
|
)
|
236
274
|
|
275
|
+
# Create acknowledgement binding with the original response schema
|
276
|
+
acknowledgement_binding = AcknowledgementBinding(
|
277
|
+
output_schema=original_response_schema
|
278
|
+
)
|
279
|
+
|
280
|
+
# For async tools, set the main tool's output_schema to the callback's input_schema
|
281
|
+
spec.output_schema = callback_input_schema
|
282
|
+
|
237
283
|
else:
|
238
284
|
callback_binding = None
|
285
|
+
acknowledgement_binding = None
|
239
286
|
|
240
287
|
openapi_binding = OpenApiToolBinding(
|
241
288
|
http_path=http_path,
|
@@ -248,6 +295,9 @@ def create_openapi_json_tool(
|
|
248
295
|
if callback_binding is not None:
|
249
296
|
openapi_binding.callback = callback_binding
|
250
297
|
|
298
|
+
if acknowledgement_binding is not None:
|
299
|
+
openapi_binding.acknowledgement = acknowledgement_binding
|
300
|
+
|
251
301
|
spec.binding = ToolBinding(openapi=openapi_binding)
|
252
302
|
|
253
303
|
return OpenAPITool(spec=spec)
|
@@ -187,15 +187,27 @@ def _fix_optional(schema):
|
|
187
187
|
replacements = {}
|
188
188
|
if schema.required is None:
|
189
189
|
schema.required = []
|
190
|
-
|
191
190
|
for k, v in schema.properties.items():
|
191
|
+
# Simple null type & required -> not required
|
192
192
|
if v.type == 'null' and k in schema.required:
|
193
193
|
not_required.append(k)
|
194
|
-
|
195
|
-
|
196
|
-
if
|
197
|
-
|
198
|
-
|
194
|
+
# Optional with null & required
|
195
|
+
if v.anyOf is not None and [x for x in v.anyOf if x.type == 'null']:
|
196
|
+
if k in schema.required:
|
197
|
+
# required with default -> not required
|
198
|
+
# required without default -> required & remove null from union
|
199
|
+
if v.default:
|
200
|
+
not_required.append(k)
|
201
|
+
else:
|
202
|
+
v.anyOf = list(filter(lambda x: x.type != 'null', v.anyOf))
|
203
|
+
if len(v.anyOf) == 1:
|
204
|
+
replacements[k] = v.anyOf[0]
|
205
|
+
else:
|
206
|
+
# not required with default -> no change
|
207
|
+
# not required without default -> means default input is 'None'
|
208
|
+
v.default = v.default if v.default else 'null'
|
209
|
+
|
210
|
+
|
199
211
|
schema.required = list(filter(lambda x: x not in not_required, schema.required if schema.required is not None else []))
|
200
212
|
for k, v in replacements.items():
|
201
213
|
combined = {
|
@@ -36,6 +36,7 @@ class JsonSchemaObject(BaseModel):
|
|
36
36
|
anyOf: Optional[List['JsonSchemaObject']] = None
|
37
37
|
in_field: Optional[Literal['query', 'header', 'path', 'body']] = Field(None, alias='in')
|
38
38
|
aliasName: str | None = None
|
39
|
+
wrap_data: Optional[bool] = True
|
39
40
|
"Runtime feature where the sdk can provide the original name of a field before prefixing"
|
40
41
|
|
41
42
|
@model_validator(mode='after')
|
@@ -48,9 +49,9 @@ class JsonSchemaObject(BaseModel):
|
|
48
49
|
class ToolRequestBody(BaseModel):
|
49
50
|
model_config = ConfigDict(extra='allow')
|
50
51
|
|
51
|
-
type: Literal['object']
|
52
|
-
properties: Dict[str, JsonSchemaObject]
|
53
|
-
required: Optional[List[str]] =
|
52
|
+
type: Literal['object', 'string']
|
53
|
+
properties: Optional[Dict[str, JsonSchemaObject]] = {}
|
54
|
+
required: Optional[List[str]] = []
|
54
55
|
|
55
56
|
|
56
57
|
class ToolResponseBody(BaseModel):
|
@@ -97,9 +98,13 @@ HTTP_METHOD = Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']
|
|
97
98
|
class CallbackBinding(BaseModel):
|
98
99
|
callback_url: str
|
99
100
|
method: HTTP_METHOD
|
100
|
-
input_schema: ToolRequestBody
|
101
|
+
input_schema: Optional[ToolRequestBody] = None
|
102
|
+
output_schema: ToolResponseBody
|
103
|
+
|
104
|
+
class AcknowledgementBinding(BaseModel):
|
101
105
|
output_schema: ToolResponseBody
|
102
106
|
|
107
|
+
|
103
108
|
class OpenApiToolBinding(BaseModel):
|
104
109
|
http_method: HTTP_METHOD
|
105
110
|
http_path: str
|
@@ -107,7 +112,8 @@ class OpenApiToolBinding(BaseModel):
|
|
107
112
|
security: Optional[List[OpenApiSecurityScheme]] = None
|
108
113
|
servers: Optional[List[str]] = None
|
109
114
|
connection_id: str | None = None
|
110
|
-
callback: CallbackBinding = None
|
115
|
+
callback: Optional[CallbackBinding] = None
|
116
|
+
acknowledgement: Optional[AcknowledgementBinding] = None
|
111
117
|
|
112
118
|
@model_validator(mode='after')
|
113
119
|
def validate_openapi_tool_binding(self):
|
@@ -181,6 +187,7 @@ class ToolBinding(BaseModel):
|
|
181
187
|
|
182
188
|
class ToolSpec(BaseModel):
|
183
189
|
name: str
|
190
|
+
id: str | None = None
|
184
191
|
display_name: str | None = None
|
185
192
|
description: str
|
186
193
|
permission: ToolPermission
|
@@ -750,6 +750,13 @@ class AgentsController:
|
|
750
750
|
def list_agents(self, kind: AgentKind=None, verbose: bool=False):
|
751
751
|
parse_errors = []
|
752
752
|
|
753
|
+
if verbose:
|
754
|
+
verbose_output_dictionary = {
|
755
|
+
"native":[],
|
756
|
+
"assistant":[],
|
757
|
+
"external":[]
|
758
|
+
}
|
759
|
+
|
753
760
|
if kind == AgentKind.NATIVE or kind is None:
|
754
761
|
response = self.get_native_client().get()
|
755
762
|
native_agents = []
|
@@ -769,7 +776,7 @@ class AgentsController:
|
|
769
776
|
for agent in native_agents:
|
770
777
|
agents_list.append(json.loads(agent.dumps_spec()))
|
771
778
|
|
772
|
-
|
779
|
+
verbose_output_dictionary["native"] = agents_list
|
773
780
|
else:
|
774
781
|
native_table = rich.table.Table(
|
775
782
|
show_header=True,
|
@@ -832,7 +839,7 @@ class AgentsController:
|
|
832
839
|
if verbose:
|
833
840
|
for agent in external_agents:
|
834
841
|
external_agents_list.append(json.loads(agent.dumps_spec()))
|
835
|
-
|
842
|
+
verbose_output_dictionary["external"] = external_agents_list
|
836
843
|
else:
|
837
844
|
external_table = rich.table.Table(
|
838
845
|
show_header=True,
|
@@ -899,8 +906,10 @@ class AgentsController:
|
|
899
906
|
assistant_agent.config.authorization_url = response_data.get("authorization_url", assistant_agent.config.authorization_url)
|
900
907
|
|
901
908
|
if verbose:
|
909
|
+
assistant_agent_specs = []
|
902
910
|
for agent in assistant_agents:
|
903
|
-
|
911
|
+
assistant_agent_specs.append(json.loads(agent.dumps_spec()))
|
912
|
+
verbose_output_dictionary["assistant"] = assistant_agent_specs
|
904
913
|
else:
|
905
914
|
assistants_table = rich.table.Table(
|
906
915
|
show_header=True,
|
@@ -938,6 +947,9 @@ class AgentsController:
|
|
938
947
|
)
|
939
948
|
rich.print(assistants_table)
|
940
949
|
|
950
|
+
if verbose:
|
951
|
+
rich.print_json(data=verbose_output_dictionary)
|
952
|
+
|
941
953
|
for error in parse_errors:
|
942
954
|
for l in error:
|
943
955
|
logger.error(l)
|
@@ -11,6 +11,7 @@ from ibm_watsonx_orchestrate.client.agents.agent_client import AgentClient
|
|
11
11
|
|
12
12
|
from ibm_watsonx_orchestrate.client.utils import instantiate_client
|
13
13
|
|
14
|
+
|
14
15
|
logger = logging.getLogger(__name__)
|
15
16
|
|
16
17
|
class ChannelsWebchatController:
|
@@ -96,8 +97,6 @@ class ChannelsWebchatController:
|
|
96
97
|
if target_env == 'draft' and is_saas == True:
|
97
98
|
logger.error(f'For SAAS, please ensure this agent exists in a Live Environment')
|
98
99
|
exit(1)
|
99
|
-
|
100
|
-
|
101
100
|
|
102
101
|
return filtered_environments[0].get("id")
|
103
102
|
|
@@ -182,7 +181,7 @@ class ChannelsWebchatController:
|
|
182
181
|
case _:
|
183
182
|
logger.error("Environment not recognized")
|
184
183
|
sys.exit(1)
|
185
|
-
|
184
|
+
|
186
185
|
host_url = self.get_host_url()
|
187
186
|
agent_id = self.get_agent_id(self.agent_name)
|
188
187
|
agent_env_id = self.get_environment_id(self.agent_name, self.env)
|
@@ -169,7 +169,7 @@ def _validate_connection_params(type: ConnectionType, **args) -> None:
|
|
169
169
|
|
170
170
|
|
171
171
|
def _parse_entry(entry: str) -> dict[str,str]:
|
172
|
-
split_entry = entry.split('=')
|
172
|
+
split_entry = entry.split('=', 1)
|
173
173
|
if len(split_entry) != 2:
|
174
174
|
message = f"The entry '{entry}' is not in the expected form '<key>=<value>'"
|
175
175
|
logger.error(message)
|
@@ -404,8 +404,11 @@ def list_connections(environment: ConnectionEnvironment | None, verbose: bool =
|
|
404
404
|
"❌"
|
405
405
|
)
|
406
406
|
continue
|
407
|
-
|
408
|
-
|
407
|
+
|
408
|
+
try:
|
409
|
+
connection_type = get_connection_type(security_scheme=conn.security_scheme, auth_type=conn.auth_type)
|
410
|
+
except:
|
411
|
+
connection_type = conn.auth_type
|
409
412
|
|
410
413
|
if conn.environment == ConnectionEnvironment.DRAFT:
|
411
414
|
draft_table.add_row(
|
@@ -10,10 +10,12 @@ from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
10
10
|
from requests import ConnectionError
|
11
11
|
from typing import List
|
12
12
|
from ibm_watsonx_orchestrate.client.base_api_client import ClientAPIException
|
13
|
+
from ibm_watsonx_orchestrate.agent_builder.knowledge_bases.types import KnowledgeBaseSpec
|
13
14
|
from ibm_watsonx_orchestrate.agent_builder.tools import ToolSpec, ToolPermission, ToolRequestBody, ToolResponseBody
|
14
15
|
from ibm_watsonx_orchestrate.cli.commands.agents.agents_controller import AgentsController, AgentKind, SpecVersion
|
15
16
|
from ibm_watsonx_orchestrate.agent_builder.agents.types import DEFAULT_LLM, BaseAgentSpec
|
16
17
|
from ibm_watsonx_orchestrate.client.agents.agent_client import AgentClient
|
18
|
+
from ibm_watsonx_orchestrate.client.knowledge_bases.knowledge_base_client import KnowledgeBaseClient
|
17
19
|
from ibm_watsonx_orchestrate.client.tools.tool_client import ToolClient
|
18
20
|
from ibm_watsonx_orchestrate.client.copilot.cpe.copilot_cpe_client import CPEClient
|
19
21
|
from ibm_watsonx_orchestrate.client.utils import instantiate_client
|
@@ -56,10 +58,16 @@ def _get_incomplete_tool_from_name(tool_name: str) -> dict:
|
|
56
58
|
"input_schema": input_schema, "output_schema": output_schema})
|
57
59
|
return spec.model_dump()
|
58
60
|
|
61
|
+
|
59
62
|
def _get_incomplete_agent_from_name(agent_name: str) -> dict:
|
60
63
|
spec = BaseAgentSpec(**{"name": agent_name, "description": agent_name, "kind": AgentKind.NATIVE})
|
61
64
|
return spec.model_dump()
|
62
65
|
|
66
|
+
def _get_incomplete_knowledge_base_from_name(kb_name: str) -> dict:
|
67
|
+
spec = KnowledgeBaseSpec(**{"name": kb_name, "description": kb_name})
|
68
|
+
return spec.model_dump()
|
69
|
+
|
70
|
+
|
63
71
|
def _get_tools_from_names(tool_names: List[str]) -> List[dict]:
|
64
72
|
if not len(tool_names):
|
65
73
|
return []
|
@@ -115,6 +123,34 @@ def _get_agents_from_names(collaborators_names: List[str]) -> List[dict]:
|
|
115
123
|
|
116
124
|
return agents
|
117
125
|
|
126
|
+
def _get_knowledge_bases_from_names(kb_names: List[str]) -> List[dict]:
|
127
|
+
if not len(kb_names):
|
128
|
+
return []
|
129
|
+
|
130
|
+
kb_client = get_knowledge_bases_client()
|
131
|
+
|
132
|
+
try:
|
133
|
+
with _get_progress_spinner() as progress:
|
134
|
+
task = progress.add_task(description="Fetching Knowledge Bases", total=None)
|
135
|
+
knowledge_bases = kb_client.get_by_names(kb_names)
|
136
|
+
found_kbs = {kb.get("name") for kb in knowledge_bases}
|
137
|
+
progress.remove_task(task)
|
138
|
+
progress.refresh()
|
139
|
+
for kb_name in kb_names:
|
140
|
+
if kb_name not in found_kbs:
|
141
|
+
logger.warning(
|
142
|
+
f"Failed to find knowledge base named '{kb_name}'. Falling back to incomplete knowledge base definition. Copilot performance maybe effected.")
|
143
|
+
knowledge_bases.append(_get_incomplete_knowledge_base_from_name(kb_name))
|
144
|
+
except ConnectionError:
|
145
|
+
logger.warning(
|
146
|
+
f"Failed to fetch knowledge bases from server. For optimal results please start the server and import the relevant knowledge bases {', '.join(kb_names)}.")
|
147
|
+
knowledge_bases = []
|
148
|
+
for kb_name in kb_names:
|
149
|
+
knowledge_bases.append(_get_incomplete_knowledge_base_from_name(kb_name))
|
150
|
+
|
151
|
+
return knowledge_bases
|
152
|
+
|
153
|
+
|
118
154
|
def get_cpe_client() -> CPEClient:
|
119
155
|
url = os.getenv('CPE_URL', "http://localhost:8081")
|
120
156
|
return instantiate_client(client=CPEClient, url=url)
|
@@ -124,6 +160,10 @@ def get_tool_client(*args, **kwargs):
|
|
124
160
|
return instantiate_client(ToolClient)
|
125
161
|
|
126
162
|
|
163
|
+
def get_knowledge_bases_client(*args, **kwargs):
|
164
|
+
return instantiate_client(KnowledgeBaseClient)
|
165
|
+
|
166
|
+
|
127
167
|
def get_native_client(*args, **kwargs):
|
128
168
|
return instantiate_client(AgentClient)
|
129
169
|
|
@@ -137,9 +177,6 @@ def gather_utterances(max: int) -> list[str]:
|
|
137
177
|
while count < max:
|
138
178
|
utterance = Prompt.ask(" [green]>[/green]").strip()
|
139
179
|
|
140
|
-
if utterance.lower() == 'q':
|
141
|
-
break
|
142
|
-
|
143
180
|
if utterance:
|
144
181
|
utterances.append(utterance)
|
145
182
|
count += 1
|
@@ -147,18 +184,34 @@ def gather_utterances(max: int) -> list[str]:
|
|
147
184
|
return utterances
|
148
185
|
|
149
186
|
|
150
|
-
def
|
187
|
+
def get_knowledge_bases(client):
|
188
|
+
with _get_progress_spinner() as progress:
|
189
|
+
task = progress.add_task(description="Fetching Knowledge Bases", total=None)
|
190
|
+
try:
|
191
|
+
knowledge_bases = client.get()
|
192
|
+
progress.remove_task(task)
|
193
|
+
except ConnectionError:
|
194
|
+
knowledge_bases = []
|
195
|
+
progress.remove_task(task)
|
196
|
+
progress.refresh()
|
197
|
+
logger.warning("Failed to contact wxo server to fetch knowledge_bases. Proceeding with empty agent list")
|
198
|
+
return knowledge_bases
|
199
|
+
|
200
|
+
|
201
|
+
def get_deployed_tools_agents_and_knowledge_bases():
|
151
202
|
all_tools = find_tools_by_description(tool_client=get_tool_client(), description=None)
|
152
203
|
# TODO: this brings only the "native" agents. Can external and assistant agents also be collaborators?
|
153
204
|
all_agents = find_agents(agent_client=get_native_client())
|
154
|
-
|
205
|
+
all_knowledge_bases = get_knowledge_bases(get_knowledge_bases_client())
|
206
|
+
|
207
|
+
return {"tools": all_tools, "collaborators": all_agents, "knowledge_bases": all_knowledge_bases}
|
155
208
|
|
156
209
|
|
157
210
|
def pre_cpe_step(cpe_client):
|
158
|
-
|
211
|
+
tools_agents_and_knowledge_bases = get_deployed_tools_agents_and_knowledge_bases()
|
159
212
|
user_message = ""
|
160
213
|
with _get_progress_spinner() as progress:
|
161
|
-
task = progress.add_task(description="
|
214
|
+
task = progress.add_task(description="Initializing Prompt Engine", total=None)
|
162
215
|
response = cpe_client.submit_pre_cpe_chat(user_message=user_message)
|
163
216
|
progress.remove_task(task)
|
164
217
|
|
@@ -168,15 +221,26 @@ def pre_cpe_step(cpe_client):
|
|
168
221
|
rich.print('\n🤖 Copilot: ' + response["message"])
|
169
222
|
user_message = Prompt.ask("\n👤 You").strip()
|
170
223
|
message_content = {"user_message": user_message}
|
171
|
-
elif "description" in response and response["description"]:
|
224
|
+
elif "description" in response and response["description"]: # after we have a description, we pass the all tools
|
172
225
|
res["description"] = response["description"]
|
173
|
-
message_content =
|
174
|
-
elif "
|
175
|
-
|
176
|
-
res["
|
177
|
-
|
178
|
-
|
179
|
-
|
226
|
+
message_content = {"tools": tools_agents_and_knowledge_bases['tools']}
|
227
|
+
elif "tools" in response and response[
|
228
|
+
'tools'] is not None: # after tools were selected, we pass all collaborators
|
229
|
+
res["tools"] = [t for t in tools_agents_and_knowledge_bases["tools"] if
|
230
|
+
t["name"] in response["tools"]]
|
231
|
+
message_content = {"collaborators": tools_agents_and_knowledge_bases['collaborators']}
|
232
|
+
elif "collaborators" in response and response[
|
233
|
+
'collaborators'] is not None: # after we have collaborators, we pass all knowledge bases
|
234
|
+
res["collaborators"] = [a for a in tools_agents_and_knowledge_bases["collaborators"] if
|
235
|
+
a["name"] in response["collaborators"]]
|
236
|
+
message_content = {"knowledge_bases": tools_agents_and_knowledge_bases['knowledge_bases']}
|
237
|
+
elif "knowledge_bases" in response and response['knowledge_bases'] is not None: # after we have knowledge bases, we pass selected=True to mark that all selection were done
|
238
|
+
res["knowledge_bases"] = [a for a in tools_agents_and_knowledge_bases["knowledge_bases"] if
|
239
|
+
a["name"] in response["knowledge_bases"]]
|
240
|
+
message_content = {"selected": True}
|
241
|
+
elif "agent_name" in response and response['agent_name'] is not None: # once we have a name and style, this phase has ended
|
242
|
+
res["agent_name"] = response["agent_name"]
|
243
|
+
res["agent_style"] = response["agent_style"]
|
180
244
|
return res
|
181
245
|
with _get_progress_spinner() as progress:
|
182
246
|
task = progress.add_task(description="Thinking...", total=None)
|
@@ -197,6 +261,7 @@ def find_tools_by_description(description, tool_client):
|
|
197
261
|
logger.warning("Failed to contact wxo server to fetch tools. Proceeding with empty tool list")
|
198
262
|
return tools
|
199
263
|
|
264
|
+
|
200
265
|
def find_agents(agent_client):
|
201
266
|
with _get_progress_spinner() as progress:
|
202
267
|
task = progress.add_task(description="Fetching Agents", total=None)
|
@@ -282,16 +347,25 @@ def prompt_tune(agent_spec: str, output_file: str | None, samples_file: str | No
|
|
282
347
|
tools = _get_tools_from_names(agent.tools)
|
283
348
|
|
284
349
|
collaborators = _get_agents_from_names(agent.collaborators)
|
350
|
+
|
351
|
+
knowledge_bases = _get_knowledge_bases_from_names(agent.knowledge_base)
|
285
352
|
try:
|
286
|
-
new_prompt = talk_to_cpe(cpe_client=client,
|
287
|
-
|
288
|
-
|
353
|
+
new_prompt = talk_to_cpe(cpe_client=client,
|
354
|
+
samples_file=samples_file,
|
355
|
+
context_data={
|
356
|
+
"initial_instruction": instr,
|
357
|
+
'tools': tools,
|
358
|
+
'description': agent.description,
|
359
|
+
"collaborators": collaborators,
|
360
|
+
"knowledge_bases": knowledge_bases
|
361
|
+
})
|
289
362
|
except ConnectionError:
|
290
363
|
logger.error(
|
291
364
|
"Failed to connect to Copilot server. Please ensure Copilot is running via `orchestrate copilot start`")
|
292
365
|
sys.exit(1)
|
293
366
|
except ClientAPIException:
|
294
|
-
logger.error(
|
367
|
+
logger.error(
|
368
|
+
"An unexpected server error has occur with in the Copilot server. Please check the logs via `orchestrate server logs`")
|
295
369
|
sys.exit(1)
|
296
370
|
|
297
371
|
if new_prompt:
|
@@ -319,17 +393,21 @@ def create_agent(output_file: str, llm: str, samples_file: str | None, dry_run_f
|
|
319
393
|
"Failed to connect to Copilot server. Please ensure Copilot is running via `orchestrate copilot start`")
|
320
394
|
sys.exit(1)
|
321
395
|
except ClientAPIException:
|
322
|
-
logger.error(
|
396
|
+
logger.error(
|
397
|
+
"An unexpected server error has occur with in the Copilot server. Please check the logs via `orchestrate server logs`")
|
323
398
|
sys.exit(1)
|
324
|
-
|
399
|
+
|
325
400
|
tools = res["tools"]
|
326
401
|
collaborators = res["collaborators"]
|
402
|
+
knowledge_bases = res["knowledge_bases"]
|
327
403
|
description = res["description"]
|
328
404
|
agent_name = res["agent_name"]
|
329
405
|
agent_style = res["agent_style"]
|
330
406
|
|
331
407
|
# 4. discuss the instructions
|
332
|
-
instructions = talk_to_cpe(cpe_client, samples_file,
|
408
|
+
instructions = talk_to_cpe(cpe_client, samples_file,
|
409
|
+
{'description': description, 'tools': tools, 'collaborators': collaborators,
|
410
|
+
'knowledge_bases': knowledge_bases})
|
333
411
|
|
334
412
|
# 6. create and save the agent
|
335
413
|
llm = llm if llm else DEFAULT_LLM
|
@@ -337,7 +415,9 @@ def create_agent(output_file: str, llm: str, samples_file: str | None, dry_run_f
|
|
337
415
|
'style': agent_style,
|
338
416
|
'tools': [t['name'] for t in tools],
|
339
417
|
'llm': llm,
|
340
|
-
'collaborators': [c['name'] for c in collaborators]
|
418
|
+
'collaborators': [c['name'] for c in collaborators],
|
419
|
+
'knowledge_base': [k['name'] for k in knowledge_bases]
|
420
|
+
# generate_agent_spec expects knowledge_base and not knowledge_bases
|
341
421
|
}
|
342
422
|
agent = AgentsController.generate_agent_spec(agent_name, AgentKind.NATIVE, description, **params)
|
343
423
|
agent.instructions = instructions
|