ibm-watsonx-orchestrate 1.5.0b0__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -2
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +13 -0
  4. ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
  5. ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
  6. ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
  7. ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +43 -7
  8. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
  9. ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
  11. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
  12. ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +136 -259
  14. ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
  15. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
  16. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
  17. ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
  18. ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
  19. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
  20. ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
  21. ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
  22. ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
  23. ibm_watsonx_orchestrate/docker/default.env +10 -11
  24. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
  25. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +28 -25
  26. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
  27. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
  28. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
  29. {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,191 @@
1
+ # https://portkey.ai/
2
+ from dotenv import dotenv_values
3
+ import sys
4
+ import logging
5
+ from ibm_watsonx_orchestrate.agent_builder.models.types import ProviderConfig, ModelProvider
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ _BASIC_PROVIDER_CONFIG_KEYS = {'provider', 'api_key', 'custom_host', 'url_to_fetch', 'forward_headers', 'request_timeout', 'transform_to_form_data'}
10
+
11
+ PROVIDER_EXTRA_PROPERTIES_LUT = {
12
+ ModelProvider.ANTHROPIC: {'anthropic_beta', 'anthropic_version'},
13
+ # ModelProvider.AZURE_AI: {
14
+ # 'azure_resource_name',
15
+ # 'azure_deployment_id',
16
+ # 'azure_api_version',
17
+ # 'ad_auth',
18
+ # 'azure_auth_mode',
19
+ # 'azure_managed_client_id',
20
+ # 'azure_entra_client_id',
21
+ # 'azure_entra_client_secret',
22
+ # 'azure_entra_tenant_id',
23
+ # 'azure_ad_token',
24
+ # 'azure_model_name'
25
+ # },
26
+ # ModelProvider.BEDROCK: {
27
+ # 'aws_secret_access_key',
28
+ # 'aws_access_key_id',
29
+ # 'aws_session_token',
30
+ # 'aws_region',
31
+ # 'aws_auth_type',
32
+ # 'aws_role_arn',
33
+ # 'aws_external_id',
34
+ # 'aws_s3_bucket',
35
+ # 'aws_s3_object_key',
36
+ # 'aws_bedrock_model',
37
+ # 'aws_server_side_encryption',
38
+ # 'aws_server_side_encryption_kms_key_id'
39
+ # },
40
+ # ModelProvider.VERTEX_AI: {
41
+ # 'vertex_region',
42
+ # 'vertex_project_id',
43
+ # 'vertex_service_account_json',
44
+ # 'vertex_storage_bucket_name',
45
+ # 'vertex_model_name',
46
+ # 'filename'
47
+ # },
48
+ # ModelProvider.HUGGINGFACE: {'huggingfaceBaseUrl'},
49
+ ModelProvider.MISTRAL_AI: {'mistral_fim_completion'},
50
+ # ModelProvider.STABILITY_AI: {'stability_client_id', 'stability_client_user_id', 'stability_client_version'},
51
+ ModelProvider.WATSONX: {'watsonx_version', 'watsonx_space_id', 'watsonx_project_id', 'api_key', 'watsonx_deployment_id', 'watsonx_cpd_url', 'watsonx_cpd_username', 'watsonx_cpd_password'},
52
+
53
+ # 'palm': _bpp('PALM', {}),
54
+ # 'nomic': _bpp('NOMIC', {}),
55
+ # 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
56
+ # 'segmind': _bpp('SEGMIND', {}),
57
+ # 'deepinfra': _bpp('DEEPINFRA', {}),
58
+ # 'novita-ai': _bpp('NOVITA_AI', {}),
59
+ # 'fireworks-ai': _bpp('FIREWORKS',{
60
+ # 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
61
+ # }),
62
+ # 'deepseek': _bpp('DEEPSEEK', {}),
63
+ # 'voyage': _bpp('VOYAGE', {}),
64
+ # 'moonshot': _bpp('MOONSHOT', {}),
65
+ # 'lingyi': _bpp('LINGYI', {}),
66
+ # 'zhipu': _bpp('ZHIPU', {}),
67
+ # 'monsterapi': _bpp('MONSTERAPI', {}),
68
+ # 'predibase': _bpp('PREDIBASE', {}),
69
+
70
+ # 'github': _bpp('GITHUB', {}),
71
+ # 'deepbricks': _bpp('DEEPBRICKS', {}),
72
+ # 'sagemaker': _bpp('AMZN_SAGEMAKER', {
73
+ # 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
74
+ # 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
75
+ # 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
76
+ # 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
77
+ # 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
78
+ # 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
79
+ # 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
80
+ # 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
81
+ # 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
82
+ # }),
83
+ # '@cf': _bpp('WORKERS_AI', { # workers ai
84
+ # 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
85
+ # }),
86
+ # 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
87
+ # 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
88
+ # })
89
+ }
90
+
91
+ PROVIDER_REQUIRED_FIELDS = {k:['api_key'] for k in ModelProvider}
92
+ # Update required fields for each provider
93
+ # Use sets to denote when a requirement is 'or'
94
+ PROVIDER_REQUIRED_FIELDS.update({
95
+ ModelProvider.WATSONX: PROVIDER_REQUIRED_FIELDS[ModelProvider.WATSONX] + [{'watsonx_space_id', 'watsonx_project_id', 'watsonx_deployment_id'}],
96
+ ModelProvider.OLLAMA: PROVIDER_REQUIRED_FIELDS[ModelProvider.OLLAMA] + ['custom_host']
97
+ })
98
+
99
+ # def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
100
+ # provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
101
+ # if provider not in ModelProvider:
102
+ # logger.error(f"Unsupported model provider {provider}")
103
+ # sys.exit(1)
104
+
105
+ # values = dotenv_values(str(env_file_path))
106
+
107
+ # if values is None:
108
+ # logger.error(f"No provider configuration in env file {env_file_path}")
109
+ # sys.exit(1)
110
+
111
+ # cfg = ProviderConfig()
112
+ # cfg.provider = PROVIDER_LUT[provider]
113
+
114
+ # cred_lut = PROVIDER_PROPERTIES_LUT[provider]
115
+
116
+
117
+ # consumed_credentials = []
118
+ # # Ollama requires some apikey but its content don't matter
119
+ # # Default it to 'ollama' to avoid users needing to specify
120
+ # if cfg.provider == ModelProvider.OLLAMA:
121
+ # consumed_credentials.append('api_key')
122
+ # setattr(cfg, 'api_key', ModelProvider.OLLAMA)
123
+
124
+ # for key, value in values.items():
125
+ # if key in cred_lut:
126
+ # k = cred_lut[key]
127
+ # consumed_credentials.append(k)
128
+ # setattr(cfg, k, value)
129
+
130
+ # return cfg
131
+
132
+ def _validate_provider(provider: str | ModelProvider) -> None:
133
+ if not ModelProvider.has_value(provider):
134
+ logger.error(f"Unsupported model provider {provider}")
135
+ sys.exit(1)
136
+
137
+ def _validate_extra_fields(provider: ModelProvider, cfg: ProviderConfig) -> None:
138
+ accepted_fields = _BASIC_PROVIDER_CONFIG_KEYS.copy()
139
+ extra_accepted_fields = PROVIDER_EXTRA_PROPERTIES_LUT.get(provider)
140
+ if extra_accepted_fields:
141
+ accepted_fields = accepted_fields.union(extra_accepted_fields)
142
+
143
+ for attr in cfg.__dict__:
144
+ if attr.startswith("__"):
145
+ continue
146
+
147
+ if cfg.__dict__.get(attr) is not None and attr not in accepted_fields:
148
+ logger.warning(f"The config option '{attr}' is not used by provider '{provider}'")
149
+
150
+ def _validate_requirements(provider: ModelProvider, cfg: ProviderConfig, app_id: str = None) -> None:
151
+ provided_credentials = set([k for k,v in dict(cfg).items() if v is not None])
152
+ required_creds = PROVIDER_REQUIRED_FIELDS[provider]
153
+ missing_credentials = []
154
+ for cred in required_creds:
155
+ if isinstance(cred, set):
156
+ if not any(c in provided_credentials for c in cred):
157
+ missing_credentials.append(cred)
158
+ else:
159
+ if cred not in provided_credentials:
160
+ missing_credentials.append(cred)
161
+
162
+ if len(missing_credentials) > 0:
163
+ if not app_id:
164
+ missing_credentials_string = f"Missing configuration variable(s) required for the provider {provider}:"
165
+ else:
166
+ missing_credentials_string = f"The following configuration variable(s) for the provider {provider} are not in the spec provider_config:"
167
+ for cred in missing_credentials:
168
+ if isinstance(cred, set):
169
+ cred_str = ' or '.join(list(cred))
170
+ else:
171
+ cred_str = cred
172
+ missing_credentials_string += f"\n\t - {cred_str}"
173
+
174
+ if not app_id:
175
+ logger.error(missing_credentials_string)
176
+ logger.error("Please provide the above values in the provider config. For secret values (e.g. 'api_key') create a key_value connection `orchestrate connections add` then bind it to the model with `--app-id`")
177
+ sys.exit(1)
178
+ else:
179
+ logger.info(missing_credentials_string)
180
+ logger.info(f"Please ensure these values are set in the connection '{app_id}'.")
181
+
182
+
183
+ def validate_ProviderConfig(cfg: ProviderConfig, app_id: str)-> None:
184
+ if not cfg:
185
+ return
186
+
187
+ provider = cfg.provider
188
+
189
+ _validate_provider(provider)
190
+ _validate_extra_fields(provider=provider, cfg=cfg)
191
+ _validate_requirements(provider=provider, cfg=cfg, app_id=app_id)
@@ -1,32 +1,20 @@
1
1
  import logging
2
- import os
3
- import sys
4
2
  from typing import List
3
+ import json
4
+ import sys
5
5
 
6
- import requests
7
- import rich
8
- import rich.highlighter
9
6
  import typer
10
7
  from typing_extensions import Annotated
11
8
 
12
- from ibm_watsonx_orchestrate.cli.commands.server.server_command import get_default_env_file, merge_env
13
- from ibm_watsonx_orchestrate.client.model_policies.model_policies_client import ModelPoliciesClient
14
- from ibm_watsonx_orchestrate.client.model_policies.types import ModelPolicy, ModelPolicyInner, \
15
- ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
16
- from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
17
- from ibm_watsonx_orchestrate.client.models.types import CreateVirtualModel, ModelType, ANTHROPIC_DEFAULT_MAX_TOKENS
18
- from ibm_watsonx_orchestrate.client.utils import instantiate_client
9
+ from ibm_watsonx_orchestrate.agent_builder.models.types import ModelType
10
+ from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicyStrategyMode
11
+ from ibm_watsonx_orchestrate.cli.commands.models.models_controller import ModelsController
12
+
19
13
 
20
14
  logger = logging.getLogger(__name__)
21
15
  models_app = typer.Typer(no_args_is_help=True)
22
16
  models_policy_app = typer.Typer(no_args_is_help=True)
23
- # models_app.add_typer(models_policy_app, name='policy', help='Add or remove pseudo models which route traffic between multiple downstream models')
24
-
25
- WATSONX_URL = os.getenv("WATSONX_URL")
26
-
27
- class ModelHighlighter(rich.highlighter.RegexHighlighter):
28
- base_style = "model."
29
- highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
17
+ models_app.add_typer(models_policy_app, name='policy', help='Add or remove pseudo models which route traffic between multiple downstream models')
30
18
 
31
19
  @models_app.command(name="list", help="List available models")
32
20
  def model_list(
@@ -35,108 +23,33 @@ def model_list(
35
23
  typer.Option("--raw", "-r", help="Display the list of models in a non-tabular format"),
36
24
  ] = False,
37
25
  ):
38
- models_client: ModelsClient = instantiate_client(ModelsClient)
39
- model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
40
- global WATSONX_URL
41
- default_env_path = get_default_env_file()
42
- merged_env_dict = merge_env(
43
- default_env_path,
44
- None
45
- )
46
-
47
- if 'WATSONX_URL' in merged_env_dict and merged_env_dict['WATSONX_URL']:
48
- WATSONX_URL = merged_env_dict['WATSONX_URL']
49
-
50
- watsonx_url = merged_env_dict.get("WATSONX_URL")
51
- if not watsonx_url:
52
- logger.error("Error: WATSONX_URL is required in the environment.")
53
- sys.exit(1)
54
-
55
- logger.info("Retrieving virtual-model models list...")
56
- virtual_models = models_client.list()
57
-
58
-
59
-
60
- logger.info("Retrieving virtual-policies models list...")
61
- virtual_model_policies = model_policies_client.list()
62
-
63
- logger.info("Retrieving watsonx.ai models list...")
64
- found_models = _get_wxai_foundational_models()
65
-
66
-
67
- preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
68
- incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
69
-
70
- preferred_list = _string_to_list(preferred_str)
71
- incompatible_list = _string_to_list(incompatible_str)
72
-
73
- models = found_models.get("resources", [])
74
- if not models:
75
- logger.error("No models found.")
76
- else:
77
- # Remove incompatible models
78
- filtered_models = []
79
- for model in models:
80
- model_id = model.get("model_id", "")
81
- short_desc = model.get("short_description", "")
82
- if any(incomp in model_id.lower() for incomp in incompatible_list):
83
- continue
84
- if any(incomp in short_desc.lower() for incomp in incompatible_list):
85
- continue
86
- filtered_models.append(model)
87
-
88
- # Sort to put the preferred first
89
- def sort_key(model):
90
- model_id = model.get("model_id", "").lower()
91
- is_preferred = any(pref in model_id for pref in preferred_list)
92
- return (0 if is_preferred else 1, model_id)
93
-
94
- sorted_models = sorted(filtered_models, key=sort_key)
95
-
96
- if print_raw:
97
- theme = rich.theme.Theme({"model.name": "bold cyan"})
98
- console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
99
- console.print("[bold]Available Models:[/bold]")
100
-
101
- for model in virtual_models:
102
- console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
103
-
104
- for model in virtual_model_policies:
105
- console.print(f"- ✨️ {model.name}:", 'No description provided.')
106
-
107
- for model in sorted_models:
108
- model_id = model.get("model_id", "N/A")
109
- short_desc = model.get("short_description", "No description provided.")
110
- full_model_name = f"watsonx/{model_id}: {short_desc}"
111
- marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
112
- console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
113
-
114
- console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
115
- else:
116
- table = rich.table.Table(
117
- show_header=True,
118
- title="[bold]Available Models[/bold]",
119
- caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
120
- show_lines=True)
121
- columns = ["Model", "Description"]
122
- for col in columns:
123
- table.add_column(col)
124
-
125
- for model in virtual_models:
126
- table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
127
-
128
- for model in virtual_model_policies:
129
- table.add_row(f"✨️ {model.name}", 'No description provided.')
130
-
131
- for model in sorted_models:
132
- model_id = model.get("model_id", "N/A")
133
- short_desc = model.get("short_description", "No description provided.")
134
- marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
135
- table.add_row(f"[yellow]{marker}[/yellow]watsonx/{model_id}", short_desc)
136
-
137
- rich.print(table)
138
-
26
+ models_controller = ModelsController()
27
+ models_controller.list_models(print_raw=print_raw)
139
28
 
29
+ @models_app.command(name="import", help="Import models from spec file")
30
+ def models_import(
31
+ file: Annotated[
32
+ str,
33
+ typer.Option(
34
+ "--file",
35
+ "-f",
36
+ help="Path to spec file containing model details.",
37
+ ),
38
+ ],
39
+ app_id: Annotated[
40
+ str, typer.Option(
41
+ '--app-id', '-a',
42
+ help='The app id of a key_value connection containing authentications details for the model provider.'
43
+ )
44
+ ] = None,
45
+ ):
46
+ models_controller = ModelsController()
47
+ models = models_controller.import_model(
48
+ file=file,
49
+ app_id=app_id
50
+ )
51
+ for model in models:
52
+ models_controller.publish_or_update_models(model=model)
140
53
 
141
54
  @models_app.command(name="add", help="Add an llm from a custom provider")
142
55
  def models_add(
@@ -144,10 +57,6 @@ def models_add(
144
57
  str,
145
58
  typer.Option("--name", "-n", help="The name of the model to add"),
146
59
  ],
147
- env_file: Annotated[
148
- str,
149
- typer.Option('--env-file', '-e', help='The path to an .env file containing the credentials for your llm provider'),
150
- ],
151
60
  description: Annotated[
152
61
  str,
153
62
  typer.Option('--description', '-d', help='The description of the model to add'),
@@ -156,38 +65,42 @@ def models_add(
156
65
  str,
157
66
  typer.Option('--display-name', help='What name should this llm appear as within the ui'),
158
67
  ] = None,
68
+ provider_config: Annotated[
69
+ str,
70
+ typer.Option(
71
+ "--provider-config",
72
+ help="LLM provider configuration in JSON format (e.g., '{\"customHost\": \"xyz\"}')",
73
+ ),
74
+ ] = None,
75
+ app_id: Annotated[
76
+ str, typer.Option(
77
+ '--app-id', '-a',
78
+ help='The app id of a key_value connection containing authentications details for the model provider.'
79
+ )
80
+ ] = None,
159
81
  type: Annotated[
160
82
  ModelType,
161
83
  typer.Option('--type', help='What type of model is it'),
162
84
  ] = ModelType.CHAT,
163
-
164
85
  ):
165
- from ibm_watsonx_orchestrate.cli.commands.models.env_file_model_provider_mapper import env_file_to_model_ProviderConfig # lazily import this because the lut building is expensive
166
-
167
- models_client: ModelsClient = instantiate_client(ModelsClient)
168
- provider_config = env_file_to_model_ProviderConfig(model_name=name, env_file_path=env_file)
169
- if not name.startswith('virtual-model/'):
170
- name = f"virtual-model/{name}"
171
-
172
- config=None
173
- # Anthropic has no default for max_tokens
174
- if "anthropic" in name:
175
- config = {
176
- "max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
177
- }
178
-
179
- model = CreateVirtualModel(
86
+ provider_config_dict = {}
87
+ if provider_config:
88
+ try:
89
+ provider_config_dict = json.loads(provider_config)
90
+ except:
91
+ logger.error(f"Failed to parse provider config. '{provider_config}' is not valid json")
92
+ sys.exit(1)
93
+
94
+ models_controller = ModelsController()
95
+ model = models_controller.create_model(
180
96
  name=name,
181
- display_name=display_name or name,
182
97
  description=description,
183
- tags=[],
184
- provider_config=provider_config,
185
- config=config,
186
- model_type=type
98
+ display_name=display_name,
99
+ provider_config_dict = provider_config_dict,
100
+ model_type=type,
101
+ app_id=app_id,
187
102
  )
188
-
189
- models_client.create(model)
190
- logger.info(f"Successfully added the model '{name}'")
103
+ models_controller.publish_or_update_models(model=model)
191
104
 
192
105
 
193
106
 
@@ -198,122 +111,86 @@ def models_remove(
198
111
  typer.Option("--name", "-n", help="The name of the model to remove"),
199
112
  ]
200
113
  ):
201
- models_client: ModelsClient = instantiate_client(ModelsClient)
202
- models = models_client.list()
203
- model = next(filter(lambda x: x.name == name or x.name == f"virtual-model/{name}", models), None)
204
- if not model:
205
- logger.error(f"No model found with the name '{name}'")
206
- sys.exit(1)
207
-
208
- models_client.delete(model_id=model.id)
209
- logger.info(f"Successfully removed the model '{name}'")
210
-
211
-
212
- # @models_policy_app.command(name='add', help='Add a model policy')
213
- # def models_policy_add(
214
- # name: Annotated[
215
- # str,
216
- # typer.Option("--name", "-n", help="The name of the model to remove"),
217
- # ],
218
- # models: Annotated[
219
- # List[str],
220
- # typer.Option('--model', '-m', help='The name of the model to add'),
221
- # ],
222
- # strategy: Annotated[
223
- # ModelPolicyStrategyMode,
224
- # typer.Option('--strategy', '-s', help='How to spread traffic across models'),
225
- # ],
226
- # strategy_on_code: Annotated[
227
- # List[int],
228
- # typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
229
- # ],
230
- # retry_on_code: Annotated[
231
- # List[int],
232
- # typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
233
- # ],
234
- # retry_attempts: Annotated[
235
- # int,
236
- # typer.Option('--retry-attempts', help='The number of attempts to retry'),
237
- # ],
238
- # display_name: Annotated[
239
- # str,
240
- # typer.Option('--display-name', help='What name should this llm appear as within the ui'),
241
- # ] = None
242
- # ):
243
- # model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
244
- # model_client: ModelsClient = instantiate_client(ModelsClient)
245
- # model_lut = {m.name: m.id for m in model_client.list()}
246
- # for m in models:
247
- # if m not in model_lut:
248
- # logger.error(f"No model found with the name '{m}'")
249
- # exit(1)
250
-
251
- # inner = ModelPolicyInner()
252
- # inner.strategy = ModelPolicyStrategy(
253
- # mode=strategy,
254
- # on_status_codes=strategy_on_code
255
- # )
256
- # inner.targets = [ModelPolicyTarget(model_id=model_lut[m], weight=1) for m in models]
257
- # if retry_on_code:
258
- # inner.retry = ModelPolicyRetry(
259
- # on_status_codes=retry_on_code,
260
- # attempts=retry_attempts
261
- # )
262
-
263
- # if not display_name:
264
- # display_name = name
265
-
266
-
267
- # policy = ModelPolicy(
268
- # name=name,
269
- # display_name=display_name,
270
- # policy=inner
271
- # )
272
- # model_policies_client.create(policy)
273
- # logger.info(f"Successfully added the model policy '{name}'")
274
-
275
-
276
-
277
- # @models_policy_app.command(name='remove', help='Remove a model policy')
278
- # def models_policy_remove(
279
- # name: Annotated[
280
- # str,
281
- # typer.Option("--name", "-n", help="The name of the model policy to remove"),
282
- # ]
283
- # ):
284
- # model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
285
- # model_policies = model_policies_client.list()
286
-
287
- # policy = next(filter(lambda x: x.name == name or x.name == f"virtual-policy/{name}", model_policies), None)
288
- # if not policy:
289
- # logger.error(f"No model found with the name '{name}'")
290
- # exit(1)
291
-
292
- # model_policies_client.delete(model_policy_id=policy.id)
293
- # logger.info(f"Successfully removed the model '{name}'")
114
+ models_controller = ModelsController()
115
+ models_controller.remove_model(name=name)
294
116
 
117
+ @models_policy_app.command(name='import', help='Add a model policy')
118
+ def models_policy_import(
119
+ file: Annotated[
120
+ str,
121
+ typer.Option(
122
+ "--file",
123
+ "-f",
124
+ help="Path to spec file containing model details.",
125
+ ),
126
+ ],
127
+ ):
128
+ models_controller = ModelsController()
129
+ policies = models_controller.import_model_policy(
130
+ file=file
131
+ )
132
+ for policy in policies:
133
+ models_controller.publish_or_update_model_policies(policy=policy)
295
134
 
296
- def _get_wxai_foundational_models():
297
- foundation_models_url = WATSONX_URL + "/ml/v1/foundation_model_specs?version=2024-05-01"
135
+ @models_policy_app.command(name='add', help='Add a model policy')
136
+ def models_policy_add(
137
+ name: Annotated[
138
+ str,
139
+ typer.Option("--name", "-n", help="The name of the model to remove"),
140
+ ],
141
+ models: Annotated[
142
+ List[str],
143
+ typer.Option('--model', '-m', help='The name of the model to add'),
144
+ ],
145
+ strategy: Annotated[
146
+ ModelPolicyStrategyMode,
147
+ typer.Option('--strategy', '-s', help='How to spread traffic across models'),
148
+ ],
149
+ strategy_on_code: Annotated[
150
+ List[int],
151
+ typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
152
+ ],
153
+ retry_on_code: Annotated[
154
+ List[int],
155
+ typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
156
+ ],
157
+ retry_attempts: Annotated[
158
+ int,
159
+ typer.Option('--retry-attempts', help='The number of attempts to retry'),
160
+ ],
161
+ display_name: Annotated[
162
+ str,
163
+ typer.Option('--display-name', help='What name should this llm appear as within the ui'),
164
+ ] = None,
165
+ description: Annotated[
166
+ str,
167
+ typer.Option('--description', help='Description of the policy for display in the ui'),
168
+ ] = None
169
+ ):
170
+ models_controller = ModelsController()
171
+ policy = models_controller.create_model_policy(
172
+ name=name,
173
+ models=models,
174
+ strategy=strategy,
175
+ strategy_on_code=strategy_on_code,
176
+ retry_on_code=retry_on_code,
177
+ retry_attempts=retry_attempts,
178
+ display_name=display_name,
179
+ description=description
180
+ )
181
+ models_controller.publish_or_update_model_policies(policy=policy)
298
182
 
299
- try:
300
- response = requests.get(foundation_models_url)
301
- except requests.exceptions.RequestException as e:
302
- logger.exception(f"Exception when connecting to Watsonx URL: {foundation_models_url}")
303
- raise
304
183
 
305
- if response.status_code != 200:
306
- error_message = (
307
- f"Failed to retrieve foundational models from {foundation_models_url}. "
308
- f"Status code: {response.status_code}. Response: {response.content}"
309
- )
310
- raise Exception(error_message)
311
-
312
- json_response = response.json()
313
- return json_response
314
184
 
315
- def _string_to_list(env_value):
316
- return [item.strip().lower() for item in env_value.split(",") if item.strip()]
185
+ @models_policy_app.command(name='remove', help='Remove a model policy')
186
+ def models_policy_remove(
187
+ name: Annotated[
188
+ str,
189
+ typer.Option("--name", "-n", help="The name of the model policy to remove"),
190
+ ]
191
+ ):
192
+ models_controller = ModelsController()
193
+ models_controller.remove_policy(name=name)
317
194
 
318
195
  if __name__ == "__main__":
319
196
  models_app()