ibm-watsonx-orchestrate 1.3.0__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -1
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +2 -0
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +9 -2
  4. ibm_watsonx_orchestrate/agent_builder/toolkits/base_toolkit.py +32 -0
  5. ibm_watsonx_orchestrate/agent_builder/toolkits/types.py +42 -0
  6. ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +10 -1
  7. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +4 -2
  8. ibm_watsonx_orchestrate/agent_builder/tools/types.py +2 -1
  9. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +29 -0
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +271 -12
  11. ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +17 -2
  12. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +180 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +194 -8
  14. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +117 -48
  15. ibm_watsonx_orchestrate/cli/commands/server/types.py +105 -0
  16. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +55 -7
  17. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_controller.py +123 -42
  18. ibm_watsonx_orchestrate/cli/commands/tools/tools_command.py +22 -1
  19. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +197 -12
  20. ibm_watsonx_orchestrate/client/agents/agent_client.py +4 -1
  21. ibm_watsonx_orchestrate/client/agents/assistant_agent_client.py +5 -1
  22. ibm_watsonx_orchestrate/client/agents/external_agent_client.py +5 -1
  23. ibm_watsonx_orchestrate/client/analytics/llm/analytics_llm_client.py +2 -6
  24. ibm_watsonx_orchestrate/client/base_api_client.py +5 -2
  25. ibm_watsonx_orchestrate/client/connections/connections_client.py +3 -9
  26. ibm_watsonx_orchestrate/client/model_policies/__init__.py +0 -0
  27. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +47 -0
  28. ibm_watsonx_orchestrate/client/model_policies/types.py +36 -0
  29. ibm_watsonx_orchestrate/client/models/__init__.py +0 -0
  30. ibm_watsonx_orchestrate/client/models/models_client.py +46 -0
  31. ibm_watsonx_orchestrate/client/models/types.py +177 -0
  32. ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +15 -6
  33. ibm_watsonx_orchestrate/client/tools/tempus_client.py +40 -0
  34. ibm_watsonx_orchestrate/client/tools/tool_client.py +8 -0
  35. ibm_watsonx_orchestrate/docker/compose-lite.yml +68 -13
  36. ibm_watsonx_orchestrate/docker/default.env +22 -12
  37. ibm_watsonx_orchestrate/docker/tempus/common-config.yaml +1 -1
  38. ibm_watsonx_orchestrate/experimental/flow_builder/__init__.py +0 -0
  39. ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +41 -0
  40. ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +17 -0
  41. ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +91 -0
  42. ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +143 -0
  43. ibm_watsonx_orchestrate/experimental/flow_builder/flows/events.py +72 -0
  44. ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +1288 -0
  45. ibm_watsonx_orchestrate/experimental/flow_builder/node.py +97 -0
  46. ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +98 -0
  47. ibm_watsonx_orchestrate/experimental/flow_builder/types.py +492 -0
  48. ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +113 -0
  49. ibm_watsonx_orchestrate/utils/utils.py +5 -2
  50. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/METADATA +4 -1
  51. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/RECORD +54 -32
  52. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/WHEEL +0 -0
  53. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/entry_points.txt +0 -0
  54. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -49,6 +49,15 @@ def get_file_name(path: str):
49
49
  # return to_column_name(path.split("/")[-1].split(".")[0])
50
50
  return path.split("/")[-1]
51
51
 
52
+ def get_relative_file_path(path, dir):
53
+ if path.startswith("/"):
54
+ return path
55
+ elif path.startswith("./"):
56
+ return f"{dir}{path.removeprefix('.')}"
57
+ else:
58
+ return f"{dir}/{path}"
59
+
60
+
52
61
  class KnowledgeBaseController:
53
62
  def __init__(self):
54
63
  self.client = None
@@ -68,8 +77,9 @@ class KnowledgeBaseController:
68
77
  kb.validate_documents_or_index_exists()
69
78
  if kb.documents:
70
79
  file_dir = "/".join(file.split("/")[:-1])
71
- files = [('files', (get_file_name(file_path), open(file_path if file_path.startswith("/") else (file_path if not file_dir else f"{file_dir}/{file_path}"), 'rb'))) for file_path in kb.documents]
80
+ files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in kb.documents]
72
81
 
