ibm-watsonx-orchestrate 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -1
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +14 -0
  4. ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
  5. ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
  6. ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
  7. ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +57 -9
  8. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
  9. ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
  11. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
  12. ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +140 -258
  14. ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
  15. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
  16. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
  17. ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
  18. ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
  19. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
  20. ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
  21. ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +13 -8
  22. ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
  23. ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
  24. ibm_watsonx_orchestrate/docker/default.env +10 -11
  25. ibm_watsonx_orchestrate/experimental/flow_builder/data_map.py +19 -0
  26. ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +4 -3
  27. ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +3 -1
  28. ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +3 -2
  29. ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +245 -223
  30. ibm_watsonx_orchestrate/experimental/flow_builder/node.py +34 -15
  31. ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +7 -39
  32. ibm_watsonx_orchestrate/experimental/flow_builder/types.py +285 -12
  33. ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +3 -1
  34. {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
  35. {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +38 -35
  36. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
  37. ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +0 -91
  38. {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
  39. {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
  40. {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -11,9 +11,12 @@ import logging
11
11
  from pathlib import Path
12
12
  from copy import deepcopy
13
13
 
14
- from typing import Iterable, List
14
+ from typing import Iterable, List, TypeVar
15
+ from ibm_watsonx_orchestrate.agent_builder.agents.types import AgentStyle
16
+ from ibm_watsonx_orchestrate.agent_builder.tools.types import ToolSpec
15
17
  from ibm_watsonx_orchestrate.cli.commands.tools.tools_controller import import_python_tool, ToolsController
16
18
  from ibm_watsonx_orchestrate.cli.commands.knowledge_bases.knowledge_bases_controller import import_python_knowledge_base
19
+ from ibm_watsonx_orchestrate.cli.commands.models.models_controller import import_python_model
17
20
 
18
21
  from ibm_watsonx_orchestrate.agent_builder.agents import (
19
22
  Agent,
@@ -34,10 +37,14 @@ from ibm_watsonx_orchestrate.utils.utils import check_file_in_zip
34
37
 
35
38
  logger = logging.getLogger(__name__)
36
39
 
40
+ # Helper generic type for any agent
41
+ AnyAgentT = TypeVar("AnyAgentT", bound=Agent | ExternalAgent | AssistantAgent)
42
+
37
43
  def import_python_agent(file: str) -> List[Agent | ExternalAgent | AssistantAgent]:
38
44
  # Import tools
39
45
  import_python_tool(file)
40
46
  import_python_knowledge_base(file)
47
+ import_python_model(file)
41
48
 
42
49
  file_path = Path(file)
43
50
  file_directory = file_path.parent
@@ -90,6 +97,8 @@ def parse_create_native_args(name: str, kind: AgentKind, description: str | None
90
97
  "description": description,
91
98
  "llm": args.get("llm"),
92
99
  "style": args.get("style"),
100
+ "custom_join_tool": args.get("custom_join_tool"),
101
+ "structured_output": args.get("structured_output"),
93
102
  }
94
103
 
95
104
  collaborators = args.get("collaborators", [])
@@ -202,7 +211,7 @@ class AgentsController:
202
211
  return self.knowledge_base_client
203
212
 
204
213
  @staticmethod
205
- def import_agent(file: str, app_id: str) -> Iterable:
214
+ def import_agent(file: str, app_id: str) -> List[Agent | ExternalAgent | AssistantAgent]:
206
215
  agents = parse_file(file)
207
216
  for agent in agents:
208
217
  if app_id and agent.kind != AgentKind.NATIVE and agent.kind != AgentKind.ASSISTANT:
@@ -216,7 +225,9 @@ class AgentsController:
216
225
  ) -> Agent | ExternalAgent | AssistantAgent:
217
226
  match kind:
218
227
  case AgentKind.NATIVE:
219
- agent_details = parse_create_native_args(name, kind=kind, description=description, **kwargs)
228
+ agent_details = parse_create_native_args(
229
+ name, kind=kind, description=description, **kwargs
230
+ )
220
231
  agent = Agent.model_validate(agent_details)
221
232
  AgentsController().persist_record(agent=agent, **kwargs)
222
233
  case AgentKind.EXTERNAL:
@@ -296,12 +307,17 @@ class AgentsController:
296
307
  ref_agent.collaborators = ref_collaborators
297
308
 
298
309
  return ref_agent
299
-
310
+
300
311
  def dereference_tools(self, agent: Agent) -> Agent:
301
312
  tool_client = self.get_tool_client()
302
313
 
303
314
  deref_agent = deepcopy(agent)
304
- matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
315
+
316
+ # If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
317
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
318
+ matching_tools = tool_client.get_drafts_by_names(deref_agent.tools + [deref_agent.custom_join_tool])
319
+ else:
320
+ matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
305
321
 
306
322
  name_id_lookup = {}
307
323
  for tool in matching_tools:
@@ -318,6 +334,13 @@ class AgentsController:
318
334
  sys.exit(1)
319
335
  deref_tools.append(id)
320
336
  deref_agent.tools = deref_tools
337
+
338
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
339
+ join_tool_id = name_id_lookup.get(agent.custom_join_tool)
340
+ if not join_tool_id:
341
+ logger.error(f"Failed to find custom join tool. No tools found with the name '{agent.custom_join_tool}'")
342
+ sys.exit(1)
343
+ deref_agent.custom_join_tool = join_tool_id
321
344
 
322
345
  return deref_agent
323
346
 
@@ -325,7 +348,12 @@ class AgentsController:
325
348
  tool_client = self.get_tool_client()
326
349
 
327
350
  ref_agent = deepcopy(agent)
328
- matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
351
+
352
+ # If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
353
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
354
+ matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools + [ref_agent.custom_join_tool])
355
+ else:
356
+ matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
329
357
 
330
358
  id_name_lookup = {}
331
359
  for tool in matching_tools:
@@ -342,6 +370,13 @@ class AgentsController:
342
370
  sys.exit(1)
343
371
  ref_tools.append(name)
344
372
  ref_agent.tools = ref_tools
373
+
374
+ if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
375
+ join_tool_name = id_name_lookup.get(agent.custom_join_tool)
376
+ if not join_tool_name:
377
+ logger.error(f"Failed to find custom join tool. No tools found with the id '{agent.custom_join_tool}'")
378
+ sys.exit(1)
379
+ ref_agent.custom_join_tool = join_tool_name
345
380
 
346
381
  return ref_agent
347
382
 
@@ -409,7 +444,7 @@ class AgentsController:
409
444
  def dereference_native_agent_dependencies(self, agent: Agent) -> Agent:
410
445
  if agent.collaborators and len(agent.collaborators):
411
446
  agent = self.dereference_collaborators(agent)
412
- if agent.tools and len(agent.tools):
447
+ if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
413
448
  agent = self.dereference_tools(agent)
414
449
  if agent.knowledge_base and len(agent.knowledge_base):
415
450
  agent = self.dereference_knowledge_bases(agent)
@@ -419,7 +454,7 @@ class AgentsController:
419
454
  def reference_native_agent_dependencies(self, agent: Agent) -> Agent:
420
455
  if agent.collaborators and len(agent.collaborators):
421
456
  agent = self.reference_collaborators(agent)
422
- if agent.tools and len(agent.tools):
457
+ if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
423
458
  agent = self.reference_tools(agent)
424
459
  if agent.knowledge_base and len(agent.knowledge_base):
425
460
  agent = self.reference_knowledge_bases(agent)
@@ -443,21 +478,21 @@ class AgentsController:
443
478
  return agent
444
479
 
445
480
  # Convert all names used in an agent to the corresponding ids
446
- def dereference_agent_dependencies(self, agent: Agent | ExternalAgent | AssistantAgent ) -> Agent | ExternalAgent | AssistantAgent:
481
+ def dereference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
447
482
  if isinstance(agent, Agent):
448
483
  return self.dereference_native_agent_dependencies(agent)
449
484
  if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
450
485
  return self.dereference_external_or_assistant_agent_dependencies(agent)
451
-
486
+
452
487
  # Convert all ids used in an agent to the corresponding names
453
- def reference_agent_dependencies(self, agent: Agent | ExternalAgent | AssistantAgent ) -> Agent | ExternalAgent | AssistantAgent:
488
+ def reference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
454
489
  if isinstance(agent, Agent):
455
490
  return self.reference_native_agent_dependencies(agent)
456
491
  if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
457
492
  return self.reference_external_or_assistant_agent_dependencies(agent)
458
493
 
459
494
  def publish_or_update_agents(
460
- self, agents: Iterable[Agent]
495
+ self, agents: Iterable[Agent | ExternalAgent | AssistantAgent]
461
496
  ):
462
497
  for agent in agents:
463
498
  agent_name = agent.name
@@ -476,6 +511,18 @@ class AgentsController:
476
511
  all_existing_agents = existing_external_clients + existing_native_agents + existing_assistant_clients
477
512
  agent = self.dereference_agent_dependencies(agent)
478
513
 
514
+ if isinstance(agent, Agent) and agent.style == AgentStyle.PLANNER and isinstance(agent.custom_join_tool, str):
515
+ tool_client = self.get_tool_client()
516
+
517
+ join_tool_spec = ToolSpec.model_validate(
518
+ tool_client.get_draft_by_id(agent.custom_join_tool)
519
+ )
520
+ if not join_tool_spec.is_custom_join_tool():
521
+ logger.error(
522
+ f"Tool '{join_tool_spec.name}' configured as the custom join tool is not a valid join tool. A custom join tool must be a Python tool with specific input and output schema."
523
+ )
524
+ sys.exit(1)
525
+
479
526
  agent_kind = agent.kind
480
527
 
481
528
  if len(all_existing_agents) > 1:
@@ -762,32 +809,32 @@ class AgentsController:
762
809
  rich.print(assistants_table)
763
810
 
764
811
  def remove_agent(self, name: str, kind: AgentKind):
765
- try:
766
- if kind == AgentKind.NATIVE:
767
- client = self.get_native_client()
768
- elif kind == AgentKind.EXTERNAL:
769
- client = self.get_external_client()
770
- elif kind == AgentKind.ASSISTANT:
771
- client = self.get_assistant_client()
772
- else:
773
- raise ValueError("'kind' must be 'native'")
774
-
775
- draft_agents = client.get_draft_by_name(name)
776
- if len(draft_agents) > 1:
777
- logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
778
- sys.exit(1)
779
- if len(draft_agents) > 0:
780
- draft_agent = draft_agents[0]
781
- agent_id = draft_agent.get("id")
782
- client.delete(agent_id=agent_id)
783
-
784
- logger.info(f"Successfully removed agent {name}")
785
- else:
786
- logger.warning(f"No agent named '{name}' found")
787
- except requests.HTTPError as e:
788
- logger.error(e.response.text)
789
- exit(1)
790
-
812
+ try:
813
+ if kind == AgentKind.NATIVE:
814
+ client = self.get_native_client()
815
+ elif kind == AgentKind.EXTERNAL:
816
+ client = self.get_external_client()
817
+ elif kind == AgentKind.ASSISTANT:
818
+ client = self.get_assistant_client()
819
+ else:
820
+ raise ValueError("'kind' must be 'native'")
821
+
822
+ draft_agents = client.get_draft_by_name(name)
823
+ if len(draft_agents) > 1:
824
+ logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
825
+ sys.exit(1)
826
+ if len(draft_agents) > 0:
827
+ draft_agent = draft_agents[0]
828
+ agent_id = draft_agent.get("id")
829
+ client.delete(agent_id=agent_id)
830
+
831
+ logger.info(f"Successfully removed agent {name}")
832
+ else:
833
+ logger.warning(f"No agent named '{name}' found")
834
+ except requests.HTTPError as e:
835
+ logger.error(e.response.text)
836
+ exit(1)
837
+
791
838
  def get_spec_file_content(self, agent: Agent | ExternalAgent | AssistantAgent):
792
839
  ref_agent = self.reference_agent_dependencies(agent)
793
840
  agent_spec = ref_agent.model_dump(mode='json', exclude_none=True)
@@ -810,7 +857,7 @@ class AgentsController:
810
857
 
811
858
  return agent
812
859
 
813
- def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent:
860
+ def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent | None:
814
861
  native_client = self.get_native_client()
815
862
  external_client = self.get_external_client()
816
863
  assistant_client = self.get_assistant_client()
@@ -0,0 +1,191 @@
1
+ # https://portkey.ai/
2
+ from dotenv import dotenv_values
3
+ import sys
4
+ import logging
5
+ from ibm_watsonx_orchestrate.agent_builder.models.types import ProviderConfig, ModelProvider
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ _BASIC_PROVIDER_CONFIG_KEYS = {'provider', 'api_key', 'custom_host', 'url_to_fetch', 'forward_headers', 'request_timeout', 'transform_to_form_data'}
10
+
11
+ PROVIDER_EXTRA_PROPERTIES_LUT = {
12
+ ModelProvider.ANTHROPIC: {'anthropic_beta', 'anthropic_version'},
13
+ # ModelProvider.AZURE_AI: {
14
+ # 'azure_resource_name',
15
+ # 'azure_deployment_id',
16
+ # 'azure_api_version',
17
+ # 'ad_auth',
18
+ # 'azure_auth_mode',
19
+ # 'azure_managed_client_id',
20
+ # 'azure_entra_client_id',
21
+ # 'azure_entra_client_secret',
22
+ # 'azure_entra_tenant_id',
23
+ # 'azure_ad_token',
24
+ # 'azure_model_name'
25
+ # },
26
+ # ModelProvider.BEDROCK: {
27
+ # 'aws_secret_access_key',
28
+ # 'aws_access_key_id',
29
+ # 'aws_session_token',
30
+ # 'aws_region',
31
+ # 'aws_auth_type',
32
+ # 'aws_role_arn',
33
+ # 'aws_external_id',
34
+ # 'aws_s3_bucket',
35
+ # 'aws_s3_object_key',
36
+ # 'aws_bedrock_model',
37
+ # 'aws_server_side_encryption',
38
+ # 'aws_server_side_encryption_kms_key_id'
39
+ # },
40
+ # ModelProvider.VERTEX_AI: {
41
+ # 'vertex_region',
42
+ # 'vertex_project_id',
43
+ # 'vertex_service_account_json',
44
+ # 'vertex_storage_bucket_name',
45
+ # 'vertex_model_name',
46
+ # 'filename'
47
+ # },
48
+ # ModelProvider.HUGGINGFACE: {'huggingfaceBaseUrl'},
49
+ ModelProvider.MISTRAL_AI: {'mistral_fim_completion'},
50
+ # ModelProvider.STABILITY_AI: {'stability_client_id', 'stability_client_user_id', 'stability_client_version'},
51
+ ModelProvider.WATSONX: {'watsonx_version', 'watsonx_space_id', 'watsonx_project_id', 'api_key', 'watsonx_deployment_id', 'watsonx_cpd_url', 'watsonx_cpd_username', 'watsonx_cpd_password'},
52
+
53
+ # 'palm': _bpp('PALM', {}),
54
+ # 'nomic': _bpp('NOMIC', {}),
55
+ # 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
56
+ # 'segmind': _bpp('SEGMIND', {}),
57
+ # 'deepinfra': _bpp('DEEPINFRA', {}),
58
+ # 'novita-ai': _bpp('NOVITA_AI', {}),
59
+ # 'fireworks-ai': _bpp('FIREWORKS',{
60
+ # 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
61
+ # }),
62
+ # 'deepseek': _bpp('DEEPSEEK', {}),
63
+ # 'voyage': _bpp('VOYAGE', {}),
64
+ # 'moonshot': _bpp('MOONSHOT', {}),
65
+ # 'lingyi': _bpp('LINGYI', {}),
66
+ # 'zhipu': _bpp('ZHIPU', {}),
67
+ # 'monsterapi': _bpp('MONSTERAPI', {}),
68
+ # 'predibase': _bpp('PREDIBASE', {}),
69
+
70
+ # 'github': _bpp('GITHUB', {}),
71
+ # 'deepbricks': _bpp('DEEPBRICKS', {}),
72
+ # 'sagemaker': _bpp('AMZN_SAGEMAKER', {
73
+ # 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
74
+ # 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
75
+ # 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
76
+ # 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
77
+ # 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
78
+ # 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
79
+ # 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
80
+ # 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
81
+ # 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
82
+ # }),
83
+ # '@cf': _bpp('WORKERS_AI', { # workers ai
84
+ # 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
85
+ # }),
86
+ # 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
87
+ # 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
88
+ # })
89
+ }
90
+
91
+ PROVIDER_REQUIRED_FIELDS = {k:['api_key'] for k in ModelProvider}
92
+ # Update required fields for each provider
93
+ # Use sets to denote when a requirement is 'or'
94
+ PROVIDER_REQUIRED_FIELDS.update({
95
+ ModelProvider.WATSONX: PROVIDER_REQUIRED_FIELDS[ModelProvider.WATSONX] + [{'watsonx_space_id', 'watsonx_project_id', 'watsonx_deployment_id'}],
96
+ ModelProvider.OLLAMA: PROVIDER_REQUIRED_FIELDS[ModelProvider.OLLAMA] + ['custom_host']
97
+ })
98
+
99
+ # def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
100
+ # provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
101
+ # if provider not in ModelProvider:
102
+ # logger.error(f"Unsupported model provider {provider}")
103
+ # sys.exit(1)
104
+
105
+ # values = dotenv_values(str(env_file_path))
106
+
107
+ # if values is None:
108
+ # logger.error(f"No provider configuration in env file {env_file_path}")
109
+ # sys.exit(1)
110
+
111
+ # cfg = ProviderConfig()
112
+ # cfg.provider = PROVIDER_LUT[provider]
113
+
114
+ # cred_lut = PROVIDER_PROPERTIES_LUT[provider]
115
+
116
+
117
+ # consumed_credentials = []
118
+ # # Ollama requires some apikey but its content don't matter
119
+ # # Default it to 'ollama' to avoid users needing to specify
120
+ # if cfg.provider == ModelProvider.OLLAMA:
121
+ # consumed_credentials.append('api_key')
122
+ # setattr(cfg, 'api_key', ModelProvider.OLLAMA)
123
+
124
+ # for key, value in values.items():
125
+ # if key in cred_lut:
126
+ # k = cred_lut[key]
127
+ # consumed_credentials.append(k)
128
+ # setattr(cfg, k, value)
129
+
130
+ # return cfg
131
+
132
+ def _validate_provider(provider: str | ModelProvider) -> None:
133
+ if not ModelProvider.has_value(provider):
134
+ logger.error(f"Unsupported model provider {provider}")
135
+ sys.exit(1)
136
+
137
+ def _validate_extra_fields(provider: ModelProvider, cfg: ProviderConfig) -> None:
138
+ accepted_fields = _BASIC_PROVIDER_CONFIG_KEYS.copy()
139
+ extra_accepted_fields = PROVIDER_EXTRA_PROPERTIES_LUT.get(provider)
140
+ if extra_accepted_fields:
141
+ accepted_fields = accepted_fields.union(extra_accepted_fields)
142
+
143
+ for attr in cfg.__dict__:
144
+ if attr.startswith("__"):
145
+ continue
146
+
147
+ if cfg.__dict__.get(attr) is not None and attr not in accepted_fields:
148
+ logger.warning(f"The config option '{attr}' is not used by provider '{provider}'")
149
+
150
+ def _validate_requirements(provider: ModelProvider, cfg: ProviderConfig, app_id: str = None) -> None:
151
+ provided_credentials = set([k for k,v in dict(cfg).items() if v is not None])
152
+ required_creds = PROVIDER_REQUIRED_FIELDS[provider]
153
+ missing_credentials = []
154
+ for cred in required_creds:
155
+ if isinstance(cred, set):
156
+ if not any(c in provided_credentials for c in cred):
157
+ missing_credentials.append(cred)
158
+ else:
159
+ if cred not in provided_credentials:
160
+ missing_credentials.append(cred)
161
+
162
+ if len(missing_credentials) > 0:
163
+ if not app_id:
164
+ missing_credentials_string = f"Missing configuration variable(s) required for the provider {provider}:"
165
+ else:
166
+ missing_credentials_string = f"The following configuration variable(s) for the provider {provider} are not in the spec provider_config:"
167
+ for cred in missing_credentials:
168
+ if isinstance(cred, set):
169
+ cred_str = ' or '.join(list(cred))
170
+ else:
171
+ cred_str = cred
172
+ missing_credentials_string += f"\n\t - {cred_str}"
173
+
174
+ if not app_id:
175
+ logger.error(missing_credentials_string)
176
+ logger.error("Please provide the above values in the provider config. For secret values (e.g. 'api_key') create a key_value connection `orchestrate connections add` then bind it to the model with `--app-id`")
177
+ sys.exit(1)
178
+ else:
179
+ logger.info(missing_credentials_string)
180
+ logger.info(f"Please ensure these values are set in the connection '{app_id}'.")
181
+
182
+
183
+ def validate_ProviderConfig(cfg: ProviderConfig, app_id: str)-> None:
184
+ if not cfg:
185
+ return
186
+
187
+ provider = cfg.provider
188
+
189
+ _validate_provider(provider)
190
+ _validate_extra_fields(provider=provider, cfg=cfg)
191
+ _validate_requirements(provider=provider, cfg=cfg, app_id=app_id)