ibm-watsonx-orchestrate 1.5.0b0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -2
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +13 -0
- ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
- ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +43 -7
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
- ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +136 -259
- ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
- ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
- ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
- ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
- ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
- ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
- ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
- ibm_watsonx_orchestrate/docker/default.env +10 -11
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.1.dist-info}/METADATA +1 -1
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.1.dist-info}/RECORD +28 -25
- ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.1.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.1.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.5.0b0.dist-info → ibm_watsonx_orchestrate-1.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,191 @@
|
|
1
|
+
# https://portkey.ai/
|
2
|
+
from dotenv import dotenv_values
|
3
|
+
import sys
|
4
|
+
import logging
|
5
|
+
from ibm_watsonx_orchestrate.agent_builder.models.types import ProviderConfig, ModelProvider
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_BASIC_PROVIDER_CONFIG_KEYS = {'provider', 'api_key', 'custom_host', 'url_to_fetch', 'forward_headers', 'request_timeout', 'transform_to_form_data'}
|
10
|
+
|
11
|
+
PROVIDER_EXTRA_PROPERTIES_LUT = {
|
12
|
+
ModelProvider.ANTHROPIC: {'anthropic_beta', 'anthropic_version'},
|
13
|
+
# ModelProvider.AZURE_AI: {
|
14
|
+
# 'azure_resource_name',
|
15
|
+
# 'azure_deployment_id',
|
16
|
+
# 'azure_api_version',
|
17
|
+
# 'ad_auth',
|
18
|
+
# 'azure_auth_mode',
|
19
|
+
# 'azure_managed_client_id',
|
20
|
+
# 'azure_entra_client_id',
|
21
|
+
# 'azure_entra_client_secret',
|
22
|
+
# 'azure_entra_tenant_id',
|
23
|
+
# 'azure_ad_token',
|
24
|
+
# 'azure_model_name'
|
25
|
+
# },
|
26
|
+
# ModelProvider.BEDROCK: {
|
27
|
+
# 'aws_secret_access_key',
|
28
|
+
# 'aws_access_key_id',
|
29
|
+
# 'aws_session_token',
|
30
|
+
# 'aws_region',
|
31
|
+
# 'aws_auth_type',
|
32
|
+
# 'aws_role_arn',
|
33
|
+
# 'aws_external_id',
|
34
|
+
# 'aws_s3_bucket',
|
35
|
+
# 'aws_s3_object_key',
|
36
|
+
# 'aws_bedrock_model',
|
37
|
+
# 'aws_server_side_encryption',
|
38
|
+
# 'aws_server_side_encryption_kms_key_id'
|
39
|
+
# },
|
40
|
+
# ModelProvider.VERTEX_AI: {
|
41
|
+
# 'vertex_region',
|
42
|
+
# 'vertex_project_id',
|
43
|
+
# 'vertex_service_account_json',
|
44
|
+
# 'vertex_storage_bucket_name',
|
45
|
+
# 'vertex_model_name',
|
46
|
+
# 'filename'
|
47
|
+
# },
|
48
|
+
# ModelProvider.HUGGINGFACE: {'huggingfaceBaseUrl'},
|
49
|
+
ModelProvider.MISTRAL_AI: {'mistral_fim_completion'},
|
50
|
+
# ModelProvider.STABILITY_AI: {'stability_client_id', 'stability_client_user_id', 'stability_client_version'},
|
51
|
+
ModelProvider.WATSONX: {'watsonx_version', 'watsonx_space_id', 'watsonx_project_id', 'api_key', 'watsonx_deployment_id', 'watsonx_cpd_url', 'watsonx_cpd_username', 'watsonx_cpd_password'},
|
52
|
+
|
53
|
+
# 'palm': _bpp('PALM', {}),
|
54
|
+
# 'nomic': _bpp('NOMIC', {}),
|
55
|
+
# 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
|
56
|
+
# 'segmind': _bpp('SEGMIND', {}),
|
57
|
+
# 'deepinfra': _bpp('DEEPINFRA', {}),
|
58
|
+
# 'novita-ai': _bpp('NOVITA_AI', {}),
|
59
|
+
# 'fireworks-ai': _bpp('FIREWORKS',{
|
60
|
+
# 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
|
61
|
+
# }),
|
62
|
+
# 'deepseek': _bpp('DEEPSEEK', {}),
|
63
|
+
# 'voyage': _bpp('VOYAGE', {}),
|
64
|
+
# 'moonshot': _bpp('MOONSHOT', {}),
|
65
|
+
# 'lingyi': _bpp('LINGYI', {}),
|
66
|
+
# 'zhipu': _bpp('ZHIPU', {}),
|
67
|
+
# 'monsterapi': _bpp('MONSTERAPI', {}),
|
68
|
+
# 'predibase': _bpp('PREDIBASE', {}),
|
69
|
+
|
70
|
+
# 'github': _bpp('GITHUB', {}),
|
71
|
+
# 'deepbricks': _bpp('DEEPBRICKS', {}),
|
72
|
+
# 'sagemaker': _bpp('AMZN_SAGEMAKER', {
|
73
|
+
# 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
|
74
|
+
# 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
|
75
|
+
# 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
|
76
|
+
# 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
|
77
|
+
# 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
|
78
|
+
# 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
|
79
|
+
# 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
|
80
|
+
# 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
|
81
|
+
# 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
|
82
|
+
# }),
|
83
|
+
# '@cf': _bpp('WORKERS_AI', { # workers ai
|
84
|
+
# 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
|
85
|
+
# }),
|
86
|
+
# 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
|
87
|
+
# 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
|
88
|
+
# })
|
89
|
+
}
|
90
|
+
|
91
|
+
PROVIDER_REQUIRED_FIELDS = {k:['api_key'] for k in ModelProvider}
|
92
|
+
# Update required fields for each provider
|
93
|
+
# Use sets to denote when a requirement is 'or'
|
94
|
+
PROVIDER_REQUIRED_FIELDS.update({
|
95
|
+
ModelProvider.WATSONX: PROVIDER_REQUIRED_FIELDS[ModelProvider.WATSONX] + [{'watsonx_space_id', 'watsonx_project_id', 'watsonx_deployment_id'}],
|
96
|
+
ModelProvider.OLLAMA: PROVIDER_REQUIRED_FIELDS[ModelProvider.OLLAMA] + ['custom_host']
|
97
|
+
})
|
98
|
+
|
99
|
+
# def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
|
100
|
+
# provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
|
101
|
+
# if provider not in ModelProvider:
|
102
|
+
# logger.error(f"Unsupported model provider {provider}")
|
103
|
+
# sys.exit(1)
|
104
|
+
|
105
|
+
# values = dotenv_values(str(env_file_path))
|
106
|
+
|
107
|
+
# if values is None:
|
108
|
+
# logger.error(f"No provider configuration in env file {env_file_path}")
|
109
|
+
# sys.exit(1)
|
110
|
+
|
111
|
+
# cfg = ProviderConfig()
|
112
|
+
# cfg.provider = PROVIDER_LUT[provider]
|
113
|
+
|
114
|
+
# cred_lut = PROVIDER_PROPERTIES_LUT[provider]
|
115
|
+
|
116
|
+
|
117
|
+
# consumed_credentials = []
|
118
|
+
# # Ollama requires some apikey but its content don't matter
|
119
|
+
# # Default it to 'ollama' to avoid users needing to specify
|
120
|
+
# if cfg.provider == ModelProvider.OLLAMA:
|
121
|
+
# consumed_credentials.append('api_key')
|
122
|
+
# setattr(cfg, 'api_key', ModelProvider.OLLAMA)
|
123
|
+
|
124
|
+
# for key, value in values.items():
|
125
|
+
# if key in cred_lut:
|
126
|
+
# k = cred_lut[key]
|
127
|
+
# consumed_credentials.append(k)
|
128
|
+
# setattr(cfg, k, value)
|
129
|
+
|
130
|
+
# return cfg
|
131
|
+
|
132
|
+
def _validate_provider(provider: str | ModelProvider) -> None:
|
133
|
+
if not ModelProvider.has_value(provider):
|
134
|
+
logger.error(f"Unsupported model provider {provider}")
|
135
|
+
sys.exit(1)
|
136
|
+
|
137
|
+
def _validate_extra_fields(provider: ModelProvider, cfg: ProviderConfig) -> None:
|
138
|
+
accepted_fields = _BASIC_PROVIDER_CONFIG_KEYS.copy()
|
139
|
+
extra_accepted_fields = PROVIDER_EXTRA_PROPERTIES_LUT.get(provider)
|
140
|
+
if extra_accepted_fields:
|
141
|
+
accepted_fields = accepted_fields.union(extra_accepted_fields)
|
142
|
+
|
143
|
+
for attr in cfg.__dict__:
|
144
|
+
if attr.startswith("__"):
|
145
|
+
continue
|
146
|
+
|
147
|
+
if cfg.__dict__.get(attr) is not None and attr not in accepted_fields:
|
148
|
+
logger.warning(f"The config option '{attr}' is not used by provider '{provider}'")
|
149
|
+
|
150
|
+
def _validate_requirements(provider: ModelProvider, cfg: ProviderConfig, app_id: str = None) -> None:
|
151
|
+
provided_credentials = set([k for k,v in dict(cfg).items() if v is not None])
|
152
|
+
required_creds = PROVIDER_REQUIRED_FIELDS[provider]
|
153
|
+
missing_credentials = []
|
154
|
+
for cred in required_creds:
|
155
|
+
if isinstance(cred, set):
|
156
|
+
if not any(c in provided_credentials for c in cred):
|
157
|
+
missing_credentials.append(cred)
|
158
|
+
else:
|
159
|
+
if cred not in provided_credentials:
|
160
|
+
missing_credentials.append(cred)
|
161
|
+
|
162
|
+
if len(missing_credentials) > 0:
|
163
|
+
if not app_id:
|
164
|
+
missing_credentials_string = f"Missing configuration variable(s) required for the provider {provider}:"
|
165
|
+
else:
|
166
|
+
missing_credentials_string = f"The following configuration variable(s) for the provider {provider} are not in the spec provider_config:"
|
167
|
+
for cred in missing_credentials:
|
168
|
+
if isinstance(cred, set):
|
169
|
+
cred_str = ' or '.join(list(cred))
|
170
|
+
else:
|
171
|
+
cred_str = cred
|
172
|
+
missing_credentials_string += f"\n\t - {cred_str}"
|
173
|
+
|
174
|
+
if not app_id:
|
175
|
+
logger.error(missing_credentials_string)
|
176
|
+
logger.error("Please provide the above values in the provider config. For secret values (e.g. 'api_key') create a key_value connection `orchestrate connections add` then bind it to the model with `--app-id`")
|
177
|
+
sys.exit(1)
|
178
|
+
else:
|
179
|
+
logger.info(missing_credentials_string)
|
180
|
+
logger.info(f"Please ensure these values are set in the connection '{app_id}'.")
|
181
|
+
|
182
|
+
|
183
|
+
def validate_ProviderConfig(cfg: ProviderConfig, app_id: str)-> None:
|
184
|
+
if not cfg:
|
185
|
+
return
|
186
|
+
|
187
|
+
provider = cfg.provider
|
188
|
+
|
189
|
+
_validate_provider(provider)
|
190
|
+
_validate_extra_fields(provider=provider, cfg=cfg)
|
191
|
+
_validate_requirements(provider=provider, cfg=cfg, app_id=app_id)
|
@@ -1,32 +1,20 @@
|
|
1
1
|
import logging
|
2
|
-
import os
|
3
|
-
import sys
|
4
2
|
from typing import List
|
3
|
+
import json
|
4
|
+
import sys
|
5
5
|
|
6
|
-
import requests
|
7
|
-
import rich
|
8
|
-
import rich.highlighter
|
9
6
|
import typer
|
10
7
|
from typing_extensions import Annotated
|
11
8
|
|
12
|
-
from ibm_watsonx_orchestrate.
|
13
|
-
from ibm_watsonx_orchestrate.
|
14
|
-
from ibm_watsonx_orchestrate.
|
15
|
-
|
16
|
-
from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
|
17
|
-
from ibm_watsonx_orchestrate.client.models.types import CreateVirtualModel, ModelType, ANTHROPIC_DEFAULT_MAX_TOKENS
|
18
|
-
from ibm_watsonx_orchestrate.client.utils import instantiate_client
|
9
|
+
from ibm_watsonx_orchestrate.agent_builder.models.types import ModelType
|
10
|
+
from ibm_watsonx_orchestrate.agent_builder.model_policies.types import ModelPolicyStrategyMode
|
11
|
+
from ibm_watsonx_orchestrate.cli.commands.models.models_controller import ModelsController
|
12
|
+
|
19
13
|
|
20
14
|
logger = logging.getLogger(__name__)
|
21
15
|
models_app = typer.Typer(no_args_is_help=True)
|
22
16
|
models_policy_app = typer.Typer(no_args_is_help=True)
|
23
|
-
|
24
|
-
|
25
|
-
WATSONX_URL = os.getenv("WATSONX_URL")
|
26
|
-
|
27
|
-
class ModelHighlighter(rich.highlighter.RegexHighlighter):
|
28
|
-
base_style = "model."
|
29
|
-
highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
|
17
|
+
models_app.add_typer(models_policy_app, name='policy', help='Add or remove pseudo models which route traffic between multiple downstream models')
|
30
18
|
|
31
19
|
@models_app.command(name="list", help="List available models")
|
32
20
|
def model_list(
|
@@ -35,108 +23,33 @@ def model_list(
|
|
35
23
|
typer.Option("--raw", "-r", help="Display the list of models in a non-tabular format"),
|
36
24
|
] = False,
|
37
25
|
):
|
38
|
-
|
39
|
-
|
40
|
-
global WATSONX_URL
|
41
|
-
default_env_path = get_default_env_file()
|
42
|
-
merged_env_dict = merge_env(
|
43
|
-
default_env_path,
|
44
|
-
None
|
45
|
-
)
|
46
|
-
|
47
|
-
if 'WATSONX_URL' in merged_env_dict and merged_env_dict['WATSONX_URL']:
|
48
|
-
WATSONX_URL = merged_env_dict['WATSONX_URL']
|
49
|
-
|
50
|
-
watsonx_url = merged_env_dict.get("WATSONX_URL")
|
51
|
-
if not watsonx_url:
|
52
|
-
logger.error("Error: WATSONX_URL is required in the environment.")
|
53
|
-
sys.exit(1)
|
54
|
-
|
55
|
-
logger.info("Retrieving virtual-model models list...")
|
56
|
-
virtual_models = models_client.list()
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
logger.info("Retrieving virtual-policies models list...")
|
61
|
-
virtual_model_policies = model_policies_client.list()
|
62
|
-
|
63
|
-
logger.info("Retrieving watsonx.ai models list...")
|
64
|
-
found_models = _get_wxai_foundational_models()
|
65
|
-
|
66
|
-
|
67
|
-
preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
|
68
|
-
incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
|
69
|
-
|
70
|
-
preferred_list = _string_to_list(preferred_str)
|
71
|
-
incompatible_list = _string_to_list(incompatible_str)
|
72
|
-
|
73
|
-
models = found_models.get("resources", [])
|
74
|
-
if not models:
|
75
|
-
logger.error("No models found.")
|
76
|
-
else:
|
77
|
-
# Remove incompatible models
|
78
|
-
filtered_models = []
|
79
|
-
for model in models:
|
80
|
-
model_id = model.get("model_id", "")
|
81
|
-
short_desc = model.get("short_description", "")
|
82
|
-
if any(incomp in model_id.lower() for incomp in incompatible_list):
|
83
|
-
continue
|
84
|
-
if any(incomp in short_desc.lower() for incomp in incompatible_list):
|
85
|
-
continue
|
86
|
-
filtered_models.append(model)
|
87
|
-
|
88
|
-
# Sort to put the preferred first
|
89
|
-
def sort_key(model):
|
90
|
-
model_id = model.get("model_id", "").lower()
|
91
|
-
is_preferred = any(pref in model_id for pref in preferred_list)
|
92
|
-
return (0 if is_preferred else 1, model_id)
|
93
|
-
|
94
|
-
sorted_models = sorted(filtered_models, key=sort_key)
|
95
|
-
|
96
|
-
if print_raw:
|
97
|
-
theme = rich.theme.Theme({"model.name": "bold cyan"})
|
98
|
-
console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
|
99
|
-
console.print("[bold]Available Models:[/bold]")
|
100
|
-
|
101
|
-
for model in virtual_models:
|
102
|
-
console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
|
103
|
-
|
104
|
-
for model in virtual_model_policies:
|
105
|
-
console.print(f"- ✨️ {model.name}:", 'No description provided.')
|
106
|
-
|
107
|
-
for model in sorted_models:
|
108
|
-
model_id = model.get("model_id", "N/A")
|
109
|
-
short_desc = model.get("short_description", "No description provided.")
|
110
|
-
full_model_name = f"watsonx/{model_id}: {short_desc}"
|
111
|
-
marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
|
112
|
-
console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
|
113
|
-
|
114
|
-
console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
|
115
|
-
else:
|
116
|
-
table = rich.table.Table(
|
117
|
-
show_header=True,
|
118
|
-
title="[bold]Available Models[/bold]",
|
119
|
-
caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
|
120
|
-
show_lines=True)
|
121
|
-
columns = ["Model", "Description"]
|
122
|
-
for col in columns:
|
123
|
-
table.add_column(col)
|
124
|
-
|
125
|
-
for model in virtual_models:
|
126
|
-
table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
|
127
|
-
|
128
|
-
for model in virtual_model_policies:
|
129
|
-
table.add_row(f"✨️ {model.name}", 'No description provided.')
|
130
|
-
|
131
|
-
for model in sorted_models:
|
132
|
-
model_id = model.get("model_id", "N/A")
|
133
|
-
short_desc = model.get("short_description", "No description provided.")
|
134
|
-
marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
|
135
|
-
table.add_row(f"[yellow]{marker}[/yellow]watsonx/{model_id}", short_desc)
|
136
|
-
|
137
|
-
rich.print(table)
|
138
|
-
|
26
|
+
models_controller = ModelsController()
|
27
|
+
models_controller.list_models(print_raw=print_raw)
|
139
28
|
|
29
|
+
@models_app.command(name="import", help="Import models from spec file")
|
30
|
+
def models_import(
|
31
|
+
file: Annotated[
|
32
|
+
str,
|
33
|
+
typer.Option(
|
34
|
+
"--file",
|
35
|
+
"-f",
|
36
|
+
help="Path to spec file containing model details.",
|
37
|
+
),
|
38
|
+
],
|
39
|
+
app_id: Annotated[
|
40
|
+
str, typer.Option(
|
41
|
+
'--app-id', '-a',
|
42
|
+
help='The app id of a key_value connection containing authentications details for the model provider.'
|
43
|
+
)
|
44
|
+
] = None,
|
45
|
+
):
|
46
|
+
models_controller = ModelsController()
|
47
|
+
models = models_controller.import_model(
|
48
|
+
file=file,
|
49
|
+
app_id=app_id
|
50
|
+
)
|
51
|
+
for model in models:
|
52
|
+
models_controller.publish_or_update_models(model=model)
|
140
53
|
|
141
54
|
@models_app.command(name="add", help="Add an llm from a custom provider")
|
142
55
|
def models_add(
|
@@ -144,10 +57,6 @@ def models_add(
|
|
144
57
|
str,
|
145
58
|
typer.Option("--name", "-n", help="The name of the model to add"),
|
146
59
|
],
|
147
|
-
env_file: Annotated[
|
148
|
-
str,
|
149
|
-
typer.Option('--env-file', '-e', help='The path to an .env file containing the credentials for your llm provider'),
|
150
|
-
],
|
151
60
|
description: Annotated[
|
152
61
|
str,
|
153
62
|
typer.Option('--description', '-d', help='The description of the model to add'),
|
@@ -156,38 +65,42 @@ def models_add(
|
|
156
65
|
str,
|
157
66
|
typer.Option('--display-name', help='What name should this llm appear as within the ui'),
|
158
67
|
] = None,
|
68
|
+
provider_config: Annotated[
|
69
|
+
str,
|
70
|
+
typer.Option(
|
71
|
+
"--provider-config",
|
72
|
+
help="LLM provider configuration in JSON format (e.g., '{\"customHost\": \"xyz\"}')",
|
73
|
+
),
|
74
|
+
] = None,
|
75
|
+
app_id: Annotated[
|
76
|
+
str, typer.Option(
|
77
|
+
'--app-id', '-a',
|
78
|
+
help='The app id of a key_value connection containing authentications details for the model provider.'
|
79
|
+
)
|
80
|
+
] = None,
|
159
81
|
type: Annotated[
|
160
82
|
ModelType,
|
161
83
|
typer.Option('--type', help='What type of model is it'),
|
162
84
|
] = ModelType.CHAT,
|
163
|
-
|
164
85
|
):
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
config = {
|
176
|
-
"max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
|
177
|
-
}
|
178
|
-
|
179
|
-
model = CreateVirtualModel(
|
86
|
+
provider_config_dict = {}
|
87
|
+
if provider_config:
|
88
|
+
try:
|
89
|
+
provider_config_dict = json.loads(provider_config)
|
90
|
+
except:
|
91
|
+
logger.error(f"Failed to parse provider config. '{provider_config}' is not valid json")
|
92
|
+
sys.exit(1)
|
93
|
+
|
94
|
+
models_controller = ModelsController()
|
95
|
+
model = models_controller.create_model(
|
180
96
|
name=name,
|
181
|
-
display_name=display_name or name,
|
182
97
|
description=description,
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
98
|
+
display_name=display_name,
|
99
|
+
provider_config_dict = provider_config_dict,
|
100
|
+
model_type=type,
|
101
|
+
app_id=app_id,
|
187
102
|
)
|
188
|
-
|
189
|
-
models_client.create(model)
|
190
|
-
logger.info(f"Successfully added the model '{name}'")
|
103
|
+
models_controller.publish_or_update_models(model=model)
|
191
104
|
|
192
105
|
|
193
106
|
|
@@ -198,122 +111,86 @@ def models_remove(
|
|
198
111
|
typer.Option("--name", "-n", help="The name of the model to remove"),
|
199
112
|
]
|
200
113
|
):
|
201
|
-
|
202
|
-
|
203
|
-
model = next(filter(lambda x: x.name == name or x.name == f"virtual-model/{name}", models), None)
|
204
|
-
if not model:
|
205
|
-
logger.error(f"No model found with the name '{name}'")
|
206
|
-
sys.exit(1)
|
207
|
-
|
208
|
-
models_client.delete(model_id=model.id)
|
209
|
-
logger.info(f"Successfully removed the model '{name}'")
|
210
|
-
|
211
|
-
|
212
|
-
# @models_policy_app.command(name='add', help='Add a model policy')
|
213
|
-
# def models_policy_add(
|
214
|
-
# name: Annotated[
|
215
|
-
# str,
|
216
|
-
# typer.Option("--name", "-n", help="The name of the model to remove"),
|
217
|
-
# ],
|
218
|
-
# models: Annotated[
|
219
|
-
# List[str],
|
220
|
-
# typer.Option('--model', '-m', help='The name of the model to add'),
|
221
|
-
# ],
|
222
|
-
# strategy: Annotated[
|
223
|
-
# ModelPolicyStrategyMode,
|
224
|
-
# typer.Option('--strategy', '-s', help='How to spread traffic across models'),
|
225
|
-
# ],
|
226
|
-
# strategy_on_code: Annotated[
|
227
|
-
# List[int],
|
228
|
-
# typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
|
229
|
-
# ],
|
230
|
-
# retry_on_code: Annotated[
|
231
|
-
# List[int],
|
232
|
-
# typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
|
233
|
-
# ],
|
234
|
-
# retry_attempts: Annotated[
|
235
|
-
# int,
|
236
|
-
# typer.Option('--retry-attempts', help='The number of attempts to retry'),
|
237
|
-
# ],
|
238
|
-
# display_name: Annotated[
|
239
|
-
# str,
|
240
|
-
# typer.Option('--display-name', help='What name should this llm appear as within the ui'),
|
241
|
-
# ] = None
|
242
|
-
# ):
|
243
|
-
# model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
|
244
|
-
# model_client: ModelsClient = instantiate_client(ModelsClient)
|
245
|
-
# model_lut = {m.name: m.id for m in model_client.list()}
|
246
|
-
# for m in models:
|
247
|
-
# if m not in model_lut:
|
248
|
-
# logger.error(f"No model found with the name '{m}'")
|
249
|
-
# exit(1)
|
250
|
-
|
251
|
-
# inner = ModelPolicyInner()
|
252
|
-
# inner.strategy = ModelPolicyStrategy(
|
253
|
-
# mode=strategy,
|
254
|
-
# on_status_codes=strategy_on_code
|
255
|
-
# )
|
256
|
-
# inner.targets = [ModelPolicyTarget(model_id=model_lut[m], weight=1) for m in models]
|
257
|
-
# if retry_on_code:
|
258
|
-
# inner.retry = ModelPolicyRetry(
|
259
|
-
# on_status_codes=retry_on_code,
|
260
|
-
# attempts=retry_attempts
|
261
|
-
# )
|
262
|
-
|
263
|
-
# if not display_name:
|
264
|
-
# display_name = name
|
265
|
-
|
266
|
-
|
267
|
-
# policy = ModelPolicy(
|
268
|
-
# name=name,
|
269
|
-
# display_name=display_name,
|
270
|
-
# policy=inner
|
271
|
-
# )
|
272
|
-
# model_policies_client.create(policy)
|
273
|
-
# logger.info(f"Successfully added the model policy '{name}'")
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
# @models_policy_app.command(name='remove', help='Remove a model policy')
|
278
|
-
# def models_policy_remove(
|
279
|
-
# name: Annotated[
|
280
|
-
# str,
|
281
|
-
# typer.Option("--name", "-n", help="The name of the model policy to remove"),
|
282
|
-
# ]
|
283
|
-
# ):
|
284
|
-
# model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
|
285
|
-
# model_policies = model_policies_client.list()
|
286
|
-
|
287
|
-
# policy = next(filter(lambda x: x.name == name or x.name == f"virtual-policy/{name}", model_policies), None)
|
288
|
-
# if not policy:
|
289
|
-
# logger.error(f"No model found with the name '{name}'")
|
290
|
-
# exit(1)
|
291
|
-
|
292
|
-
# model_policies_client.delete(model_policy_id=policy.id)
|
293
|
-
# logger.info(f"Successfully removed the model '{name}'")
|
114
|
+
models_controller = ModelsController()
|
115
|
+
models_controller.remove_model(name=name)
|
294
116
|
|
117
|
+
@models_policy_app.command(name='import', help='Add a model policy')
|
118
|
+
def models_policy_import(
|
119
|
+
file: Annotated[
|
120
|
+
str,
|
121
|
+
typer.Option(
|
122
|
+
"--file",
|
123
|
+
"-f",
|
124
|
+
help="Path to spec file containing model details.",
|
125
|
+
),
|
126
|
+
],
|
127
|
+
):
|
128
|
+
models_controller = ModelsController()
|
129
|
+
policies = models_controller.import_model_policy(
|
130
|
+
file=file
|
131
|
+
)
|
132
|
+
for policy in policies:
|
133
|
+
models_controller.publish_or_update_model_policies(policy=policy)
|
295
134
|
|
296
|
-
|
297
|
-
|
135
|
+
@models_policy_app.command(name='add', help='Add a model policy')
|
136
|
+
def models_policy_add(
|
137
|
+
name: Annotated[
|
138
|
+
str,
|
139
|
+
typer.Option("--name", "-n", help="The name of the model to remove"),
|
140
|
+
],
|
141
|
+
models: Annotated[
|
142
|
+
List[str],
|
143
|
+
typer.Option('--model', '-m', help='The name of the model to add'),
|
144
|
+
],
|
145
|
+
strategy: Annotated[
|
146
|
+
ModelPolicyStrategyMode,
|
147
|
+
typer.Option('--strategy', '-s', help='How to spread traffic across models'),
|
148
|
+
],
|
149
|
+
strategy_on_code: Annotated[
|
150
|
+
List[int],
|
151
|
+
typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
|
152
|
+
],
|
153
|
+
retry_on_code: Annotated[
|
154
|
+
List[int],
|
155
|
+
typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
|
156
|
+
],
|
157
|
+
retry_attempts: Annotated[
|
158
|
+
int,
|
159
|
+
typer.Option('--retry-attempts', help='The number of attempts to retry'),
|
160
|
+
],
|
161
|
+
display_name: Annotated[
|
162
|
+
str,
|
163
|
+
typer.Option('--display-name', help='What name should this llm appear as within the ui'),
|
164
|
+
] = None,
|
165
|
+
description: Annotated[
|
166
|
+
str,
|
167
|
+
typer.Option('--description', help='Description of the policy for display in the ui'),
|
168
|
+
] = None
|
169
|
+
):
|
170
|
+
models_controller = ModelsController()
|
171
|
+
policy = models_controller.create_model_policy(
|
172
|
+
name=name,
|
173
|
+
models=models,
|
174
|
+
strategy=strategy,
|
175
|
+
strategy_on_code=strategy_on_code,
|
176
|
+
retry_on_code=retry_on_code,
|
177
|
+
retry_attempts=retry_attempts,
|
178
|
+
display_name=display_name,
|
179
|
+
description=description
|
180
|
+
)
|
181
|
+
models_controller.publish_or_update_model_policies(policy=policy)
|
298
182
|
|
299
|
-
try:
|
300
|
-
response = requests.get(foundation_models_url)
|
301
|
-
except requests.exceptions.RequestException as e:
|
302
|
-
logger.exception(f"Exception when connecting to Watsonx URL: {foundation_models_url}")
|
303
|
-
raise
|
304
183
|
|
305
|
-
if response.status_code != 200:
|
306
|
-
error_message = (
|
307
|
-
f"Failed to retrieve foundational models from {foundation_models_url}. "
|
308
|
-
f"Status code: {response.status_code}. Response: {response.content}"
|
309
|
-
)
|
310
|
-
raise Exception(error_message)
|
311
|
-
|
312
|
-
json_response = response.json()
|
313
|
-
return json_response
|
314
184
|
|
315
|
-
|
316
|
-
|
185
|
+
@models_policy_app.command(name='remove', help='Remove a model policy')
|
186
|
+
def models_policy_remove(
|
187
|
+
name: Annotated[
|
188
|
+
str,
|
189
|
+
typer.Option("--name", "-n", help="The name of the model policy to remove"),
|
190
|
+
]
|
191
|
+
):
|
192
|
+
models_controller = ModelsController()
|
193
|
+
models_controller.remove_policy(name=name)
|
317
194
|
|
318
195
|
if __name__ == "__main__":
|
319
196
|
models_app()
|