82
+ kb.prioritize_built_in_index = True
73
83
  payload = kb.model_dump(exclude_none=True);
74
84
  payload.pop('documents');
75
85
 
@@ -78,6 +88,7 @@ class KnowledgeBaseController:
78
88
  if len(kb.conversational_search_tool.index_config) != 1:
79
89
  raise ValueError(f"Must provide exactly one conversational_search_tool.index_config. Provided {len(kb.conversational_search_tool.index_config)}.")
80
90
 
91
+
81
92
  if app_id:
82
93
  connections_client = get_connections_client()
83
94
  connection_id = None
@@ -90,6 +101,7 @@ class KnowledgeBaseController:
90
101
  connection_id = connections.connection_id
91
102
  kb.conversational_search_tool.index_config[0].connection_id = connection_id
92
103
 
104
+ kb.prioritize_built_in_index = False
93
105
  client.create(payload=kb.model_dump(exclude_none=True))
94
106
 
95
107
  logger.info(f"Successfully imported knowledge base '{kb.name}'")
@@ -126,13 +138,16 @@ class KnowledgeBaseController:
126
138
 
127
139
  if update_request.documents:
128
140
  file_dir = "/".join(file.split("/")[:-1])
129
- files = [('files', (get_file_name(file_path), open(file_path if file_path.startswith("/") else f"{file_dir}/{file_path}", 'rb'))) for file_path in update_request.documents]
141
+ files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in update_request.documents]
130
142
 
143
+ update_request.prioritize_built_in_index = True
131
144
  payload = update_request.model_dump(exclude_none=True);
132
145
  payload.pop('documents');
133
146
 
134
147
  self.get_client().update_with_documents(knowledge_base_id, payload=payload, files=files)
135
148
  else:
149
+ if update_request.conversational_search_tool and update_request.conversational_search_tool.index_config:
150
+ update_request.prioritize_built_in_index = False
136
151
  self.get_client().update(knowledge_base_id, update_request.model_dump(exclude_none=True))
137
152
 
138
153
  logEnding = f"with ID '{id}'" if id else f"'{name}'"
