ibm-watsonx-orchestrate 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -1
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +14 -0
- ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
- ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +57 -9
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
- ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +140 -258
- ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
- ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
- ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
- ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
- ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
- ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +13 -8
- ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
- ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
- ibm_watsonx_orchestrate/docker/default.env +10 -11
- ibm_watsonx_orchestrate/experimental/flow_builder/data_map.py +19 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +4 -3
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +3 -1
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +3 -2
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +245 -223
- ibm_watsonx_orchestrate/experimental/flow_builder/node.py +34 -15
- ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +7 -39
- ibm_watsonx_orchestrate/experimental/flow_builder/types.py +285 -12
- ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +3 -1
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +38 -35
- ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +0 -91
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,437 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import json
|
5
|
+
import yaml
|
6
|
+
import importlib
|
7
|
+
import inspect
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import List
|
10
|
+
|
11
|
+
import requests
|
12
|
+
import rich
|
13
|
+
import rich.highlighter
|
14
|
+
|
15
|
+
from ibm_watsonx_orchestrate.cli.commands.server.server_command import get_default_env_file, merge_env
|
16
|
+
from ibm_watsonx_orchestrate.client.model_policies.model_policies_client import ModelPoliciesClient
|
17
|
+
from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicy, ModelPolicyInner, \
|
18
|
+
ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
|
19
|
+
from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
|
20
|
+
from ibm_watsonx_orchestrate.agent_builder.models.types import VirtualModel, ProviderConfig, ModelType, ANTHROPIC_DEFAULT_MAX_TOKENS
|
21
|
+
from ibm_watsonx_orchestrate.client.utils import instantiate_client
|
22
|
+
from ibm_watsonx_orchestrate.client.connections import get_connection_id, ConnectionType
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
WATSONX_URL = os.getenv("WATSONX_URL")
|
27
|
+
|
28
|
+
class ModelHighlighter(rich.highlighter.RegexHighlighter):
|
29
|
+
base_style = "model."
|
30
|
+
highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
|
31
|
+
|
32
|
+
def _get_wxai_foundational_models() -> dict:
|
33
|
+
foundation_models_url = WATSONX_URL + "/ml/v1/foundation_model_specs?version=2024-05-01"
|
34
|
+
|
35
|
+
try:
|
36
|
+
response = requests.get(foundation_models_url)
|
37
|
+
except requests.exceptions.RequestException as e:
|
38
|
+
logger.exception(f"Exception when connecting to Watsonx URL: {foundation_models_url}")
|
39
|
+
raise
|
40
|
+
|
41
|
+
if response.status_code != 200:
|
42
|
+
error_message = (
|
43
|
+
f"Failed to retrieve foundational models from {foundation_models_url}. "
|
44
|
+
f"Status code: {response.status_code}. Response: {response.content}"
|
45
|
+
)
|
46
|
+
raise Exception(error_message)
|
47
|
+
|
48
|
+
json_response = response.json()
|
49
|
+
return json_response
|
50
|
+
|
51
|
+
def _string_to_list(env_value) -> List[str]:
|
52
|
+
return [item.strip().lower() for item in env_value.split(",") if item.strip()]
|
53
|
+
|
54
|
+
def create_model_from_spec(spec: dict) -> VirtualModel:
|
55
|
+
return VirtualModel.model_validate(spec)
|
56
|
+
|
57
|
+
def create_policy_from_spec(spec: dict) -> ModelPolicy:
|
58
|
+
return ModelPolicy.model_validate(spec)
|
59
|
+
|
60
|
+
def import_python_model(file: str) -> List[VirtualModel]:
|
61
|
+
file_path = Path(file)
|
62
|
+
file_directory = file_path.parent
|
63
|
+
file_name = file_path.stem
|
64
|
+
sys.path.append(str(file_directory))
|
65
|
+
module = importlib.import_module(file_name)
|
66
|
+
del sys.path[-1]
|
67
|
+
|
68
|
+
models = []
|
69
|
+
for _, obj in inspect.getmembers(module):
|
70
|
+
if isinstance(obj, VirtualModel):
|
71
|
+
models.append(obj)
|
72
|
+
return models
|
73
|
+
|
74
|
+
def import_python_policy(file: str) -> List[ModelPolicy]:
|
75
|
+
file_path = Path(file)
|
76
|
+
file_directory = file_path.parent
|
77
|
+
file_name = file_path.stem
|
78
|
+
sys.path.append(str(file_directory))
|
79
|
+
module = importlib.import_module(file_name)
|
80
|
+
del sys.path[-1]
|
81
|
+
|
82
|
+
models = []
|
83
|
+
for _, obj in inspect.getmembers(module):
|
84
|
+
if isinstance(obj, ModelPolicy):
|
85
|
+
models.append(obj)
|
86
|
+
return models
|
87
|
+
|
88
|
+
def parse_model_file(file: str) -> List[VirtualModel]:
|
89
|
+
if file.endswith('.yaml') or file.endswith('.yml') or file.endswith(".json"):
|
90
|
+
with open(file, 'r') as f:
|
91
|
+
if file.endswith(".json"):
|
92
|
+
content = json.load(f)
|
93
|
+
else:
|
94
|
+
content = yaml.load(f, Loader=yaml.SafeLoader)
|
95
|
+
model = create_model_from_spec(spec=content)
|
96
|
+
return [model]
|
97
|
+
elif file.endswith('.py'):
|
98
|
+
models = import_python_model(file)
|
99
|
+
return models
|
100
|
+
else:
|
101
|
+
raise ValueError("file must end in .json, .yaml, .yml or .py")
|
102
|
+
|
103
|
+
def parse_policy_file(file: str) -> List[ModelPolicy]:
|
104
|
+
if file.endswith('.yaml') or file.endswith('.yml') or file.endswith(".json"):
|
105
|
+
with open(file, 'r') as f:
|
106
|
+
if file.endswith(".json"):
|
107
|
+
content = json.load(f)
|
108
|
+
else:
|
109
|
+
content = yaml.load(f, Loader=yaml.SafeLoader)
|
110
|
+
policy = create_policy_from_spec(spec=content)
|
111
|
+
return [policy]
|
112
|
+
elif file.endswith('.py'):
|
113
|
+
policies = import_python_policy(file)
|
114
|
+
return policies
|
115
|
+
else:
|
116
|
+
raise ValueError("file must end in .json, .yaml, .yml or .py")
|
117
|
+
|
118
|
+
def extract_model_names_from_policy_inner(policy_inner: ModelPolicyInner) -> List[str]:
|
119
|
+
model_names = []
|
120
|
+
for target in policy_inner.targets:
|
121
|
+
if isinstance(target, ModelPolicyTarget):
|
122
|
+
model_names.append(target.model_name)
|
123
|
+
elif isinstance(target, ModelPolicyInner):
|
124
|
+
model_names += extract_model_names_from_policy_inner(target)
|
125
|
+
return model_names
|
126
|
+
|
127
|
+
def get_model_names_from_policy(policy: ModelPolicy) -> List[str]:
|
128
|
+
return extract_model_names_from_policy_inner(policy_inner=policy.policy)
|
129
|
+
|
130
|
+
class ModelsController:
|
131
|
+
def __init__(self):
|
132
|
+
self.models_client = None
|
133
|
+
self.model_policies_client = None
|
134
|
+
|
135
|
+
def get_models_client(self) -> ModelsClient:
|
136
|
+
if not self.models_client:
|
137
|
+
self.models_client = instantiate_client(ModelsClient)
|
138
|
+
return self.models_client
|
139
|
+
|
140
|
+
def get_model_policies_client(self) -> ModelPoliciesClient:
|
141
|
+
if not self.model_policies_client:
|
142
|
+
self.model_policies_client = instantiate_client(ModelPoliciesClient)
|
143
|
+
return self.model_policies_client
|
144
|
+
|
145
|
+
def list_models(self, print_raw: bool = False) -> None:
|
146
|
+
models_client: ModelsClient = self.get_models_client()
|
147
|
+
model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
|
148
|
+
global WATSONX_URL
|
149
|
+
default_env_path = get_default_env_file()
|
150
|
+
merged_env_dict = merge_env(
|
151
|
+
default_env_path,
|
152
|
+
None
|
153
|
+
)
|
154
|
+
|
155
|
+
if 'WATSONX_URL' in merged_env_dict and merged_env_dict['WATSONX_URL']:
|
156
|
+
WATSONX_URL = merged_env_dict['WATSONX_URL']
|
157
|
+
|
158
|
+
watsonx_url = merged_env_dict.get("WATSONX_URL")
|
159
|
+
if not watsonx_url:
|
160
|
+
logger.error("Error: WATSONX_URL is required in the environment.")
|
161
|
+
sys.exit(1)
|
162
|
+
|
163
|
+
logger.info("Retrieving virtual-model models list...")
|
164
|
+
virtual_models = models_client.list()
|
165
|
+
|
166
|
+
|
167
|
+
|
168
|
+
logger.info("Retrieving virtual-policies models list...")
|
169
|
+
virtual_model_policies = model_policies_client.list()
|
170
|
+
|
171
|
+
logger.info("Retrieving watsonx.ai models list...")
|
172
|
+
found_models = _get_wxai_foundational_models()
|
173
|
+
|
174
|
+
|
175
|
+
preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
|
176
|
+
incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
|
177
|
+
|
178
|
+
preferred_list = _string_to_list(preferred_str)
|
179
|
+
incompatible_list = _string_to_list(incompatible_str)
|
180
|
+
|
181
|
+
models = found_models.get("resources", [])
|
182
|
+
if not models:
|
183
|
+
logger.error("No models found.")
|
184
|
+
else:
|
185
|
+
# Remove incompatible models
|
186
|
+
filtered_models = []
|
187
|
+
for model in models:
|
188
|
+
model_id = model.get("model_id", "")
|
189
|
+
short_desc = model.get("short_description", "")
|
190
|
+
if any(incomp in model_id.lower() for incomp in incompatible_list):
|
191
|
+
continue
|
192
|
+
if any(incomp in short_desc.lower() for incomp in incompatible_list):
|
193
|
+
continue
|
194
|
+
filtered_models.append(model)
|
195
|
+
|
196
|
+
# Sort to put the preferred first
|
197
|
+
def sort_key(model):
|
198
|
+
model_id = model.get("model_id", "").lower()
|
199
|
+
is_preferred = any(pref in model_id for pref in preferred_list)
|
200
|
+
return (0 if is_preferred else 1, model_id)
|
201
|
+
|
202
|
+
sorted_models = sorted(filtered_models, key=sort_key)
|
203
|
+
|
204
|
+
if print_raw:
|
205
|
+
theme = rich.theme.Theme({"model.name": "bold cyan"})
|
206
|
+
console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
|
207
|
+
console.print("[bold]Available Models:[/bold]")
|
208
|
+
|
209
|
+
for model in (virtual_models + virtual_model_policies):
|
210
|
+
console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
|
211
|
+
|
212
|
+
for model in sorted_models:
|
213
|
+
model_id = model.get("model_id", "N/A")
|
214
|
+
short_desc = model.get("short_description", "No description provided.")
|
215
|
+
full_model_name = f"watsonx/{model_id}: {short_desc}"
|
216
|
+
marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
|
217
|
+
console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
|
218
|
+
|
219
|
+
console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
|
220
|
+
else:
|
221
|
+
table = rich.table.Table(
|
222
|
+
show_header=True,
|
223
|
+
title="[bold]Available Models[/bold]",
|
224
|
+
caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
|
225
|
+
show_lines=True)
|
226
|
+
columns = ["Model", "Description"]
|
227
|
+
for col in columns:
|
228
|
+
table.add_column(col)
|
229
|
+
|
230
|
+
for model in (virtual_models + virtual_model_policies):
|
231
|
+
table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
|
232
|
+
|
233
|
+
for model in sorted_models:
|
234
|
+
model_id = model.get("model_id", "N/A")
|
235
|
+
short_desc = model.get("short_description", "No description provided.")
|
236
|
+
marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
|
237
|
+
table.add_row(f"[yellow]{marker}[/yellow]watsonx/{model_id}", short_desc)
|
238
|
+
|
239
|
+
rich.print(table)
|
240
|
+
|
241
|
+
def import_model(self, file: str, app_id: str | None) -> List[VirtualModel]:
|
242
|
+
from ibm_watsonx_orchestrate.cli.commands.models.model_provider_mapper import validate_ProviderConfig # lazily import this because the lut building is expensive
|
243
|
+
models = parse_model_file(file)
|
244
|
+
|
245
|
+
for model in models:
|
246
|
+
if not model.name.startswith('virtual-model/'):
|
247
|
+
model.name = f"virtual-model/{model.name}"
|
248
|
+
|
249
|
+
provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model.name.split('/')))
|
250
|
+
if not model.provider_config:
|
251
|
+
model.provider_config = ProviderConfig.model_validate({"provider": provider})
|
252
|
+
else:
|
253
|
+
model.provider_config.provider = provider
|
254
|
+
|
255
|
+
if "anthropic" in model.name:
|
256
|
+
if not model.config:
|
257
|
+
model.config = {}
|
258
|
+
if "max_tokens" not in model.config:
|
259
|
+
model.config["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS
|
260
|
+
|
261
|
+
if app_id:
|
262
|
+
model.connection_id = get_connection_id(app_id, supported_schemas={ConnectionType.KEY_VALUE})
|
263
|
+
validate_ProviderConfig(model.provider_config, app_id=app_id)
|
264
|
+
return models
|
265
|
+
|
266
|
+
def create_model(self, name: str, display_name: str | None = None, description: str | None = None, provider_config_dict: dict = None, model_type: ModelType = ModelType.CHAT, app_id: str = None) -> VirtualModel:
|
267
|
+
from ibm_watsonx_orchestrate.cli.commands.models.model_provider_mapper import validate_ProviderConfig # lazily import this because the lut building is expensive
|
268
|
+
|
269
|
+
provider =next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), name.split('/')))
|
270
|
+
|
271
|
+
provider_config = {}
|
272
|
+
if provider_config_dict:
|
273
|
+
provider_config = ProviderConfig.model_validate(provider_config_dict)
|
274
|
+
provider_config.provider = provider
|
275
|
+
else:
|
276
|
+
provider_config = ProviderConfig.model_validate({"provider": provider})
|
277
|
+
validate_ProviderConfig(provider_config, app_id=app_id)
|
278
|
+
|
279
|
+
if not name.startswith('virtual-model/'):
|
280
|
+
name = f"virtual-model/{name}"
|
281
|
+
|
282
|
+
config=None
|
283
|
+
# Anthropic has no default for max_tokens
|
284
|
+
if "anthropic" in name:
|
285
|
+
config = {
|
286
|
+
"max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
|
287
|
+
}
|
288
|
+
|
289
|
+
model = VirtualModel(
|
290
|
+
name=name,
|
291
|
+
display_name=display_name,
|
292
|
+
description=description,
|
293
|
+
tags=[],
|
294
|
+
provider_config=provider_config,
|
295
|
+
config=config,
|
296
|
+
model_type=model_type,
|
297
|
+
connection_id=get_connection_id(app_id, supported_schemas={ConnectionType.KEY_VALUE})
|
298
|
+
)
|
299
|
+
|
300
|
+
return model
|
301
|
+
|
302
|
+
def publish_or_update_models(self, model: VirtualModel) -> None:
|
303
|
+
models_client = self.get_models_client()
|
304
|
+
|
305
|
+
existing_models = models_client.get_draft_by_name(model.name)
|
306
|
+
if len(existing_models) > 1:
|
307
|
+
logger.error(f"Multiple models with the name '{model.name}' found. Failed to update model")
|
308
|
+
sys.exit(1)
|
309
|
+
|
310
|
+
if len(existing_models) == 1:
|
311
|
+
self.update_model(model_id=existing_models[0].id, model=model)
|
312
|
+
else:
|
313
|
+
self.publish_model(model=model)
|
314
|
+
|
315
|
+
def publish_model(self, model: VirtualModel) -> None:
|
316
|
+
self.get_models_client().create(model)
|
317
|
+
logger.info(f"Successfully added the model '{model.name}'")
|
318
|
+
|
319
|
+
def update_model(self, model_id: str, model: VirtualModel) -> None:
|
320
|
+
logger.info(f"Existing model '{model.name}' found. Updating...")
|
321
|
+
self.get_models_client().update(model_id, model)
|
322
|
+
logger.info(f"Model '{model.name}' updated successfully")
|
323
|
+
|
324
|
+
def remove_model(self, name: str) -> None:
|
325
|
+
models_client: ModelsClient = self.get_models_client()
|
326
|
+
|
327
|
+
existing_models = models_client.get_draft_by_name(name)
|
328
|
+
|
329
|
+
if len(existing_models) > 1:
|
330
|
+
logger.error(f"Multiple models with the name '{name}' found. Failed to remove model")
|
331
|
+
sys.exit(1)
|
332
|
+
if len(existing_models) == 0:
|
333
|
+
logger.error(f"No model found with the name '{name}'")
|
334
|
+
sys.exit(1)
|
335
|
+
|
336
|
+
model = existing_models[0]
|
337
|
+
|
338
|
+
models_client.delete(model_id=model.id)
|
339
|
+
logger.info(f"Successfully removed the model '{name}'")
|
340
|
+
|
341
|
+
def import_model_policy(self, file: str) -> List[ModelPolicy]:
|
342
|
+
policies = parse_policy_file(file)
|
343
|
+
model_client: ModelsClient = self.get_models_client()
|
344
|
+
model_lut = {m.name: m.id for m in model_client.list()}
|
345
|
+
|
346
|
+
for policy in policies:
|
347
|
+
models = get_model_names_from_policy(policy)
|
348
|
+
for m in models:
|
349
|
+
if m not in model_lut:
|
350
|
+
logger.error(f"No model found with the name '{m}'")
|
351
|
+
sys.exit(1)
|
352
|
+
|
353
|
+
if not policy.name.startswith('virtual-policy/'):
|
354
|
+
policy.name = f"virtual-policy/{policy.name}"
|
355
|
+
|
356
|
+
return policies
|
357
|
+
|
358
|
+
def create_model_policy(
|
359
|
+
self,
|
360
|
+
name: str,
|
361
|
+
models: List[str],
|
362
|
+
strategy: ModelPolicyStrategyMode,
|
363
|
+
strategy_on_code: List[int],
|
364
|
+
retry_on_code: List[int],
|
365
|
+
retry_attempts: int,
|
366
|
+
display_name: str = None,
|
367
|
+
description: str = None
|
368
|
+
) -> ModelPolicy:
|
369
|
+
|
370
|
+
model_client: ModelsClient = self.get_models_client()
|
371
|
+
model_lut = {m.name: m.id for m in model_client.list()}
|
372
|
+
for m in models:
|
373
|
+
if m not in model_lut:
|
374
|
+
logger.error(f"No model found with the name '{m}'")
|
375
|
+
sys.exit(1)
|
376
|
+
|
377
|
+
if not name.startswith('virtual-policy/'):
|
378
|
+
name = f"virtual-policy/{name}"
|
379
|
+
|
380
|
+
inner = ModelPolicyInner()
|
381
|
+
inner.strategy = ModelPolicyStrategy(
|
382
|
+
mode=strategy,
|
383
|
+
on_status_codes=strategy_on_code
|
384
|
+
)
|
385
|
+
inner.targets = [ModelPolicyTarget(weight=1, model_name=m) for m in models]
|
386
|
+
if retry_on_code:
|
387
|
+
inner.retry = ModelPolicyRetry(
|
388
|
+
on_status_codes=retry_on_code,
|
389
|
+
attempts=retry_attempts
|
390
|
+
)
|
391
|
+
|
392
|
+
policy = ModelPolicy(
|
393
|
+
name=name,
|
394
|
+
display_name=display_name or name,
|
395
|
+
description=description or name,
|
396
|
+
policy=inner
|
397
|
+
)
|
398
|
+
|
399
|
+
return policy
|
400
|
+
|
401
|
+
def publish_or_update_model_policies(self, policy: ModelPolicy) -> None:
|
402
|
+
model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
|
403
|
+
|
404
|
+
existing_policies = model_policies_client.get_draft_by_name(policy.name)
|
405
|
+
if len(existing_policies) > 1:
|
406
|
+
logger.error(f"Multiple model policies with the name '{policy.name}' found. Failed to update model policy")
|
407
|
+
sys.exit(1)
|
408
|
+
|
409
|
+
if len(existing_policies) == 1:
|
410
|
+
self.update_policy(policy_id=existing_policies[0].id, policy=policy)
|
411
|
+
else:
|
412
|
+
self.publish_policy(policy=policy)
|
413
|
+
|
414
|
+
def publish_policy(self, policy: VirtualModel) -> None:
|
415
|
+
self.get_model_policies_client().create(policy)
|
416
|
+
logger.info(f"Successfully added the model policy '{policy.name}'")
|
417
|
+
|
418
|
+
def update_policy(self, policy_id: str, policy: VirtualModel) -> None:
|
419
|
+
logger.info(f"Existing model policy '{policy.name}' found. Updating...")
|
420
|
+
self.get_model_policies_client().update(policy_id, policy)
|
421
|
+
logger.info(f"Model policy '{policy.name}' updated successfully")
|
422
|
+
|
423
|
+
def remove_policy(self, name: str) -> None:
|
424
|
+
model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
|
425
|
+
existing_model_policies = model_policies_client.get_draft_by_name(name)
|
426
|
+
|
427
|
+
if len(existing_model_policies) > 1:
|
428
|
+
logger.error(f"Multiple model policies with the name '{name}' found. Failed to remove model policy")
|
429
|
+
sys.exit(1)
|
430
|
+
if len(existing_model_policies) == 0:
|
431
|
+
logger.error(f"No model policy found with the name '{name}'")
|
432
|
+
sys.exit(1)
|
433
|
+
|
434
|
+
policy = existing_model_policies[0]
|
435
|
+
|
436
|
+
model_policies_client.delete(model_policy_id=policy.id)
|
437
|
+
logger.info(f"Successfully removed the policy '{name}'")
|
@@ -22,6 +22,7 @@ from ibm_watsonx_orchestrate.cli.commands.environment.environment_controller imp
|
|
22
22
|
|
23
23
|
from ibm_watsonx_orchestrate.cli.config import LICENSE_HEADER, \
|
24
24
|
ENV_ACCEPT_LICENSE
|
25
|
+
|
25
26
|
from ibm_watsonx_orchestrate.cli.config import PROTECTED_ENV_NAME, clear_protected_env_credentials_token, Config, \
|
26
27
|
AUTH_CONFIG_FILE_FOLDER, AUTH_CONFIG_FILE, AUTH_MCSP_TOKEN_OPT, AUTH_SECTION_HEADER, USER_ENV_CACHE_HEADER, LICENSE_HEADER, \
|
27
28
|
ENV_ACCEPT_LICENSE
|
@@ -816,4 +817,4 @@ def run_db_migration() -> None:
|
|
816
817
|
sys.exit(1)
|
817
818
|
|
818
819
|
if __name__ == "__main__":
|
819
|
-
server_app()
|
820
|
+
server_app()
|
@@ -25,7 +25,7 @@ from rich.panel import Panel
|
|
25
25
|
|
26
26
|
from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool, ToolSpec
|
27
27
|
from ibm_watsonx_orchestrate.agent_builder.tools.openapi_tool import create_openapi_json_tools_from_uri,create_openapi_json_tools_from_content
|
28
|
-
from ibm_watsonx_orchestrate.cli.commands.models.
|
28
|
+
from ibm_watsonx_orchestrate.cli.commands.models.models_controller import ModelHighlighter
|
29
29
|
from ibm_watsonx_orchestrate.cli.commands.tools.types import RegistryType
|
30
30
|
from ibm_watsonx_orchestrate.cli.commands.connections.connections_controller import configure_connection, remove_connection, add_connection
|
31
31
|
from ibm_watsonx_orchestrate.agent_builder.connections.types import ConnectionType, ConnectionEnvironment, ConnectionPreference
|
@@ -1,8 +1,11 @@
|
|
1
|
+
import logging
|
1
2
|
from ibm_watsonx_orchestrate.client.utils import instantiate_client, is_local_dev
|
2
3
|
from ibm_watsonx_orchestrate.client.connections.connections_client import ConnectionsClient
|
3
4
|
from ibm_watsonx_orchestrate.cli.config import Config, ENVIRONMENTS_SECTION_HEADER, CONTEXT_SECTION_HEADER, CONTEXT_ACTIVE_ENV_OPT, ENV_WXO_URL_OPT
|
4
5
|
from ibm_watsonx_orchestrate.agent_builder.connections.types import ConnectionType, ConnectionAuthType, ConnectionSecurityScheme
|
5
6
|
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
6
9
|
LOCAL_CONNECTION_MANAGER_PORT = 3001
|
7
10
|
|
8
11
|
def _get_connections_manager_url() -> str:
|
@@ -25,3 +28,30 @@ def get_connection_type(security_scheme: ConnectionSecurityScheme, auth_type: Co
|
|
25
28
|
if security_scheme != ConnectionSecurityScheme.OAUTH2:
|
26
29
|
return ConnectionType(security_scheme)
|
27
30
|
return ConnectionType(auth_type)
|
31
|
+
|
32
|
+
def get_connection_id(app_id: str, supported_schemas: set | None = None) -> str:
|
33
|
+
if not app_id:
|
34
|
+
return
|
35
|
+
|
36
|
+
connections_client = get_connections_client()
|
37
|
+
|
38
|
+
connection_id = None
|
39
|
+
if app_id is not None:
|
40
|
+
connection = connections_client.get(app_id=app_id)
|
41
|
+
if not connection:
|
42
|
+
logger.error(f"No connection exists with the app-id '{app_id}'")
|
43
|
+
exit(1)
|
44
|
+
connection_id = connection.connection_id
|
45
|
+
|
46
|
+
existing_draft_configuration = None
|
47
|
+
existing_live_configuration = None
|
48
|
+
|
49
|
+
if supported_schemas:
|
50
|
+
existing_draft_configuration = connections_client.get_config(app_id=app_id, env='draft')
|
51
|
+
existing_live_configuration = connections_client.get_config(app_id=app_id, env='live')
|
52
|
+
for config in [existing_draft_configuration, existing_live_configuration]:
|
53
|
+
if config and config.security_scheme not in supported_schemas:
|
54
|
+
logger.error(f"Only {', '.join(supported_schemas)} credentials are currently supported. Provided connection '{config.app_id}' is of type '{config.security_scheme}' in environment '{config.environment}'")
|
55
|
+
exit(1)
|
56
|
+
|
57
|
+
return connection_id
|
@@ -5,15 +5,12 @@ from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient, Client
|
|
5
5
|
|
6
6
|
import logging
|
7
7
|
|
8
|
-
from ibm_watsonx_orchestrate.
|
9
|
-
from ibm_watsonx_orchestrate.client.models.types import ListVirtualModel, CreateVirtualModel
|
8
|
+
from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicy
|
10
9
|
|
11
10
|
logger = logging.getLogger(__name__)
|
12
11
|
|
13
12
|
|
14
13
|
|
15
|
-
|
16
|
-
|
17
14
|
class ModelPoliciesClient(BaseAPIClient):
|
18
15
|
"""
|
19
16
|
Client to handle CRUD operations for ModelPolicies endpoint
|
@@ -21,6 +18,10 @@ class ModelPoliciesClient(BaseAPIClient):
|
|
21
18
|
# POST api/v1/model_policy
|
22
19
|
def create(self, model: ModelPolicy) -> None:
|
23
20
|
self._post("/model_policy", data=model.model_dump())
|
21
|
+
|
22
|
+
# PUT api/v1/model_policy/{model_policy_id}
|
23
|
+
def update(self, model_policy_id: str, model: ModelPolicy) -> None:
|
24
|
+
self._put(f"/model_policy/{model_policy_id}", data=model.model_dump())
|
24
25
|
|
25
26
|
# DELETE api/v1/model_policy/{model_policy_id}
|
26
27
|
def delete(self, model_policy_id: str) -> dict:
|
@@ -43,5 +44,23 @@ class ModelPoliciesClient(BaseAPIClient):
|
|
43
44
|
if e.response.status_code == 404:
|
44
45
|
return []
|
45
46
|
raise e
|
47
|
+
|
48
|
+
def get_drafts_by_names(self, policy_names: List[str]) -> List[ModelPolicy]:
|
49
|
+
try:
|
50
|
+
formatted_policy_names = [f'name={x}' for x in policy_names]
|
51
|
+
res = self._get(f"/model_policy?{'&'.join(formatted_policy_names)}")
|
52
|
+
return [ModelPolicy.model_validate(conn) for conn in res]
|
53
|
+
except ValidationError as e:
|
54
|
+
logger.error("Received unexpected response from server")
|
55
|
+
raise e
|
56
|
+
except ClientAPIException as e:
|
57
|
+
if e.response.status_code == 404:
|
58
|
+
return []
|
59
|
+
raise e
|
60
|
+
|
61
|
+
def get_draft_by_name(self, policy_name: str) -> ModelPolicy:
|
62
|
+
return self.get_drafts_by_names([policy_name])
|
63
|
+
|
64
|
+
|
46
65
|
|
47
66
|
|
@@ -5,7 +5,7 @@ from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient, Client
|
|
5
5
|
|
6
6
|
import logging
|
7
7
|
|
8
|
-
from ibm_watsonx_orchestrate.
|
8
|
+
from ibm_watsonx_orchestrate.agent_builder.models.types import ListVirtualModel, VirtualModel
|
9
9
|
|
10
10
|
logger = logging.getLogger(__name__)
|
11
11
|
|
@@ -18,8 +18,12 @@ class ModelsClient(BaseAPIClient):
|
|
18
18
|
Client to handle CRUD operations for Models endpoint
|
19
19
|
"""
|
20
20
|
# POST api/v1/models
|
21
|
-
def create(self, model:
|
22
|
-
self._post("/models", data=model.model_dump())
|
21
|
+
def create(self, model: VirtualModel) -> None:
|
22
|
+
self._post("/models", data=model.model_dump(exclude_none=True))
|
23
|
+
|
24
|
+
# PUT api/v1/models/{models_id}
|
25
|
+
def update(self, model_id: str, model: VirtualModel) -> None:
|
26
|
+
self._put(f"/models/{model_id}", data=model.model_dump(exclude_none=True))
|
23
27
|
|
24
28
|
# DELETE api/v1/models/{model_id}
|
25
29
|
def delete(self, model_id: str) -> dict:
|
@@ -42,5 +46,21 @@ class ModelsClient(BaseAPIClient):
|
|
42
46
|
if e.response.status_code == 404:
|
43
47
|
return []
|
44
48
|
raise e
|
49
|
+
|
50
|
+
def get_drafts_by_names(self, model_names: List[str]) -> List[ListVirtualModel]:
|
51
|
+
try:
|
52
|
+
formatted_model_names = [f'name={x}' for x in model_names]
|
53
|
+
res = self._get(f"/models?{'&'.join(formatted_model_names)}")
|
54
|
+
return [ListVirtualModel.model_validate(conn) for conn in res]
|
55
|
+
except ValidationError as e:
|
56
|
+
logger.error("Received unexpected response from server")
|
57
|
+
raise e
|
58
|
+
except ClientAPIException as e:
|
59
|
+
if e.response.status_code == 404:
|
60
|
+
return []
|
61
|
+
raise e
|
62
|
+
|
63
|
+
def get_draft_by_name(self, model_name: str) -> ListVirtualModel:
|
64
|
+
return self.get_drafts_by_names([model_name])
|
45
65
|
|
46
66
|
|
@@ -8,10 +8,10 @@ class ToolKitClient(BaseAPIClient):
|
|
8
8
|
|
9
9
|
def __init__(self, *args, **kwargs):
|
10
10
|
super().__init__(*args, **kwargs)
|
11
|
-
self.base_endpoint = "/orchestrate/toolkits" if is_local_dev(self.base_url) else "/toolkits"
|
12
11
|
|
13
12
|
def get(self) -> dict:
|
14
|
-
return self._get(
|
13
|
+
return self._get("/toolkits")
|
14
|
+
|
15
15
|
|
16
16
|
# POST /toolkits/prepare/list-tools
|
17
17
|
def list_tools(self, zip_file_path: str, command: str, args: List[str]) -> List[str]:
|
@@ -33,7 +33,8 @@ class ToolKitClient(BaseAPIClient):
|
|
33
33
|
"file": (filename, f, "application/zip"),
|
34
34
|
}
|
35
35
|
|
36
|
-
response = self._post(
|
36
|
+
response = self._post("/toolkits/prepare/list-tools", files=files)
|
37
|
+
|
37
38
|
|
38
39
|
return response.get("tools", [])
|
39
40
|
|
@@ -44,7 +45,8 @@ class ToolKitClient(BaseAPIClient):
|
|
44
45
|
Creates new toolkit metadata
|
45
46
|
"""
|
46
47
|
try:
|
47
|
-
return self._post(
|
48
|
+
return self._post("/toolkits", data=payload)
|
49
|
+
|
48
50
|
except ClientAPIException as e:
|
49
51
|
if e.response.status_code == 400 and "already exists" in e.response.text:
|
50
52
|
raise ClientAPIException(
|
@@ -63,25 +65,28 @@ class ToolKitClient(BaseAPIClient):
|
|
63
65
|
files = {
|
64
66
|
"file": (filename, f, "application/zip", {"Expires": "0"})
|
65
67
|
}
|
66
|
-
return self._post(f"
|
68
|
+
return self._post(f"/toolkits/{toolkit_id}/upload", files=files)
|
67
69
|
|
68
70
|
# DELETE /toolkits/{toolkit-id}
|
69
71
|
def delete(self, toolkit_id: str) -> dict:
|
70
|
-
return self._delete(f"
|
72
|
+
return self._delete(f"/toolkits/{toolkit_id}")
|
73
|
+
|
71
74
|
|
72
75
|
def get_draft_by_name(self, toolkit_name: str) -> List[dict]:
|
73
76
|
return self.get_drafts_by_names([toolkit_name])
|
74
77
|
|
75
78
|
def get_drafts_by_names(self, toolkit_names: List[str]) -> List[dict]:
|
76
79
|
formatted_toolkit_names = [f"names={x}" for x in toolkit_names]
|
77
|
-
|
80
|
+
|
81
|
+
return self._get(f"/toolkits?{'&'.join(formatted_toolkit_names)}")
|
82
|
+
|
78
83
|
|
79
84
|
def get_draft_by_id(self, toolkit_id: str) -> dict:
|
80
85
|
if toolkit_id is None:
|
81
86
|
return ""
|
82
87
|
else:
|
83
88
|
try:
|
84
|
-
toolkit = self._get(f"
|
89
|
+
toolkit = self._get(f"/toolkits/{toolkit_id}")
|
85
90
|
return toolkit
|
86
91
|
except ClientAPIException as e:
|
87
92
|
if e.response.status_code == 404 and "not found with the given name" in e.response.text:
|