ibm-watsonx-orchestrate 1.3.0__py3-none-any.whl → 1.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -1
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +2 -0
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +9 -2
- ibm_watsonx_orchestrate/agent_builder/toolkits/base_toolkit.py +32 -0
- ibm_watsonx_orchestrate/agent_builder/toolkits/types.py +42 -0
- ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +10 -1
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +4 -2
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +2 -1
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +29 -0
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +271 -12
- ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +17 -2
- ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +180 -0
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +194 -8
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +117 -48
- ibm_watsonx_orchestrate/cli/commands/server/types.py +105 -0
- ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +55 -7
- ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_controller.py +123 -42
- ibm_watsonx_orchestrate/cli/commands/tools/tools_command.py +22 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +197 -12
- ibm_watsonx_orchestrate/client/agents/agent_client.py +4 -1
- ibm_watsonx_orchestrate/client/agents/assistant_agent_client.py +5 -1
- ibm_watsonx_orchestrate/client/agents/external_agent_client.py +5 -1
- ibm_watsonx_orchestrate/client/analytics/llm/analytics_llm_client.py +2 -6
- ibm_watsonx_orchestrate/client/base_api_client.py +5 -2
- ibm_watsonx_orchestrate/client/connections/connections_client.py +3 -9
- ibm_watsonx_orchestrate/client/model_policies/__init__.py +0 -0
- ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +47 -0
- ibm_watsonx_orchestrate/client/model_policies/types.py +36 -0
- ibm_watsonx_orchestrate/client/models/__init__.py +0 -0
- ibm_watsonx_orchestrate/client/models/models_client.py +46 -0
- ibm_watsonx_orchestrate/client/models/types.py +177 -0
- ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +15 -6
- ibm_watsonx_orchestrate/client/tools/tempus_client.py +40 -0
- ibm_watsonx_orchestrate/client/tools/tool_client.py +8 -0
- ibm_watsonx_orchestrate/docker/compose-lite.yml +68 -13
- ibm_watsonx_orchestrate/docker/default.env +22 -12
- ibm_watsonx_orchestrate/docker/tempus/common-config.yaml +1 -1
- ibm_watsonx_orchestrate/experimental/flow_builder/__init__.py +0 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +41 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +17 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +91 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +143 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/events.py +72 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +1288 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/node.py +97 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +98 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/types.py +492 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +113 -0
- ibm_watsonx_orchestrate/utils/utils.py +5 -2
- {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/METADATA +4 -1
- {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/RECORD +54 -32
- {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,15 @@ def get_file_name(path: str):
|
|
49
49
|
# return to_column_name(path.split("/")[-1].split(".")[0])
|
50
50
|
return path.split("/")[-1]
|
51
51
|
|
52
|
+
def get_relative_file_path(path, dir):
|
53
|
+
if path.startswith("/"):
|
54
|
+
return path
|
55
|
+
elif path.startswith("./"):
|
56
|
+
return f"{dir}{path.removeprefix('.')}"
|
57
|
+
else:
|
58
|
+
return f"{dir}/{path}"
|
59
|
+
|
60
|
+
|
52
61
|
class KnowledgeBaseController:
|
53
62
|
def __init__(self):
|
54
63
|
self.client = None
|
@@ -68,8 +77,9 @@ class KnowledgeBaseController:
|
|
68
77
|
kb.validate_documents_or_index_exists()
|
69
78
|
if kb.documents:
|
70
79
|
file_dir = "/".join(file.split("/")[:-1])
|
71
|
-
files = [('files', (get_file_name(file_path), open(
|
80
|
+
files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in kb.documents]
|
72
81
|
|
82
|
+
kb.prioritize_built_in_index = True
|
73
83
|
payload = kb.model_dump(exclude_none=True);
|
74
84
|
payload.pop('documents');
|
75
85
|
|
@@ -78,6 +88,7 @@ class KnowledgeBaseController:
|
|
78
88
|
if len(kb.conversational_search_tool.index_config) != 1:
|
79
89
|
raise ValueError(f"Must provide exactly one conversational_search_tool.index_config. Provided {len(kb.conversational_search_tool.index_config)}.")
|
80
90
|
|
91
|
+
|
81
92
|
if app_id:
|
82
93
|
connections_client = get_connections_client()
|
83
94
|
connection_id = None
|
@@ -90,6 +101,7 @@ class KnowledgeBaseController:
|
|
90
101
|
connection_id = connections.connection_id
|
91
102
|
kb.conversational_search_tool.index_config[0].connection_id = connection_id
|
92
103
|
|
104
|
+
kb.prioritize_built_in_index = False
|
93
105
|
client.create(payload=kb.model_dump(exclude_none=True))
|
94
106
|
|
95
107
|
logger.info(f"Successfully imported knowledge base '{kb.name}'")
|
@@ -126,13 +138,16 @@ class KnowledgeBaseController:
|
|
126
138
|
|
127
139
|
if update_request.documents:
|
128
140
|
file_dir = "/".join(file.split("/")[:-1])
|
129
|
-
files = [('files', (get_file_name(file_path), open(file_path
|
141
|
+
files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in update_request.documents]
|
130
142
|
|
143
|
+
update_request.prioritize_built_in_index = True
|
131
144
|
payload = update_request.model_dump(exclude_none=True);
|
132
145
|
payload.pop('documents');
|
133
146
|
|
134
147
|
self.get_client().update_with_documents(knowledge_base_id, payload=payload, files=files)
|
135
148
|
else:
|
149
|
+
if update_request.conversational_search_tool and update_request.conversational_search_tool.index_config:
|
150
|
+
update_request.prioritize_built_in_index = False
|
136
151
|
self.get_client().update(knowledge_base_id, update_request.model_dump(exclude_none=True))
|
137
152
|
|
138
153
|
logEnding = f"with ID '{id}'" if id else f"'{name}'"
|
@@ -0,0 +1,180 @@
|
|
1
|
+
# https://portkey.ai/
|
2
|
+
from dotenv import dotenv_values
|
3
|
+
|
4
|
+
from ibm_watsonx_orchestrate.client.models.types import ProviderConfig, ModelProvider
|
5
|
+
|
6
|
+
def _bpp(name: str, provider_config=None) -> dict:
|
7
|
+
if provider_config is None:
|
8
|
+
provider_config = {}
|
9
|
+
|
10
|
+
name = name.upper()
|
11
|
+
|
12
|
+
provider_config.update({
|
13
|
+
f"{name}_API_KEY": 'api_key',
|
14
|
+
f"{name}_CUSTOM_HOST": 'custom_host',
|
15
|
+
f"{name}_URL_TO_FETCH": 'url_to_fetch',
|
16
|
+
f"{name}_FORWARD_HEADERS": lambda value, config: setattr(config, 'forward_headers', list(map(lambda v: v.strip(), value.split(';')))),
|
17
|
+
f"{name}_REQUEST_TIMEOUT": lambda value, config: setattr(config, 'request_timeout', int(value)),
|
18
|
+
f"{name}_TRANSFORM_TO_FORM_DATA": 'transform_to_form_data'
|
19
|
+
})
|
20
|
+
|
21
|
+
return provider_config
|
22
|
+
|
23
|
+
# model provider prefix => ENV_VAR key => provider config key
|
24
|
+
PROVIDER_PROPERTIES_LUT = {
|
25
|
+
ModelProvider.OPENAI: _bpp('OPENAI', {}),
|
26
|
+
# ModelProvider.A21: _bpp('A21', {}),
|
27
|
+
ModelProvider.ANTHROPIC: _bpp('ANTHROPIC', {
|
28
|
+
'ANTHROPIC_BETA': 'anthropic_beta',
|
29
|
+
'ANTHROPIC_VERSION': 'anthropic_version'
|
30
|
+
}),
|
31
|
+
# ModelProvider.ANYSCALE: _bpp('ANYSCALE', {}),
|
32
|
+
# ModelProvider.AZURE_OPENAI: _bpp('AZURE_OPENAI', {}),
|
33
|
+
# ModelProvider.AZURE_AI: _bpp('AZURE', {
|
34
|
+
# 'AZURE_AI_RESOURCE_NAME': 'azure_resource_name',
|
35
|
+
# 'AZURE_AI_DEPLOYMENT_ID': 'azure_deployment_id',
|
36
|
+
# 'AZURE_AI_API_VERSION': 'azure_api_version',
|
37
|
+
# 'AZURE_AI_AD_AUTH': 'ad_auth',
|
38
|
+
# 'AZURE_AI_AUTH_MODE': 'azure_auth_mode',
|
39
|
+
# 'AZURE_AI_MANAGED_CLIENT_ID': 'azure_managed_client_id',
|
40
|
+
# 'AZURE_AI_ENTRA_CLIENT_ID': 'azure_entra_client_id',
|
41
|
+
# 'AZURE_AI_ENTRA_CLIENT_SECRET': 'azure_entra_client_secret',
|
42
|
+
# 'AZURE_AI_ENTRA_TENANT_ID': 'azure_entra_tenant_id',
|
43
|
+
# 'AZURE_AI_AD_TOKEN': 'azure_ad_token',
|
44
|
+
# 'AZURE_AI_MODEL_NAME': 'azure_model_name',
|
45
|
+
# }),
|
46
|
+
# ModelProvider.BEDROCK: _bpp('BEDROCK', {
|
47
|
+
# 'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key',
|
48
|
+
# 'AWS_ACCESS_KEY_ID': 'aws_access_key_id',
|
49
|
+
# 'AWS_SESSION_TOKEN': 'aws_session_token',
|
50
|
+
# 'AWS_REGION': 'aws_region',
|
51
|
+
# 'AWS_AUTH_TYPE': 'aws_auth_type',
|
52
|
+
# 'AWS_ROLE_ARN': 'aws_role_arn',
|
53
|
+
# 'AWS_EXTERNAL_ID': 'aws_external_id',
|
54
|
+
# 'AWS_S3_BUCKET': 'aws_s3_bucket',
|
55
|
+
# 'AWS_S3_OBJECT_KEY': 'aws_s3_object_key',
|
56
|
+
# 'AWS_BEDROCK_MODEL': 'aws_bedrock_model',
|
57
|
+
# 'AWS_SERVER_SIDE_ENCRYPTION': 'aws_server_side_encryption',
|
58
|
+
# 'AWS_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID': 'aws_server_side_encryption_kms_key_id',
|
59
|
+
# }),
|
60
|
+
# ModelProvider.CEREBRAS: _bpp('COHERE', {}),
|
61
|
+
# ModelProvider.COHERE: _bpp('COHERE', {}),
|
62
|
+
ModelProvider.GOOGLE: _bpp('GOOGLE', {}),
|
63
|
+
# ModelProvider.VERTEX_AI: _bpp('GOOGLE_VERTEX_AI', {
|
64
|
+
# 'GOOGLE_VERTEX_AI_REGION': 'vertex_region',
|
65
|
+
# 'GOOGLE_VERTEX_AI_PROJECT_ID': 'vertex_project_id',
|
66
|
+
# 'GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_JSON': 'vertex_service_account_json',
|
67
|
+
# 'GOOGLE_VERTEX_AI_STORAGE_BUCKET_NAME': 'vertex_storage_bucket_name',
|
68
|
+
# 'GOOGLE_VERTEX_AI_MODEL_NAME': 'vertex_model_name',
|
69
|
+
# 'GOOGLE_VERTEX_AI_FILENAME': 'filename'
|
70
|
+
# }),
|
71
|
+
# ModelProvider.GROQ: _bpp('GROQ', {}),
|
72
|
+
# ModelProvider.HUGGINGFACE: _bpp('HUGGINGFACE', {
|
73
|
+
# 'HUGGINGFACE_BASE_URL': 'huggingfaceBaseUrl'
|
74
|
+
# }),
|
75
|
+
ModelProvider.MISTRAL_AI: _bpp('MISTRAL', {
|
76
|
+
'MISTRAL_FIM_COMPLETION': 'mistral_fim_completion'
|
77
|
+
}),
|
78
|
+
# ModelProvider.JINA: _bpp('JINA', {}),
|
79
|
+
ModelProvider.OLLAMA: _bpp('OLLAMA', {}),
|
80
|
+
ModelProvider.OPENROUTER: _bpp('OPENROUTER', {}),
|
81
|
+
# ModelProvider.STABILITY_AI: _bpp('STABILITY', {
|
82
|
+
# 'STABILITY_CLIENT_ID': 'stability_client_id',
|
83
|
+
# 'STABILITY_CLIENT_USER_ID': 'stability_client_user_id',
|
84
|
+
# 'STABILITY_CLIENT_VERSION': 'stability_client_version'
|
85
|
+
# }),
|
86
|
+
# ModelProvider.TOGETHER_AI: _bpp('TOGETHER_AI', {}),
|
87
|
+
ModelProvider.WATSONX: _bpp('WATSONX', {
|
88
|
+
'WATSONX_VERSION': 'watsonx_version',
|
89
|
+
'WATSONX_SPACE_ID': 'watsonx_space_id',
|
90
|
+
'WATSONX_PROJECT_ID': 'watsonx_project_id',
|
91
|
+
'WATSONX_APIKEY': 'api_key'
|
92
|
+
})
|
93
|
+
|
94
|
+
# 'palm': _bpp('PALM', {}),
|
95
|
+
# 'nomic': _bpp('NOMIC', {}),
|
96
|
+
# 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
|
97
|
+
# 'segmind': _bpp('SEGMIND', {}),
|
98
|
+
# 'deepinfra': _bpp('DEEPINFRA', {}),
|
99
|
+
# 'novita-ai': _bpp('NOVITA_AI', {}),
|
100
|
+
# 'fireworks-ai': _bpp('FIREWORKS',{
|
101
|
+
# 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
|
102
|
+
# }),
|
103
|
+
# 'deepseek': _bpp('DEEPSEEK', {}),
|
104
|
+
# 'voyage': _bpp('VOYAGE', {}),
|
105
|
+
# 'moonshot': _bpp('MOONSHOT', {}),
|
106
|
+
# 'lingyi': _bpp('LINGYI', {}),
|
107
|
+
# 'zhipu': _bpp('ZHIPU', {}),
|
108
|
+
# 'monsterapi': _bpp('MONSTERAPI', {}),
|
109
|
+
# 'predibase': _bpp('PREDIBASE', {}),
|
110
|
+
|
111
|
+
# 'github': _bpp('GITHUB', {}),
|
112
|
+
# 'deepbricks': _bpp('DEEPBRICKS', {}),
|
113
|
+
# 'sagemaker': _bpp('AMZN_SAGEMAKER', {
|
114
|
+
# 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
|
115
|
+
# 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
|
116
|
+
# 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
|
117
|
+
# 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
|
118
|
+
# 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
|
119
|
+
# 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
|
120
|
+
# 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
|
121
|
+
# 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
|
122
|
+
# 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
|
123
|
+
# }),
|
124
|
+
# '@cf': _bpp('WORKERS_AI', { # workers ai
|
125
|
+
# 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
|
126
|
+
# }),
|
127
|
+
# 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
|
128
|
+
# 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
|
129
|
+
# })
|
130
|
+
}
|
131
|
+
PROVIDER_PROPERTIES_RLUT= {}
|
132
|
+
for provider in PROVIDER_PROPERTIES_LUT.keys():
|
133
|
+
PROVIDER_PROPERTIES_RLUT[provider] = {v:k for k,v in PROVIDER_PROPERTIES_LUT[provider].items()}
|
134
|
+
|
135
|
+
PROVIDER_LUT = {k:k for k in PROVIDER_PROPERTIES_LUT.keys() }
|
136
|
+
PROVIDER_LUT.update({
|
137
|
+
# any overrides for the provider prefix to provider name can be provided here on PROVIDER_LUT
|
138
|
+
})
|
139
|
+
|
140
|
+
|
141
|
+
|
142
|
+
PROVIDER_REQUIRED_FIELDS = {k:['api_key'] if k not in ['ollama'] else [] for k in PROVIDER_PROPERTIES_LUT.keys()}
|
143
|
+
PROVIDER_REQUIRED_FIELDS.update({
|
144
|
+
# Mark the required fields for a provider
|
145
|
+
})
|
146
|
+
|
147
|
+
def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
|
148
|
+
provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
|
149
|
+
if provider not in PROVIDER_LUT:
|
150
|
+
raise ValueError(f"Unsupported model provider {provider}")
|
151
|
+
|
152
|
+
values = dotenv_values(str(env_file_path))
|
153
|
+
|
154
|
+
if values is None:
|
155
|
+
raise ValueError(f"No provider configuration in env file {env_file_path}")
|
156
|
+
|
157
|
+
cfg = ProviderConfig()
|
158
|
+
cfg.provider = PROVIDER_LUT[provider]
|
159
|
+
|
160
|
+
cred_lut = PROVIDER_PROPERTIES_LUT[provider]
|
161
|
+
|
162
|
+
|
163
|
+
consumed_credentials = []
|
164
|
+
for key, value in values.items():
|
165
|
+
if key in cred_lut:
|
166
|
+
k = cred_lut[key]
|
167
|
+
consumed_credentials.append(k)
|
168
|
+
setattr(cfg, k, value)
|
169
|
+
|
170
|
+
|
171
|
+
required_creds = PROVIDER_REQUIRED_FIELDS[provider]
|
172
|
+
missing_credentials = []
|
173
|
+
for cred in required_creds:
|
174
|
+
if cred not in consumed_credentials:
|
175
|
+
missing_credentials.append(cred)
|
176
|
+
|
177
|
+
if len(missing_credentials) > 0:
|
178
|
+
raise ValueError(f"Missing environment variable(s) {', '.join(map(lambda c: PROVIDER_PROPERTIES_RLUT[provider][c], missing_credentials))} required for the provider {provider}")
|
179
|
+
|
180
|
+
return cfg
|
@@ -1,30 +1,42 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
|
-
import requests
|
4
3
|
import sys
|
5
|
-
import
|
6
|
-
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
import requests
|
7
7
|
import rich
|
8
|
+
import rich.highlighter
|
8
9
|
import typer
|
9
10
|
from typing_extensions import Annotated
|
11
|
+
|
10
12
|
from ibm_watsonx_orchestrate.cli.commands.server.server_command import get_default_env_file, merge_env
|
13
|
+
from ibm_watsonx_orchestrate.client.model_policies.model_policies_client import ModelPoliciesClient
|
14
|
+
from ibm_watsonx_orchestrate.client.model_policies.types import ModelPolicy, ModelPolicyInner, \
|
15
|
+
ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
|
16
|
+
from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
|
17
|
+
from ibm_watsonx_orchestrate.client.models.types import CreateVirtualModel, ANTHROPIC_DEFAULT_MAX_TOKENS
|
18
|
+
from ibm_watsonx_orchestrate.client.utils import instantiate_client
|
11
19
|
|
12
20
|
logger = logging.getLogger(__name__)
|
13
21
|
models_app = typer.Typer(no_args_is_help=True)
|
22
|
+
models_policy_app = typer.Typer(no_args_is_help=True)
|
23
|
+
# models_app.add_typer(models_policy_app, name='policy', help='Add or remove pseudo models which route traffic between multiple downstream models')
|
14
24
|
|
15
25
|
WATSONX_URL = os.getenv("WATSONX_URL")
|
16
26
|
|
17
27
|
class ModelHighlighter(rich.highlighter.RegexHighlighter):
|
18
28
|
base_style = "model."
|
19
|
-
highlights = [r"(?P<name>watsonx\/.+\/.+):"]
|
29
|
+
highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
|
20
30
|
|
21
|
-
@models_app.command(name="list")
|
31
|
+
@models_app.command(name="list", help="List available models")
|
22
32
|
def model_list(
|
23
33
|
print_raw: Annotated[
|
24
34
|
bool,
|
25
35
|
typer.Option("--raw", "-r", help="Display the list of models in a non-tabular format"),
|
26
36
|
] = False,
|
27
37
|
):
|
38
|
+
models_client: ModelsClient = instantiate_client(ModelsClient)
|
39
|
+
model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
|
28
40
|
global WATSONX_URL
|
29
41
|
default_env_path = get_default_env_file()
|
30
42
|
merged_env_dict = merge_env(
|
@@ -40,9 +52,18 @@ def model_list(
|
|
40
52
|
logger.error("Error: WATSONX_URL is required in the environment.")
|
41
53
|
sys.exit(1)
|
42
54
|
|
55
|
+
logger.info("Retrieving virtual-model models list...")
|
56
|
+
virtual_models = models_client.list()
|
57
|
+
|
58
|
+
|
59
|
+
|
60
|
+
logger.info("Retrieving virtual-policies models list...")
|
61
|
+
virtual_model_policies = model_policies_client.list()
|
62
|
+
|
43
63
|
logger.info("Retrieving watsonx.ai models list...")
|
44
64
|
found_models = _get_wxai_foundational_models()
|
45
65
|
|
66
|
+
|
46
67
|
preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
|
47
68
|
incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
|
48
69
|
|
@@ -64,7 +85,7 @@ def model_list(
|
|
64
85
|
continue
|
65
86
|
filtered_models.append(model)
|
66
87
|
|
67
|
-
# Sort to put preferred first
|
88
|
+
# Sort to put the preferred first
|
68
89
|
def sort_key(model):
|
69
90
|
model_id = model.get("model_id", "").lower()
|
70
91
|
is_preferred = any(pref in model_id for pref in preferred_list)
|
@@ -76,6 +97,13 @@ def model_list(
|
|
76
97
|
theme = rich.theme.Theme({"model.name": "bold cyan"})
|
77
98
|
console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
|
78
99
|
console.print("[bold]Available Models:[/bold]")
|
100
|
+
|
101
|
+
for model in virtual_models:
|
102
|
+
console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
|
103
|
+
|
104
|
+
for model in virtual_model_policies:
|
105
|
+
console.print(f"- ✨️ {model.name}:", 'No description provided.')
|
106
|
+
|
79
107
|
for model in sorted_models:
|
80
108
|
model_id = model.get("model_id", "N/A")
|
81
109
|
short_desc = model.get("short_description", "No description provided.")
|
@@ -83,17 +111,23 @@ def model_list(
|
|
83
111
|
marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
|
84
112
|
console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
|
85
113
|
|
86
|
-
console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]" )
|
114
|
+
console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
|
87
115
|
else:
|
88
116
|
table = rich.table.Table(
|
89
117
|
show_header=True,
|
90
118
|
title="[bold]Available Models[/bold]",
|
91
|
-
caption="[yellow]★[/yellow] indicates a supported and preferred model",
|
119
|
+
caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
|
92
120
|
show_lines=True)
|
93
121
|
columns = ["Model", "Description"]
|
94
122
|
for col in columns:
|
95
123
|
table.add_column(col)
|
96
124
|
|
125
|
+
for model in virtual_models:
|
126
|
+
table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
|
127
|
+
|
128
|
+
for model in virtual_model_policies:
|
129
|
+
table.add_row(f"✨️ {model.name}", 'No description provided.')
|
130
|
+
|
97
131
|
for model in sorted_models:
|
98
132
|
model_id = model.get("model_id", "N/A")
|
99
133
|
short_desc = model.get("short_description", "No description provided.")
|
@@ -102,6 +136,158 @@ def model_list(
|
|
102
136
|
|
103
137
|
rich.print(table)
|
104
138
|
|
139
|
+
|
140
|
+
|
141
|
+
@models_app.command(name="add", help="Add an llm from a custom provider")
|
142
|
+
def models_add(
|
143
|
+
name: Annotated[
|
144
|
+
str,
|
145
|
+
typer.Option("--name", "-n", help="The name of the model to add"),
|
146
|
+
],
|
147
|
+
env_file: Annotated[
|
148
|
+
str,
|
149
|
+
typer.Option('--env-file', '-e', help='The path to an .env file containing the credentials for your llm provider'),
|
150
|
+
],
|
151
|
+
description: Annotated[
|
152
|
+
str,
|
153
|
+
typer.Option('--description', '-d', help='The description of the model to add'),
|
154
|
+
] = None,
|
155
|
+
display_name: Annotated[
|
156
|
+
str,
|
157
|
+
typer.Option('--display-name', help='What name should this llm appear as within the ui'),
|
158
|
+
] = None,
|
159
|
+
|
160
|
+
):
|
161
|
+
from ibm_watsonx_orchestrate.cli.commands.models.env_file_model_provider_mapper import env_file_to_model_ProviderConfig # lazily import this because the lut building is expensive
|
162
|
+
|
163
|
+
models_client: ModelsClient = instantiate_client(ModelsClient)
|
164
|
+
provider_config = env_file_to_model_ProviderConfig(model_name=name, env_file_path=env_file)
|
165
|
+
if not name.startswith('virtual-model/'):
|
166
|
+
name = f"virtual-model/{name}"
|
167
|
+
|
168
|
+
config=None
|
169
|
+
# Anthropic has no default for max_tokens
|
170
|
+
if "anthropic" in name:
|
171
|
+
config = {
|
172
|
+
"max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
|
173
|
+
}
|
174
|
+
|
175
|
+
model = CreateVirtualModel(
|
176
|
+
name=name,
|
177
|
+
display_name=display_name or name,
|
178
|
+
description=description,
|
179
|
+
tags=[],
|
180
|
+
provider_config=provider_config,
|
181
|
+
config=config
|
182
|
+
)
|
183
|
+
|
184
|
+
models_client.create(model)
|
185
|
+
logger.info(f"Successfully added the model '{name}'")
|
186
|
+
|
187
|
+
|
188
|
+
|
189
|
+
@models_app.command(name="remove", help="Remove an llm from a custom provider")
|
190
|
+
def models_remove(
|
191
|
+
name: Annotated[
|
192
|
+
str,
|
193
|
+
typer.Option("--name", "-n", help="The name of the model to remove"),
|
194
|
+
]
|
195
|
+
):
|
196
|
+
models_client: ModelsClient = instantiate_client(ModelsClient)
|
197
|
+
models = models_client.list()
|
198
|
+
model = next(filter(lambda x: x.name == name or x.name == f"virtual-model/{name}", models), None)
|
199
|
+
if not model:
|
200
|
+
logger.error(f"No model found with the name '{name}'")
|
201
|
+
sys.exit(1)
|
202
|
+
|
203
|
+
models_client.delete(model_id=model.id)
|
204
|
+
logger.info(f"Successfully removed the model '{name}'")
|
205
|
+
|
206
|
+
|
207
|
+
# @models_policy_app.command(name='add', help='Add a model policy')
|
208
|
+
# def models_policy_add(
|
209
|
+
# name: Annotated[
|
210
|
+
# str,
|
211
|
+
# typer.Option("--name", "-n", help="The name of the model to remove"),
|
212
|
+
# ],
|
213
|
+
# models: Annotated[
|
214
|
+
# List[str],
|
215
|
+
# typer.Option('--model', '-m', help='The name of the model to add'),
|
216
|
+
# ],
|
217
|
+
# strategy: Annotated[
|
218
|
+
# ModelPolicyStrategyMode,
|
219
|
+
# typer.Option('--strategy', '-s', help='How to spread traffic across models'),
|
220
|
+
# ],
|
221
|
+
# strategy_on_code: Annotated[
|
222
|
+
# List[int],
|
223
|
+
# typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
|
224
|
+
# ],
|
225
|
+
# retry_on_code: Annotated[
|
226
|
+
# List[int],
|
227
|
+
# typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
|
228
|
+
# ],
|
229
|
+
# retry_attempts: Annotated[
|
230
|
+
# int,
|
231
|
+
# typer.Option('--retry-attempts', help='The number of attempts to retry'),
|
232
|
+
# ],
|
233
|
+
# display_name: Annotated[
|
234
|
+
# str,
|
235
|
+
# typer.Option('--display-name', help='What name should this llm appear as within the ui'),
|
236
|
+
# ] = None
|
237
|
+
# ):
|
238
|
+
# model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
|
239
|
+
# model_client: ModelsClient = instantiate_client(ModelsClient)
|
240
|
+
# model_lut = {m.name: m.id for m in model_client.list()}
|
241
|
+
# for m in models:
|
242
|
+
# if m not in model_lut:
|
243
|
+
# logger.error(f"No model found with the name '{m}'")
|
244
|
+
# exit(1)
|
245
|
+
|
246
|
+
# inner = ModelPolicyInner()
|
247
|
+
# inner.strategy = ModelPolicyStrategy(
|
248
|
+
# mode=strategy,
|
249
|
+
# on_status_codes=strategy_on_code
|
250
|
+
# )
|
251
|
+
# inner.targets = [ModelPolicyTarget(model_id=model_lut[m], weight=1) for m in models]
|
252
|
+
# if retry_on_code:
|
253
|
+
# inner.retry = ModelPolicyRetry(
|
254
|
+
# on_status_codes=retry_on_code,
|
255
|
+
# attempts=retry_attempts
|
256
|
+
# )
|
257
|
+
|
258
|
+
# if not display_name:
|
259
|
+
# display_name = name
|
260
|
+
|
261
|
+
|
262
|
+
# policy = ModelPolicy(
|
263
|
+
# name=name,
|
264
|
+
# display_name=display_name,
|
265
|
+
# policy=inner
|
266
|
+
# )
|
267
|
+
# model_policies_client.create(policy)
|
268
|
+
# logger.info(f"Successfully added the model policy '{name}'")
|
269
|
+
|
270
|
+
|
271
|
+
|
272
|
+
# @models_policy_app.command(name='remove', help='Remove a model policy')
|
273
|
+
# def models_policy_remove(
|
274
|
+
# name: Annotated[
|
275
|
+
# str,
|
276
|
+
# typer.Option("--name", "-n", help="The name of the model policy to remove"),
|
277
|
+
# ]
|
278
|
+
# ):
|
279
|
+
# model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
|
280
|
+
# model_policies = model_policies_client.list()
|
281
|
+
|
282
|
+
# policy = next(filter(lambda x: x.name == name or x.name == f"virtual-policy/{name}", model_policies), None)
|
283
|
+
# if not policy:
|
284
|
+
# logger.error(f"No model found with the name '{name}'")
|
285
|
+
# exit(1)
|
286
|
+
|
287
|
+
# model_policies_client.delete(model_policy_id=policy.id)
|
288
|
+
# logger.info(f"Successfully removed the model '{name}'")
|
289
|
+
|
290
|
+
|
105
291
|
def _get_wxai_foundational_models():
|
106
292
|
foundation_models_url = WATSONX_URL + "/ml/v1/foundation_model_specs?version=2024-05-01"
|
107
293
|
|