@@ -0,0 +1,180 @@
1
+ # https://portkey.ai/
2
+ from dotenv import dotenv_values
3
+
4
+ from ibm_watsonx_orchestrate.client.models.types import ProviderConfig, ModelProvider
5
+
6
+ def _bpp(name: str, provider_config=None) -> dict:
7
+ if provider_config is None:
8
+ provider_config = {}
9
+
10
+ name = name.upper()
11
+
12
+ provider_config.update({
13
+ f"{name}_API_KEY": 'api_key',
14
+ f"{name}_CUSTOM_HOST": 'custom_host',
15
+ f"{name}_URL_TO_FETCH": 'url_to_fetch',
16
+ f"{name}_FORWARD_HEADERS": lambda value, config: setattr(config, 'forward_headers', list(map(lambda v: v.strip(), value.split(';')))),
17
+ f"{name}_REQUEST_TIMEOUT": lambda value, config: setattr(config, 'request_timeout', int(value)),
18
+ f"{name}_TRANSFORM_TO_FORM_DATA": 'transform_to_form_data'
19
+ })
20
+
21
+ return provider_config
22
+
23
+ # model provider prefix => ENV_VAR key => provider config key
24
+ PROVIDER_PROPERTIES_LUT = {
25
+ ModelProvider.OPENAI: _bpp('OPENAI', {}),
26
+ # ModelProvider.A21: _bpp('A21', {}),
27
+ ModelProvider.ANTHROPIC: _bpp('ANTHROPIC', {
28
+ 'ANTHROPIC_BETA': 'anthropic_beta',
29
+ 'ANTHROPIC_VERSION': 'anthropic_version'
30
+ }),
31
+ # ModelProvider.ANYSCALE: _bpp('ANYSCALE', {}),
32
+ # ModelProvider.AZURE_OPENAI: _bpp('AZURE_OPENAI', {}),
33
+ # ModelProvider.AZURE_AI: _bpp('AZURE', {
34
+ # 'AZURE_AI_RESOURCE_NAME': 'azure_resource_name',
35
+ # 'AZURE_AI_DEPLOYMENT_ID': 'azure_deployment_id',
36
+ # 'AZURE_AI_API_VERSION': 'azure_api_version',
37
+ # 'AZURE_AI_AD_AUTH': 'ad_auth',
38
+ # 'AZURE_AI_AUTH_MODE': 'azure_auth_mode',
39
+ # 'AZURE_AI_MANAGED_CLIENT_ID': 'azure_managed_client_id',
40
+ # 'AZURE_AI_ENTRA_CLIENT_ID': 'azure_entra_client_id',
41
+ # 'AZURE_AI_ENTRA_CLIENT_SECRET': 'azure_entra_client_secret',
42
+ # 'AZURE_AI_ENTRA_TENANT_ID': 'azure_entra_tenant_id',
43
+ # 'AZURE_AI_AD_TOKEN': 'azure_ad_token',
44
+ # 'AZURE_AI_MODEL_NAME': 'azure_model_name',
45
+ # }),
46
+ # ModelProvider.BEDROCK: _bpp('BEDROCK', {
47
+ # 'AWS_SECRET_ACCESS_KEY': 'aws_secret_access_key',
48
+ # 'AWS_ACCESS_KEY_ID': 'aws_access_key_id',
49
+ # 'AWS_SESSION_TOKEN': 'aws_session_token',
50
+ # 'AWS_REGION': 'aws_region',
51
+ # 'AWS_AUTH_TYPE': 'aws_auth_type',
52
+ # 'AWS_ROLE_ARN': 'aws_role_arn',
53
+ # 'AWS_EXTERNAL_ID': 'aws_external_id',
54
+ # 'AWS_S3_BUCKET': 'aws_s3_bucket',
55
+ # 'AWS_S3_OBJECT_KEY': 'aws_s3_object_key',
56
+ # 'AWS_BEDROCK_MODEL': 'aws_bedrock_model',
57
+ # 'AWS_SERVER_SIDE_ENCRYPTION': 'aws_server_side_encryption',
58
+ # 'AWS_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID': 'aws_server_side_encryption_kms_key_id',
59
+ # }),
60
+ # ModelProvider.CEREBRAS: _bpp('COHERE', {}),
61
+ # ModelProvider.COHERE: _bpp('COHERE', {}),
62
+ ModelProvider.GOOGLE: _bpp('GOOGLE', {}),
63
+ # ModelProvider.VERTEX_AI: _bpp('GOOGLE_VERTEX_AI', {
64
+ # 'GOOGLE_VERTEX_AI_REGION': 'vertex_region',
65
+ # 'GOOGLE_VERTEX_AI_PROJECT_ID': 'vertex_project_id',
66
+ # 'GOOGLE_VERTEX_AI_SERVICE_ACCOUNT_JSON': 'vertex_service_account_json',
67
+ # 'GOOGLE_VERTEX_AI_STORAGE_BUCKET_NAME': 'vertex_storage_bucket_name',
68
+ # 'GOOGLE_VERTEX_AI_MODEL_NAME': 'vertex_model_name',
69
+ # 'GOOGLE_VERTEX_AI_FILENAME': 'filename'
70
+ # }),
71
+ # ModelProvider.GROQ: _bpp('GROQ', {}),
72
+ # ModelProvider.HUGGINGFACE: _bpp('HUGGINGFACE', {
73
+ # 'HUGGINGFACE_BASE_URL': 'huggingfaceBaseUrl'
74
+ # }),
75
+ ModelProvider.MISTRAL_AI: _bpp('MISTRAL', {
76
+ 'MISTRAL_FIM_COMPLETION': 'mistral_fim_completion'
77
+ }),
78
+ # ModelProvider.JINA: _bpp('JINA', {}),
79
+ ModelProvider.OLLAMA: _bpp('OLLAMA', {}),
80
+ ModelProvider.OPENROUTER: _bpp('OPENROUTER', {}),
81
+ # ModelProvider.STABILITY_AI: _bpp('STABILITY', {
82
+ # 'STABILITY_CLIENT_ID': 'stability_client_id',
83
+ # 'STABILITY_CLIENT_USER_ID': 'stability_client_user_id',
84
+ # 'STABILITY_CLIENT_VERSION': 'stability_client_version'
85
+ # }),
86
+ # ModelProvider.TOGETHER_AI: _bpp('TOGETHER_AI', {}),
87
+ ModelProvider.WATSONX: _bpp('WATSONX', {
88
+ 'WATSONX_VERSION': 'watsonx_version',
89
+ 'WATSONX_SPACE_ID': 'watsonx_space_id',
90
+ 'WATSONX_PROJECT_ID': 'watsonx_project_id',
91
+ 'WATSONX_APIKEY': 'api_key'
92
+ })
93
+
94
+ # 'palm': _bpp('PALM', {}),
95
+ # 'nomic': _bpp('NOMIC', {}),
96
+ # 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
97
+ # 'segmind': _bpp('SEGMIND', {}),
98
+ # 'deepinfra': _bpp('DEEPINFRA', {}),
99
+ # 'novita-ai': _bpp('NOVITA_AI', {}),
100
+ # 'fireworks-ai': _bpp('FIREWORKS',{
101
+ # 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
102
+ # }),
103
+ # 'deepseek': _bpp('DEEPSEEK', {}),
104
+ # 'voyage': _bpp('VOYAGE', {}),
105
+ # 'moonshot': _bpp('MOONSHOT', {}),
106
+ # 'lingyi': _bpp('LINGYI', {}),
107
+ # 'zhipu': _bpp('ZHIPU', {}),
108
+ # 'monsterapi': _bpp('MONSTERAPI', {}),
109
+ # 'predibase': _bpp('PREDIBASE', {}),
110
+
111
+ # 'github': _bpp('GITHUB', {}),
112
+ # 'deepbricks': _bpp('DEEPBRICKS', {}),
113
+ # 'sagemaker': _bpp('AMZN_SAGEMAKER', {
114
+ # 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
115
+ # 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
116
+ # 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
117
+ # 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
118
+ # 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
119
+ # 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
120
+ # 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
121
+ # 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
122
+ # 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
123
+ # }),
124
+ # '@cf': _bpp('WORKERS_AI', { # workers ai
125
+ # 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
126
+ # }),
127
+ # 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
128
+ # 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
129
+ # })
130
+ }
131
+ PROVIDER_PROPERTIES_RLUT= {}
132
+ for provider in PROVIDER_PROPERTIES_LUT.keys():
133
+ PROVIDER_PROPERTIES_RLUT[provider] = {v:k for k,v in PROVIDER_PROPERTIES_LUT[provider].items()}
134
+
135
+ PROVIDER_LUT = {k:k for k in PROVIDER_PROPERTIES_LUT.keys() }
136
+ PROVIDER_LUT.update({
137
+ # any overrides for the provider prefix to provider name can be provided here on PROVIDER_LUT
138
+ })
139
+
140
+
141
+
142
+ PROVIDER_REQUIRED_FIELDS = {k:['api_key'] if k not in ['ollama'] else [] for k in PROVIDER_PROPERTIES_LUT.keys()}
143
+ PROVIDER_REQUIRED_FIELDS.update({
144
+ # Mark the required fields for a provider
145
+ })
146
+
147
+ def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
148
+ provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
149
+ if provider not in PROVIDER_LUT:
150
+ raise ValueError(f"Unsupported model provider {provider}")
151
+
152
+ values = dotenv_values(str(env_file_path))
153
+
154
+ if values is None:
155
+ raise ValueError(f"No provider configuration in env file {env_file_path}")
156
+
157
+ cfg = ProviderConfig()
158
+ cfg.provider = PROVIDER_LUT[provider]
159
+
160
+ cred_lut = PROVIDER_PROPERTIES_LUT[provider]
161
+
162
+
163
+ consumed_credentials = []
164
+ for key, value in values.items():
165
+ if key in cred_lut:
166
+ k = cred_lut[key]
167
+ consumed_credentials.append(k)
168
+ setattr(cfg, k, value)
169
+
170
+
171
+ required_creds = PROVIDER_REQUIRED_FIELDS[provider]
172
+ missing_credentials = []
173
+ for cred in required_creds:
174
+ if cred not in consumed_credentials:
175
+ missing_credentials.append(cred)
176
+
177
+ if len(missing_credentials) > 0:
178
+ raise ValueError(f"Missing environment variable(s) {', '.join(map(lambda c: PROVIDER_PROPERTIES_RLUT[provider][c], missing_credentials))} required for the provider {provider}")
179
+
180
+ return cfg
@@ -1,30 +1,42 @@
1
1
  import logging
