ibm-watsonx-orchestrate 1.5.0b0__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -2
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +13 -0
  4. ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
  5. ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
  6. ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
  7. ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +43 -7
  8. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
  9. ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
  11. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
  12. ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +136 -259
  14. ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
  15. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
  16. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
  17. ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
  18. ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
  19. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
  20. ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
  21. ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
  22. ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
  23. ibm_watsonx_orchestrate/docker/default.env +10 -11
  24. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
  25. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +28 -25
  26. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
  27. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
  28. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
  29. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,437 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ import json
5
+ import yaml
6
+ import importlib
7
+ import inspect
8
+ from pathlib import Path
9
+ from typing import List
10
+
11
+ import requests
12
+ import rich
13
+ import rich.highlighter
14
+
15
+ from ibm_watsonx_orchestrate.cli.commands.server.server_command import get_default_env_file, merge_env
16
+ from ibm_watsonx_orchestrate.client.model_policies.model_policies_client import ModelPoliciesClient
17
+ from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicy, ModelPolicyInner, \
18
+ ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
19
+ from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
20
+ from ibm_watsonx_orchestrate.agent_builder.models.types import VirtualModel, ProviderConfig, ModelType, ANTHROPIC_DEFAULT_MAX_TOKENS
21
+ from ibm_watsonx_orchestrate.client.utils import instantiate_client
22
+ from ibm_watsonx_orchestrate.client.connections import get_connection_id, ConnectionType
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ WATSONX_URL = os.getenv("WATSONX_URL")
27
+
28
+ class ModelHighlighter(rich.highlighter.RegexHighlighter):
29
+ base_style = "model."
30
+ highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
31
+
32
+ def _get_wxai_foundational_models() -> dict:
33
+ foundation_models_url = WATSONX_URL + "/ml/v1/foundation_model_specs?version=2024-05-01"
34
+
35
+ try:
36
+ response = requests.get(foundation_models_url)
37
+ except requests.exceptions.RequestException as e:
38
+ logger.exception(f"Exception when connecting to Watsonx URL: {foundation_models_url}")
39
+ raise
40
+
41
+ if response.status_code != 200:
42
+ error_message = (
43
+ f"Failed to retrieve foundational models from {foundation_models_url}. "
44
+ f"Status code: {response.status_code}. Response: {response.content}"
45
+ )
46
+ raise Exception(error_message)
47
+
48
+ json_response = response.json()
49
+ return json_response
50
+
51
+ def _string_to_list(env_value) -> List[str]:
52
+ return [item.strip().lower() for item in env_value.split(",") if item.strip()]
53
+
54
+ def create_model_from_spec(spec: dict) -> VirtualModel:
55
+ return VirtualModel.model_validate(spec)
56
+
57
+ def create_policy_from_spec(spec: dict) -> ModelPolicy:
58
+ return ModelPolicy.model_validate(spec)
59
+
60
+ def import_python_model(file: str) -> List[VirtualModel]:
61
+ file_path = Path(file)
62
+ file_directory = file_path.parent
63
+ file_name = file_path.stem
64
+ sys.path.append(str(file_directory))
65
+ module = importlib.import_module(file_name)
66
+ del sys.path[-1]
67
+
68
+ models = []
69
+ for _, obj in inspect.getmembers(module):
70
+ if isinstance(obj, VirtualModel):
71
+ models.append(obj)
72
+ return models
73
+
74
+ def import_python_policy(file: str) -> List[ModelPolicy]:
75
+ file_path = Path(file)
76
+ file_directory = file_path.parent
77
+ file_name = file_path.stem
78
+ sys.path.append(str(file_directory))
79
+ module = importlib.import_module(file_name)
80
+ del sys.path[-1]
81
+
82
+ models = []
83
+ for _, obj in inspect.getmembers(module):
84
+ if isinstance(obj, ModelPolicy):
85
+ models.append(obj)
86
+ return models
87
+
88
+ def parse_model_file(file: str) -> List[VirtualModel]:
89
+ if file.endswith('.yaml') or file.endswith('.yml') or file.endswith(".json"):
90
+ with open(file, 'r') as f:
91
+ if file.endswith(".json"):
92
+ content = json.load(f)
93
+ else:
94
+ content = yaml.load(f, Loader=yaml.SafeLoader)
95
+ model = create_model_from_spec(spec=content)
96
+ return [model]
97
+ elif file.endswith('.py'):
98
+ models = import_python_model(file)
99
+ return models
100
+ else:
101
+ raise ValueError("file must end in .json, .yaml, .yml or .py")
102
+
103
+ def parse_policy_file(file: str) -> List[ModelPolicy]:
104
+ if file.endswith('.yaml') or file.endswith('.yml') or file.endswith(".json"):
105
+ with open(file, 'r') as f:
106
+ if file.endswith(".json"):
107
+ content = json.load(f)
108
+ else:
109
+ content = yaml.load(f, Loader=yaml.SafeLoader)
110
+ policy = create_policy_from_spec(spec=content)
111
+ return [policy]
112
+ elif file.endswith('.py'):
113
+ policies = import_python_policy(file)
114
+ return policies
115
+ else:
116
+ raise ValueError("file must end in .json, .yaml, .yml or .py")
117
+
118
+ def extract_model_names_from_policy_inner(policy_inner: ModelPolicyInner) -> List[str]:
119
+ model_names = []
120
+ for target in policy_inner.targets:
121
+ if isinstance(target, ModelPolicyTarget):
122
+ model_names.append(target.model_name)
123
+ elif isinstance(target, ModelPolicyInner):
124
+ model_names += extract_model_names_from_policy_inner(target)
125
+ return model_names
126
+
127
+ def get_model_names_from_policy(policy: ModelPolicy) -> List[str]:
128
+ return extract_model_names_from_policy_inner(policy_inner=policy.policy)
129
+
130
+ class ModelsController:
131
+ def __init__(self):
132
+ self.models_client = None
133
+ self.model_policies_client = None
134
+
135
+ def get_models_client(self) -> ModelsClient:
136
+ if not self.models_client:
137
+ self.models_client = instantiate_client(ModelsClient)
138
+ return self.models_client
139
+
140
+ def get_model_policies_client(self) -> ModelPoliciesClient:
141
+ if not self.model_policies_client:
142
+ self.model_policies_client = instantiate_client(ModelPoliciesClient)
143
+ return self.model_policies_client
144
+
145
+ def list_models(self, print_raw: bool = False) -> None:
146
+ models_client: ModelsClient = self.get_models_client()
147
+ model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
148
+ global WATSONX_URL
149
+ default_env_path = get_default_env_file()
150
+ merged_env_dict = merge_env(
151
+ default_env_path,
152
+ None
153
+ )
154
+
155
+ if 'WATSONX_URL' in merged_env_dict and merged_env_dict['WATSONX_URL']:
156
+ WATSONX_URL = merged_env_dict['WATSONX_URL']
157
+
158
+ watsonx_url = merged_env_dict.get("WATSONX_URL")
159
+ if not watsonx_url:
160
+ logger.error("Error: WATSONX_URL is required in the environment.")
161
+ sys.exit(1)
162
+
163
+ logger.info("Retrieving virtual-model models list...")
164
+ virtual_models = models_client.list()
165
+
166
+
167
+
168
+ logger.info("Retrieving virtual-policies models list...")
169
+ virtual_model_policies = model_policies_client.list()
170
+
171
+ logger.info("Retrieving watsonx.ai models list...")
172
+ found_models = _get_wxai_foundational_models()
173
+
174
+
175
+ preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
176
+ incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
177
+
178
+ preferred_list = _string_to_list(preferred_str)
179
+ incompatible_list = _string_to_list(incompatible_str)
180
+
181
+ models = found_models.get("resources", [])
182
+ if not models:
183
+ logger.error("No models found.")
184
+ else:
185
+ # Remove incompatible models
186
+ filtered_models = []
187
+ for model in models:
188
+ model_id = model.get("model_id", "")
189
+ short_desc = model.get("short_description", "")
190
+ if any(incomp in model_id.lower() for incomp in incompatible_list):
191
+ continue
192
+ if any(incomp in short_desc.lower() for incomp in incompatible_list):
193
+ continue
194
+ filtered_models.append(model)
195
+
196
+ # Sort to put the preferred first
197
+ def sort_key(model):
198
+ model_id = model.get("model_id", "").lower()
199
+ is_preferred = any(pref in model_id for pref in preferred_list)
200
+ return (0 if is_preferred else 1, model_id)
201
+
202
+ sorted_models = sorted(filtered_models, key=sort_key)
203
+
204
+ if print_raw:
205
+ theme = rich.theme.Theme({"model.name": "bold cyan"})
206
+ console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
207
+ console.print("[bold]Available Models:[/bold]")
208
+
209
+ for model in (virtual_models + virtual_model_policies):
210
+ console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
211
+
212
+ for model in sorted_models:
213
+ model_id = model.get("model_id", "N/A")
214
+ short_desc = model.get("short_description", "No description provided.")
215
+ full_model_name = f"watsonx/{model_id}: {short_desc}"
216
+ marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
217
+ console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
218
+
219
+ console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
220
+ else:
221
+ table = rich.table.Table(
222
+ show_header=True,
223
+ title="[bold]Available Models[/bold]",
224
+ caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
225
+ show_lines=True)
226
+ columns = ["Model", "Description"]
227
+ for col in columns:
228
+ table.add_column(col)
229
+
230
+ for model in (virtual_models + virtual_model_policies):
231
+ table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
232
+
233
+ for model in sorted_models:
234
+ model_id = model.get("model_id", "N/A")
235
+ short_desc = model.get("short_description", "No description provided.")
236
+ marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
237
+ table.add_row(f"[yellow]{marker}[/yellow]watsonx/{model_id}", short_desc)
238
+
239
+ rich.print(table)
240
+
241
+ def import_model(self, file: str, app_id: str | None) -> List[VirtualModel]:
242
+ from ibm_watsonx_orchestrate.cli.commands.models.model_provider_mapper import validate_ProviderConfig # lazily import this because the lut building is expensive
243
+ models = parse_model_file(file)
244
+
245
+ for model in models:
246
+ if not model.name.startswith('virtual-model/'):
247
+ model.name = f"virtual-model/{model.name}"
248
+
249
+ provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model.name.split('/')))
250
+ if not model.provider_config:
251
+ model.provider_config = ProviderConfig.model_validate({"provider": provider})
252
+ else:
253
+ model.provider_config.provider = provider
254
+
255
+ if "anthropic" in model.name:
256
+ if not model.config:
257
+ model.config = {}
258
+ if "max_tokens" not in model.config:
259
+ model.config["max_tokens"] = ANTHROPIC_DEFAULT_MAX_TOKENS
260
+
261
+ if app_id:
262
+ model.connection_id = get_connection_id(app_id, supported_schemas={ConnectionType.KEY_VALUE})
263
+ validate_ProviderConfig(model.provider_config, app_id=app_id)
264
+ return models
265
+
266
+ def create_model(self, name: str, display_name: str | None = None, description: str | None = None, provider_config_dict: dict = None, model_type: ModelType = ModelType.CHAT, app_id: str = None) -> VirtualModel:
267
+ from ibm_watsonx_orchestrate.cli.commands.models.model_provider_mapper import validate_ProviderConfig # lazily import this because the lut building is expensive
268
+
269
+ provider =next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), name.split('/')))
270
+
271
+ provider_config = {}
272
+ if provider_config_dict:
273
+ provider_config = ProviderConfig.model_validate(provider_config_dict)
274
+ provider_config.provider = provider
275
+ else:
276
+ provider_config = ProviderConfig.model_validate({"provider": provider})
277
+ validate_ProviderConfig(provider_config, app_id=app_id)
278
+
279
+ if not name.startswith('virtual-model/'):
280
+ name = f"virtual-model/{name}"
281
+
282
+ config=None
283
+ # Anthropic has no default for max_tokens
284
+ if "anthropic" in name:
285
+ config = {
286
+ "max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
287
+ }
288
+
289
+ model = VirtualModel(
290
+ name=name,
291
+ display_name=display_name,
292
+ description=description,
293
+ tags=[],
294
+ provider_config=provider_config,
295
+ config=config,
296
+ model_type=model_type,
297
+ connection_id=get_connection_id(app_id, supported_schemas={ConnectionType.KEY_VALUE})
298
+ )
299
+
300
+ return model
301
+
302
+ def publish_or_update_models(self, model: VirtualModel) -> None:
303
+ models_client = self.get_models_client()
304
+
305
+ existing_models = models_client.get_draft_by_name(model.name)
306
+ if len(existing_models) > 1:
307
+ logger.error(f"Multiple models with the name '{model.name}' found. Failed to update model")
308
+ sys.exit(1)
309
+
310
+ if len(existing_models) == 1:
311
+ self.update_model(model_id=existing_models[0].id, model=model)
312
+ else:
313
+ self.publish_model(model=model)
314
+
315
+ def publish_model(self, model: VirtualModel) -> None:
316
+ self.get_models_client().create(model)
317
+ logger.info(f"Successfully added the model '{model.name}'")
318
+
319
+ def update_model(self, model_id: str, model: VirtualModel) -> None:
320
+ logger.info(f"Existing model '{model.name}' found. Updating...")
321
+ self.get_models_client().update(model_id, model)
322
+ logger.info(f"Model '{model.name}' updated successfully")
323
+
324
+ def remove_model(self, name: str) -> None:
325
+ models_client: ModelsClient = self.get_models_client()
326
+
327
+ existing_models = models_client.get_draft_by_name(name)
328
+
329
+ if len(existing_models) > 1:
330
+ logger.error(f"Multiple models with the name '{name}' found. Failed to remove model")
331
+ sys.exit(1)
332
+ if len(existing_models) == 0:
333
+ logger.error(f"No model found with the name '{name}'")
334
+ sys.exit(1)
335
+
336
+ model = existing_models[0]
337
+
338
+ models_client.delete(model_id=model.id)
339
+ logger.info(f"Successfully removed the model '{name}'")
340
+
341
+ def import_model_policy(self, file: str) -> List[ModelPolicy]:
342
+ policies = parse_policy_file(file)
343
+ model_client: ModelsClient = self.get_models_client()
344
+ model_lut = {m.name: m.id for m in model_client.list()}
345
+
346
+ for policy in policies:
347
+ models = get_model_names_from_policy(policy)
348
+ for m in models:
349
+ if m not in model_lut:
350
+ logger.error(f"No model found with the name '{m}'")
351
+ sys.exit(1)
352
+
353
+ if not policy.name.startswith('virtual-policy/'):
354
+ policy.name = f"virtual-policy/{policy.name}"
355
+
356
+ return policies
357
+
358
+ def create_model_policy(
359
+ self,
360
+ name: str,
361
+ models: List[str],
362
+ strategy: ModelPolicyStrategyMode,
363
+ strategy_on_code: List[int],
364
+ retry_on_code: List[int],
365
+ retry_attempts: int,
366
+ display_name: str = None,
367
+ description: str = None
368
+ ) -> ModelPolicy:
369
+
370
+ model_client: ModelsClient = self.get_models_client()
371
+ model_lut = {m.name: m.id for m in model_client.list()}
372
+ for m in models:
373
+ if m not in model_lut:
374
+ logger.error(f"No model found with the name '{m}'")
375
+ sys.exit(1)
376
+
377
+ if not name.startswith('virtual-policy/'):
378
+ name = f"virtual-policy/{name}"
379
+
380
+ inner = ModelPolicyInner()
381
+ inner.strategy = ModelPolicyStrategy(
382
+ mode=strategy,
383
+ on_status_codes=strategy_on_code
384
+ )
385
+ inner.targets = [ModelPolicyTarget(weight=1, model_name=m) for m in models]
386
+ if retry_on_code:
387
+ inner.retry = ModelPolicyRetry(
388
+ on_status_codes=retry_on_code,
389
+ attempts=retry_attempts
390
+ )
391
+
392
+ policy = ModelPolicy(
393
+ name=name,
394
+ display_name=display_name or name,
395
+ description=description or name,
396
+ policy=inner
397
+ )
398
+
399
+ return policy
400
+
401
+ def publish_or_update_model_policies(self, policy: ModelPolicy) -> None:
402
+ model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
403
+
404
+ existing_policies = model_policies_client.get_draft_by_name(policy.name)
405
+ if len(existing_policies) > 1:
406
+ logger.error(f"Multiple model policies with the name '{policy.name}' found. Failed to update model policy")
407
+ sys.exit(1)
408
+
409
+ if len(existing_policies) == 1:
410
+ self.update_policy(policy_id=existing_policies[0].id, policy=policy)
411
+ else:
412
+ self.publish_policy(policy=policy)
413
+
414
+ def publish_policy(self, policy: VirtualModel) -> None:
415
+ self.get_model_policies_client().create(policy)
416
+ logger.info(f"Successfully added the model policy '{policy.name}'")
417
+
418
+ def update_policy(self, policy_id: str, policy: VirtualModel) -> None:
419
+ logger.info(f"Existing model policy '{policy.name}' found. Updating...")
420
+ self.get_model_policies_client().update(policy_id, policy)
421
+ logger.info(f"Model policy '{policy.name}' updated successfully")
422
+
423
+ def remove_policy(self, name: str) -> None:
424
+ model_policies_client: ModelPoliciesClient = self.get_model_policies_client()
425
+ existing_model_policies = model_policies_client.get_draft_by_name(name)
426
+
427
+ if len(existing_model_policies) > 1:
428
+ logger.error(f"Multiple model policies with the name '{name}' found. Failed to remove model policy")
429
+ sys.exit(1)
430
+ if len(existing_model_policies) == 0:
431
+ logger.error(f"No model policy found with the name '{name}'")
432
+ sys.exit(1)
433
+
434
+ policy = existing_model_policies[0]
435
+
436
+ model_policies_client.delete(model_policy_id=policy.id)
437
+ logger.info(f"Successfully removed the policy '{name}'")
@@ -22,6 +22,7 @@ from ibm_watsonx_orchestrate.cli.commands.environment.environment_controller imp
22
22
 