2
2
  import os
3
- import requests
4
3
  import sys
5
- import rich.highlighter
6
- import typer
4
+ from typing import List
5
+
6
+ import requests
7
7
  import rich
8
+ import rich.highlighter
8
9
  import typer
9
10
  from typing_extensions import Annotated
11
+
10
12
  from ibm_watsonx_orchestrate.cli.commands.server.server_command import get_default_env_file, merge_env
13
+ from ibm_watsonx_orchestrate.client.model_policies.model_policies_client import ModelPoliciesClient
14
+ from ibm_watsonx_orchestrate.client.model_policies.types import ModelPolicy, ModelPolicyInner, \
15
+ ModelPolicyRetry, ModelPolicyStrategy, ModelPolicyStrategyMode, ModelPolicyTarget
16
+ from ibm_watsonx_orchestrate.client.models.models_client import ModelsClient
17
+ from ibm_watsonx_orchestrate.client.models.types import CreateVirtualModel, ANTHROPIC_DEFAULT_MAX_TOKENS
18
+ from ibm_watsonx_orchestrate.client.utils import instantiate_client
11
19
 
12
20
  logger = logging.getLogger(__name__)
13
21
  models_app = typer.Typer(no_args_is_help=True)
22
+ models_policy_app = typer.Typer(no_args_is_help=True)
23
+ # models_app.add_typer(models_policy_app, name='policy', help='Add or remove pseudo models which route traffic between multiple downstream models')
14
24
 