23
23
  from ibm_watsonx_orchestrate.cli.config import LICENSE_HEADER, \
24
24
  ENV_ACCEPT_LICENSE
25
+
25
26
  from ibm_watsonx_orchestrate.cli.config import PROTECTED_ENV_NAME, clear_protected_env_credentials_token, Config, \
26
27
  AUTH_CONFIG_FILE_FOLDER, AUTH_CONFIG_FILE, AUTH_MCSP_TOKEN_OPT, AUTH_SECTION_HEADER, USER_ENV_CACHE_HEADER, LICENSE_HEADER, \
27
28
  ENV_ACCEPT_LICENSE
@@ -816,4 +817,4 @@ def run_db_migration() -> None:
816
817
  sys.exit(1)
817
818
 
818
819
  if __name__ == "__main__":
819
- server_app()
820
+ server_app()
@@ -25,7 +25,7 @@ from rich.panel import Panel
25
25
 
26
26
  from ibm_watsonx_orchestrate.agent_builder.tools import BaseTool, ToolSpec
27
27
  from ibm_watsonx_orchestrate.agent_builder.tools.openapi_tool import create_openapi_json_tools_from_uri,create_openapi_json_tools_from_content
28
- from ibm_watsonx_orchestrate.cli.commands.models.models_command import ModelHighlighter
28
+ from ibm_watsonx_orchestrate.cli.commands.models.models_controller import ModelHighlighter
29
29
  from ibm_watsonx_orchestrate.cli.commands.tools.types import RegistryType
30
30
  from ibm_watsonx_orchestrate.cli.commands.connections.connections_controller import configure_connection, remove_connection, add_connection
31
31
  from ibm_watsonx_orchestrate.agent_builder.connections.types import ConnectionType, ConnectionEnvironment, ConnectionPreference
@@ -6,5 +6,6 @@ from .connections_client import (
6
6
 
7
7
  from .utils import (
8
8
  get_connections_client,
9
- get_connection_type
9
+ get_connection_type,
10
+ get_connection_id
10
11
  )
@@ -1,8 +1,11 @@
1
+ import logging
1
2
  from ibm_watsonx_orchestrate.client.utils import instantiate_client, is_local_dev
2
3
  from ibm_watsonx_orchestrate.client.connections.connections_client import ConnectionsClient
3
4
  from ibm_watsonx_orchestrate.cli.config import Config, ENVIRONMENTS_SECTION_HEADER, CONTEXT_SECTION_HEADER, CONTEXT_ACTIVE_ENV_OPT, ENV_WXO_URL_OPT
4
5
  from ibm_watsonx_orchestrate.agent_builder.connections.types import ConnectionType, ConnectionAuthType, ConnectionSecurityScheme
5
6
 
7
+ logger = logging.getLogger(__name__)
8
+
6
9
  LOCAL_CONNECTION_MANAGER_PORT = 3001
7
10
 
8
11
  def _get_connections_manager_url() -> str:
@@ -25,3 +28,30 @@ def get_connection_type(security_scheme: ConnectionSecurityScheme, auth_type: Co
25
28
  if security_scheme != ConnectionSecurityScheme.OAUTH2:
26
29
  return ConnectionType(security_scheme)
27
30
  return ConnectionType(auth_type)
31
+
32
+ def get_connection_id(app_id: str, supported_schemas: set | None = None) -> str:
33
+ if not app_id:
34
+ return
35
+
36
+ connections_client = get_connections_client()
37
+
38
+ connection_id = None
39
+ if app_id is not None:
40
+ connection = connections_client.get(app_id=app_id)
41
+ if not connection:
42
+ logger.error(f"No connection exists with the app-id '{app_id}'")
43
+ exit(1)
44
+ connection_id = connection.connection_id
45
+
46
+ existing_draft_configuration = None
47
+ existing_live_configuration = None
48
+
49
+ if supported_schemas:
50
+ existing_draft_configuration = connections_client.get_config(app_id=app_id, env='draft')
51
+ existing_live_configuration = connections_client.get_config(app_id=app_id, env='live')
52
+ for config in [existing_draft_configuration, existing_live_configuration]:
53
+ if config and config.security_scheme not in supported_schemas:
54
+ logger.error(f"Only {', '.join(supported_schemas)} credentials are currently supported. Provided connection '{config.app_id}' is of type '{config.security_scheme}' in environment '{config.environment}'")
55
+ exit(1)
56
+
57
+ return connection_id
@@ -5,15 +5,12 @@ from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient, Client
5
5
 
6
6
  import logging
7
7
 
8
- from ibm_watsonx_orchestrate.client.model_policies.types import ModelPolicy
9
- from ibm_watsonx_orchestrate.client.models.types import ListVirtualModel, CreateVirtualModel
8
+ from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicy
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
13
12
 
14
13
 
15
-
16
-
17
14
  class ModelPoliciesClient(BaseAPIClient):
18
15
  """
19
16
  Client to handle CRUD operations for ModelPolicies endpoint
@@ -21,6 +18,10 @@ class ModelPoliciesClient(BaseAPIClient):
21
18
  # POST api/v1/model_policy
22
19
  def create(self, model: ModelPolicy) -> None:
23
20
  self._post("/model_policy", data=model.model_dump())
21
+
22
+ # PUT api/v1/model_policy/{model_policy_id}
23
+ def update(self, model_policy_id: str, model: ModelPolicy) -> None:
24
+ self._put(f"/model_policy/{model_policy_id}", data=model.model_dump())
24
25
 
25
26
  # DELETE api/v1/model_policy/{model_policy_id}
26
27
  def delete(self, model_policy_id: str) -> dict:
@@ -43,5 +44,23 @@ class ModelPoliciesClient(BaseAPIClient):
43
44
  if e.response.status_code == 404:
44
45
  return []
45
46
  raise e
47
+
48
+ def get_drafts_by_names(self, policy_names: List[str]) -> List[ModelPolicy]:
49
+ try:
50
+ formatted_policy_names = [f'name={x}' for x in policy_names]
51
+ res = self._get(f"/model_policy?{'&'.join(formatted_policy_names)}")
52
+ return [ModelPolicy.model_validate(conn) for conn in res]
53
+ except ValidationError as e:
54
+ logger.error("Received unexpected response from server")
55
+ raise e
56
+ except ClientAPIException as e:
57
+ if e.response.status_code == 404:
58
+ return []
59
+ raise e
60
+
61
+ def get_draft_by_name(self, policy_name: str) -> ModelPolicy:
62
+ return self.get_drafts_by_names([policy_name])
63
+
64
+
46
65
 
47
66
 
@@ -5,7 +5,7 @@ from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient, Client
5
5
 
6
6
  import logging
7
7
 
8
- from ibm_watsonx_orchestrate.client.models.types import ListVirtualModel, CreateVirtualModel
8
+ from ibm_watsonx_orchestrate.agent_builder.models.types import ListVirtualModel, VirtualModel
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
@@ -18,8 +18,12 @@ class ModelsClient(BaseAPIClient):
18
18
  Client to handle CRUD operations for Models endpoint
19
19
  """
20
20
  # POST api/v1/models
21
- def create(self, model: CreateVirtualModel) -> None:
22
- self._post("/models", data=model.model_dump())
21
+ def create(self, model: VirtualModel) -> None:
22
+ self._post("/models", data=model.model_dump(exclude_none=True))
23
+
24
+ # PUT api/v1/models/{models_id}
25
+ def update(self, model_id: str, model: VirtualModel) -> None:
26
+ self._put(f"/models/{model_id}", data=model.model_dump(exclude_none=True))
23
27
 
24
28
  # DELETE api/v1/models/{model_id}
25
29
  def delete(self, model_id: str) -> dict:
@@ -42,5 +46,21 @@ class ModelsClient(BaseAPIClient):
42
46
  if e.response.status_code == 404:
43
47
  return []
44
48
  raise e
49
+
50
+ def get_drafts_by_names(self, model_names: List[str]) -> List[ListVirtualModel]:
51
+ try:
52
+ formatted_model_names = [f'name={x}' for x in model_names]
53
+ res = self._get(f"/models?{'&'.join(formatted_model_names)}")
54
+ return [ListVirtualModel.model_validate(conn) for conn in res]
55
+ except ValidationError as e:
56
+ logger.error("Received unexpected response from server")
57
+ raise e
58
+ except ClientAPIException as e:
59
+ if e.response.status_code == 404:
60
+ return []
61
+ raise e
62
+
63
+ def get_draft_by_name(self, model_name: str) -> ListVirtualModel:
64
+ return self.get_drafts_by_names([model_name])
45
65
 
46
66
 
@@ -1,3 +1,4 @@
1
+ from typing import Literal
1
2
  from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient, ClientAPIException
2
3
  from typing_extensions import List
3
4
 
@@ -32,7 +33,7 @@ class ToolClient(BaseAPIClient):
32
33
  formatted_tool_names = [f"names={x}" for x in tool_names]
33
34
  return self._get(f"/tools?{'&'.join(formatted_tool_names)}")
34
35
 
35
- def get_draft_by_id(self, tool_id: str) -> List[dict]:
36
+ def get_draft_by_id(self, tool_id: str) -> dict | Literal[""]:
36
37
  if tool_id is None:
37
38
  return ""
38
39
  else:
@@ -323,6 +323,7 @@ services:
323
323
  AI_GATEWAY_BASE_URL: ${AI_GATEWAY_BASE_URL}
324
324
  AI_GATEWAY_ENABLED : ${AI_GATEWAY_ENABLED}
325
325
 
326
+
326
327
  wxo-server-worker:
327
328
  image: ${WORKER_REGISTRY:-us.icr.io/watson-orchestrate-private}/wxo-server-conversation_controller:${WORKER_TAG:-latest}
328
329
  platform: linux/amd64
@@ -450,6 +451,7 @@ services:
450
451
  STORAGE_S3_REGION: us-east-1
451
452
  AWS_ACCESS_KEY_ID: ${MINIO_ROOT_USER:-minioadmin}
452
453
  AWS_SECRET_ACCESS_KEY: ${MINIO_ROOT_PASSWORD:-watsonxorchestrate}
454
+ TEMPUS_HOST_NAME: http://wxo-tempus-runtime:9044
453
455
  extra_hosts:
454
456
  - "host.docker.internal:host-gateway"
455
457