15
25
  WATSONX_URL = os.getenv("WATSONX_URL")
16
26
 
17
27
  class ModelHighlighter(rich.highlighter.RegexHighlighter):
18
28
  base_style = "model."
19
- highlights = [r"(?P<name>watsonx\/.+\/.+):"]
29
+ highlights = [r"(?P<name>(watsonx|virtual[-]model|virtual[-]policy)\/.+\/.+):"]
20
30
 
21
- @models_app.command(name="list")
31
+ @models_app.command(name="list", help="List available models")
22
32
  def model_list(
23
33
  print_raw: Annotated[
24
34
  bool,
25
35
  typer.Option("--raw", "-r", help="Display the list of models in a non-tabular format"),
26
36
  ] = False,
27
37
  ):
38
+ models_client: ModelsClient = instantiate_client(ModelsClient)
39
+ model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
28
40
  global WATSONX_URL
29
41
  default_env_path = get_default_env_file()
30
42
  merged_env_dict = merge_env(
@@ -40,9 +52,18 @@ def model_list(
40
52
  logger.error("Error: WATSONX_URL is required in the environment.")
41
53
  sys.exit(1)
42
54
 
55
+ logger.info("Retrieving virtual-model models list...")
56
+ virtual_models = models_client.list()
57
+
58
+
59
+
60
+ logger.info("Retrieving virtual-policies models list...")
61
+ virtual_model_policies = model_policies_client.list()
62
+
43
63
  logger.info("Retrieving watsonx.ai models list...")
44
64
  found_models = _get_wxai_foundational_models()
45
65
 
66
+
46
67
  preferred_str = merged_env_dict.get('PREFERRED_MODELS', '')
47
68
  incompatible_str = merged_env_dict.get('INCOMPATIBLE_MODELS', '')
48
69
 
@@ -64,7 +85,7 @@ def model_list(
64
85
  continue
65
86
  filtered_models.append(model)
66
87
 
67
- # Sort to put preferred first
88
+ # Sort to put the preferred first
68
89
  def sort_key(model):
69
90
  model_id = model.get("model_id", "").lower()
70
91
  is_preferred = any(pref in model_id for pref in preferred_list)
@@ -76,6 +97,13 @@ def model_list(
76
97
  theme = rich.theme.Theme({"model.name": "bold cyan"})
77
98
  console = rich.console.Console(highlighter=ModelHighlighter(), theme=theme)
78
99
  console.print("[bold]Available Models:[/bold]")
100
+
101
+ for model in virtual_models:
102
+ console.print(f"- ✨️ {model.name}:", model.description or 'No description provided.')
103
+
104
+ for model in virtual_model_policies:
105
+ console.print(f"- ✨️ {model.name}:", 'No description provided.')
106
+
79
107
  for model in sorted_models:
80
108
  model_id = model.get("model_id", "N/A")
81
109
  short_desc = model.get("short_description", "No description provided.")
@@ -83,17 +111,23 @@ def model_list(
83
111
  marker = "★ " if any(pref in model_id.lower() for pref in preferred_list) else ""
84
112
  console.print(f"- [yellow]{marker}[/yellow]{full_model_name}")
85
113
 
86
- console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]" )
114
+ console.print("[yellow]★[/yellow] [italic dim]indicates a supported and preferred model[/italic dim]\n[blue dim]✨️[/blue dim] [italic dim]indicates a model from a custom provider[/italic dim]" )
87
115
  else:
88
116
  table = rich.table.Table(
89
117
  show_header=True,
90
118
  title="[bold]Available Models[/bold]",
91
- caption="[yellow]★[/yellow] indicates a supported and preferred model",
119
+ caption="[yellow]★ [/yellow] indicates a supported and preferred model from watsonx\n[blue]✨️[/blue] indicates a model from a custom provider",
92
120
  show_lines=True)
93
121
  columns = ["Model", "Description"]
94
122
  for col in columns:
95
123
  table.add_column(col)
96
124
 
125
+ for model in virtual_models:
126
+ table.add_row(f"✨️ {model.name}", model.description or 'No description provided.')
127
+
128
+ for model in virtual_model_policies:
129
+ table.add_row(f"✨️ {model.name}", 'No description provided.')
130
+
97
131
  for model in sorted_models:
98
132
  model_id = model.get("model_id", "N/A")
99
133
  short_desc = model.get("short_description", "No description provided.")
@@ -102,6 +136,158 @@ def model_list(
102
136
 
103
137
  rich.print(table)
104
138
 
139
+
140
+
141
+ @models_app.command(name="add", help="Add an llm from a custom provider")
142
+ def models_add(
143
+ name: Annotated[
144
+ str,
145
+ typer.Option("--name", "-n", help="The name of the model to add"),
146
+ ],
147
+ env_file: Annotated[
148
+ str,
149
+ typer.Option('--env-file', '-e', help='The path to an .env file containing the credentials for your llm provider'),
150
+ ],
151
+ description: Annotated[
152
+ str,
153
+ typer.Option('--description', '-d', help='The description of the model to add'),
154
+ ] = None,
155
+ display_name: Annotated[
156
+ str,
157
+ typer.Option('--display-name', help='What name should this llm appear as within the ui'),
158
+ ] = None,
159
+
160
+ ):
161
+ from ibm_watsonx_orchestrate.cli.commands.models.env_file_model_provider_mapper import env_file_to_model_ProviderConfig # lazily import this because the lut building is expensive
162
+
163
+ models_client: ModelsClient = instantiate_client(ModelsClient)
164
+ provider_config = env_file_to_model_ProviderConfig(model_name=name, env_file_path=env_file)
165
+ if not name.startswith('virtual-model/'):
166
+ name = f"virtual-model/{name}"
167
+
168
+ config=None
169
+ # Anthropic has no default for max_tokens
170
+ if "anthropic" in name:
171
+ config = {
172
+ "max_tokens": ANTHROPIC_DEFAULT_MAX_TOKENS
173
+ }
174
+
175
+ model = CreateVirtualModel(
176
+ name=name,
177
+ display_name=display_name or name,
178
+ description=description,
179
+ tags=[],
180
+ provider_config=provider_config,
181
+ config=config
182
+ )
183
+
184
+ models_client.create(model)
185
+ logger.info(f"Successfully added the model '{name}'")
186
+
187
+
188
+
189
+ @models_app.command(name="remove", help="Remove an llm from a custom provider")
190
+ def models_remove(
191
+ name: Annotated[
192
+ str,
193
+ typer.Option("--name", "-n", help="The name of the model to remove"),
194
+ ]
195
+ ):
196
+ models_client: ModelsClient = instantiate_client(ModelsClient)
197
+ models = models_client.list()
198
+ model = next(filter(lambda x: x.name == name or x.name == f"virtual-model/{name}", models), None)
199
+ if not model:
200
+ logger.error(f"No model found with the name '{name}'")
201
+ sys.exit(1)
202
+
203
+ models_client.delete(model_id=model.id)
204
+ logger.info(f"Successfully removed the model '{name}'")
205
+
206
+
207
+ # @models_policy_app.command(name='add', help='Add a model policy')
208
+ # def models_policy_add(
209
+ # name: Annotated[
210
+ # str,
211
+ # typer.Option("--name", "-n", help="The name of the model to remove"),
212
+ # ],
213
+ # models: Annotated[
214
+ # List[str],
215
+ # typer.Option('--model', '-m', help='The name of the model to add'),
216
+ # ],
217
+ # strategy: Annotated[
218
+ # ModelPolicyStrategyMode,
219
+ # typer.Option('--strategy', '-s', help='How to spread traffic across models'),
220
+ # ],
221
+ # strategy_on_code: Annotated[
222
+ # List[int],
223
+ # typer.Option('--strategy-on-code', help='The http status to consider invoking the strategy'),
224
+ # ],
225
+ # retry_on_code: Annotated[
226
+ # List[int],
227
+ # typer.Option('--retry-on-code', help='The http status to consider retrying the llm call'),
228
+ # ],
229
+ # retry_attempts: Annotated[
230
+ # int,
231
+ # typer.Option('--retry-attempts', help='The number of attempts to retry'),
232
+ # ],
233
+ # display_name: Annotated[
234
+ # str,
235
+ # typer.Option('--display-name', help='What name should this llm appear as within the ui'),
236
+ # ] = None
237
+ # ):
238
+ # model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
239
+ # model_client: ModelsClient = instantiate_client(ModelsClient)
240
+ # model_lut = {m.name: m.id for m in model_client.list()}
241
+ # for m in models:
242
+ # if m not in model_lut:
243
+ # logger.error(f"No model found with the name '{m}'")
244
+ # exit(1)
245
+
246
+ # inner = ModelPolicyInner()
247
+ # inner.strategy = ModelPolicyStrategy(
248
+ # mode=strategy,
249
+ # on_status_codes=strategy_on_code
250
+ # )
251
+ # inner.targets = [ModelPolicyTarget(model_id=model_lut[m], weight=1) for m in models]
252
+ # if retry_on_code:
253
+ # inner.retry = ModelPolicyRetry(
254
+ # on_status_codes=retry_on_code,
255
+ # attempts=retry_attempts
256
+ # )
257
+
258
+ # if not display_name:
259
+ # display_name = name
260
+
261
+
262
+ # policy = ModelPolicy(
263
+ # name=name,
264
+ # display_name=display_name,
265
+ # policy=inner
266
+ # )
267
+ # model_policies_client.create(policy)
268
+ # logger.info(f"Successfully added the model policy '{name}'")
269
+
270
+
271
+
272
+ # @models_policy_app.command(name='remove', help='Remove a model policy')
273
+ # def models_policy_remove(
274
+ # name: Annotated[
275
+ # str,
276
+ # typer.Option("--name", "-n", help="The name of the model policy to remove"),
277
+ # ]
278
+ # ):
279
+ # model_policies_client: ModelPoliciesClient = instantiate_client(ModelPoliciesClient)
280
+ # model_policies = model_policies_client.list()
281
+
282
+ # policy = next(filter(lambda x: x.name == name or x.name == f"virtual-policy/{name}", model_policies), None)
283
+ # if not policy:
284
+ # logger.error(f"No model found with the name '{name}'")
285
+ # exit(1)
286
+
287
+ # model_policies_client.delete(model_policy_id=policy.id)
288
+ # logger.info(f"Successfully removed the model '{name}'")
289
+
290
+
105
291
  def _get_wxai_foundational_models():
106
292
  foundation_models_url = WATSONX_URL + "/ml/v1/foundation_model_specs?version=2024-05-01"
107
293