ibm-watsonx-orchestrate 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibm_watsonx_orchestrate/__init__.py +1 -1
- ibm_watsonx_orchestrate/agent_builder/agents/types.py +10 -1
- ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +14 -0
- ibm_watsonx_orchestrate/agent_builder/model_policies/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/model_policies/types.py +7 -8
- ibm_watsonx_orchestrate/agent_builder/models/__init__.py +1 -0
- ibm_watsonx_orchestrate/{client → agent_builder}/models/types.py +57 -9
- ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +46 -3
- ibm_watsonx_orchestrate/agent_builder/tools/types.py +47 -1
- ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +17 -0
- ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +86 -39
- ibm_watsonx_orchestrate/cli/commands/models/model_provider_mapper.py +191 -0
- ibm_watsonx_orchestrate/cli/commands/models/models_command.py +140 -258
- ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +437 -0
- ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -1
- ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +1 -1
- ibm_watsonx_orchestrate/client/connections/__init__.py +2 -1
- ibm_watsonx_orchestrate/client/connections/utils.py +30 -0
- ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +23 -4
- ibm_watsonx_orchestrate/client/models/models_client.py +23 -3
- ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +13 -8
- ibm_watsonx_orchestrate/client/tools/tool_client.py +2 -1
- ibm_watsonx_orchestrate/docker/compose-lite.yml +2 -0
- ibm_watsonx_orchestrate/docker/default.env +10 -11
- ibm_watsonx_orchestrate/experimental/flow_builder/data_map.py +19 -0
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +4 -3
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +3 -1
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +3 -2
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +245 -223
- ibm_watsonx_orchestrate/experimental/flow_builder/node.py +34 -15
- ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +7 -39
- ibm_watsonx_orchestrate/experimental/flow_builder/types.py +285 -12
- ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +3 -1
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/METADATA +1 -1
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/RECORD +38 -35
- ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +0 -180
- ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +0 -91
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/WHEEL +0 -0
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/entry_points.txt +0 -0
- {ibm_watsonx_orchestrate-1.4.2.dist-info → ibm_watsonx_orchestrate-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -11,9 +11,12 @@ import logging
|
|
11
11
|
from pathlib import Path
|
12
12
|
from copy import deepcopy
|
13
13
|
|
14
|
-
from typing import Iterable, List
|
14
|
+
from typing import Iterable, List, TypeVar
|
15
|
+
from ibm_watsonx_orchestrate.agent_builder.agents.types import AgentStyle
|
16
|
+
from ibm_watsonx_orchestrate.agent_builder.tools.types import ToolSpec
|
15
17
|
from ibm_watsonx_orchestrate.cli.commands.tools.tools_controller import import_python_tool, ToolsController
|
16
18
|
from ibm_watsonx_orchestrate.cli.commands.knowledge_bases.knowledge_bases_controller import import_python_knowledge_base
|
19
|
+
from ibm_watsonx_orchestrate.cli.commands.models.models_controller import import_python_model
|
17
20
|
|
18
21
|
from ibm_watsonx_orchestrate.agent_builder.agents import (
|
19
22
|
Agent,
|
@@ -34,10 +37,14 @@ from ibm_watsonx_orchestrate.utils.utils import check_file_in_zip
|
|
34
37
|
|
35
38
|
logger = logging.getLogger(__name__)
|
36
39
|
|
40
|
+
# Helper generic type for any agent
|
41
|
+
AnyAgentT = TypeVar("AnyAgentT", bound=Agent | ExternalAgent | AssistantAgent)
|
42
|
+
|
37
43
|
def import_python_agent(file: str) -> List[Agent | ExternalAgent | AssistantAgent]:
|
38
44
|
# Import tools
|
39
45
|
import_python_tool(file)
|
40
46
|
import_python_knowledge_base(file)
|
47
|
+
import_python_model(file)
|
41
48
|
|
42
49
|
file_path = Path(file)
|
43
50
|
file_directory = file_path.parent
|
@@ -90,6 +97,8 @@ def parse_create_native_args(name: str, kind: AgentKind, description: str | None
|
|
90
97
|
"description": description,
|
91
98
|
"llm": args.get("llm"),
|
92
99
|
"style": args.get("style"),
|
100
|
+
"custom_join_tool": args.get("custom_join_tool"),
|
101
|
+
"structured_output": args.get("structured_output"),
|
93
102
|
}
|
94
103
|
|
95
104
|
collaborators = args.get("collaborators", [])
|
@@ -202,7 +211,7 @@ class AgentsController:
|
|
202
211
|
return self.knowledge_base_client
|
203
212
|
|
204
213
|
@staticmethod
|
205
|
-
def import_agent(file: str, app_id: str) ->
|
214
|
+
def import_agent(file: str, app_id: str) -> List[Agent | ExternalAgent | AssistantAgent]:
|
206
215
|
agents = parse_file(file)
|
207
216
|
for agent in agents:
|
208
217
|
if app_id and agent.kind != AgentKind.NATIVE and agent.kind != AgentKind.ASSISTANT:
|
@@ -216,7 +225,9 @@ class AgentsController:
|
|
216
225
|
) -> Agent | ExternalAgent | AssistantAgent:
|
217
226
|
match kind:
|
218
227
|
case AgentKind.NATIVE:
|
219
|
-
agent_details = parse_create_native_args(
|
228
|
+
agent_details = parse_create_native_args(
|
229
|
+
name, kind=kind, description=description, **kwargs
|
230
|
+
)
|
220
231
|
agent = Agent.model_validate(agent_details)
|
221
232
|
AgentsController().persist_record(agent=agent, **kwargs)
|
222
233
|
case AgentKind.EXTERNAL:
|
@@ -296,12 +307,17 @@ class AgentsController:
|
|
296
307
|
ref_agent.collaborators = ref_collaborators
|
297
308
|
|
298
309
|
return ref_agent
|
299
|
-
|
310
|
+
|
300
311
|
def dereference_tools(self, agent: Agent) -> Agent:
|
301
312
|
tool_client = self.get_tool_client()
|
302
313
|
|
303
314
|
deref_agent = deepcopy(agent)
|
304
|
-
|
315
|
+
|
316
|
+
# If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
|
317
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
318
|
+
matching_tools = tool_client.get_drafts_by_names(deref_agent.tools + [deref_agent.custom_join_tool])
|
319
|
+
else:
|
320
|
+
matching_tools = tool_client.get_drafts_by_names(deref_agent.tools)
|
305
321
|
|
306
322
|
name_id_lookup = {}
|
307
323
|
for tool in matching_tools:
|
@@ -318,6 +334,13 @@ class AgentsController:
|
|
318
334
|
sys.exit(1)
|
319
335
|
deref_tools.append(id)
|
320
336
|
deref_agent.tools = deref_tools
|
337
|
+
|
338
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
339
|
+
join_tool_id = name_id_lookup.get(agent.custom_join_tool)
|
340
|
+
if not join_tool_id:
|
341
|
+
logger.error(f"Failed to find custom join tool. No tools found with the name '{agent.custom_join_tool}'")
|
342
|
+
sys.exit(1)
|
343
|
+
deref_agent.custom_join_tool = join_tool_id
|
321
344
|
|
322
345
|
return deref_agent
|
323
346
|
|
@@ -325,7 +348,12 @@ class AgentsController:
|
|
325
348
|
tool_client = self.get_tool_client()
|
326
349
|
|
327
350
|
ref_agent = deepcopy(agent)
|
328
|
-
|
351
|
+
|
352
|
+
# If agent has style set to "planner" and have join_tool defined, then we need to include that tool as well
|
353
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
354
|
+
matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools + [ref_agent.custom_join_tool])
|
355
|
+
else:
|
356
|
+
matching_tools = tool_client.get_drafts_by_ids(ref_agent.tools)
|
329
357
|
|
330
358
|
id_name_lookup = {}
|
331
359
|
for tool in matching_tools:
|
@@ -342,6 +370,13 @@ class AgentsController:
|
|
342
370
|
sys.exit(1)
|
343
371
|
ref_tools.append(name)
|
344
372
|
ref_agent.tools = ref_tools
|
373
|
+
|
374
|
+
if agent.style == AgentStyle.PLANNER and agent.custom_join_tool:
|
375
|
+
join_tool_name = id_name_lookup.get(agent.custom_join_tool)
|
376
|
+
if not join_tool_name:
|
377
|
+
logger.error(f"Failed to find custom join tool. No tools found with the id '{agent.custom_join_tool}'")
|
378
|
+
sys.exit(1)
|
379
|
+
ref_agent.custom_join_tool = join_tool_name
|
345
380
|
|
346
381
|
return ref_agent
|
347
382
|
|
@@ -409,7 +444,7 @@ class AgentsController:
|
|
409
444
|
def dereference_native_agent_dependencies(self, agent: Agent) -> Agent:
|
410
445
|
if agent.collaborators and len(agent.collaborators):
|
411
446
|
agent = self.dereference_collaborators(agent)
|
412
|
-
if agent.tools and len(agent.tools):
|
447
|
+
if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
|
413
448
|
agent = self.dereference_tools(agent)
|
414
449
|
if agent.knowledge_base and len(agent.knowledge_base):
|
415
450
|
agent = self.dereference_knowledge_bases(agent)
|
@@ -419,7 +454,7 @@ class AgentsController:
|
|
419
454
|
def reference_native_agent_dependencies(self, agent: Agent) -> Agent:
|
420
455
|
if agent.collaborators and len(agent.collaborators):
|
421
456
|
agent = self.reference_collaborators(agent)
|
422
|
-
if agent.tools and len(agent.tools):
|
457
|
+
if (agent.tools and len(agent.tools)) or (agent.style == AgentStyle.PLANNER and agent.custom_join_tool):
|
423
458
|
agent = self.reference_tools(agent)
|
424
459
|
if agent.knowledge_base and len(agent.knowledge_base):
|
425
460
|
agent = self.reference_knowledge_bases(agent)
|
@@ -443,21 +478,21 @@ class AgentsController:
|
|
443
478
|
return agent
|
444
479
|
|
445
480
|
# Convert all names used in an agent to the corresponding ids
|
446
|
-
def dereference_agent_dependencies(self, agent:
|
481
|
+
def dereference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
|
447
482
|
if isinstance(agent, Agent):
|
448
483
|
return self.dereference_native_agent_dependencies(agent)
|
449
484
|
if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
|
450
485
|
return self.dereference_external_or_assistant_agent_dependencies(agent)
|
451
|
-
|
486
|
+
|
452
487
|
# Convert all ids used in an agent to the corresponding names
|
453
|
-
def reference_agent_dependencies(self, agent:
|
488
|
+
def reference_agent_dependencies(self, agent: AnyAgentT) -> AnyAgentT:
|
454
489
|
if isinstance(agent, Agent):
|
455
490
|
return self.reference_native_agent_dependencies(agent)
|
456
491
|
if isinstance(agent, ExternalAgent) or isinstance(agent, AssistantAgent):
|
457
492
|
return self.reference_external_or_assistant_agent_dependencies(agent)
|
458
493
|
|
459
494
|
def publish_or_update_agents(
|
460
|
-
self, agents: Iterable[Agent]
|
495
|
+
self, agents: Iterable[Agent | ExternalAgent | AssistantAgent]
|
461
496
|
):
|
462
497
|
for agent in agents:
|
463
498
|
agent_name = agent.name
|
@@ -476,6 +511,18 @@ class AgentsController:
|
|
476
511
|
all_existing_agents = existing_external_clients + existing_native_agents + existing_assistant_clients
|
477
512
|
agent = self.dereference_agent_dependencies(agent)
|
478
513
|
|
514
|
+
if isinstance(agent, Agent) and agent.style == AgentStyle.PLANNER and isinstance(agent.custom_join_tool, str):
|
515
|
+
tool_client = self.get_tool_client()
|
516
|
+
|
517
|
+
join_tool_spec = ToolSpec.model_validate(
|
518
|
+
tool_client.get_draft_by_id(agent.custom_join_tool)
|
519
|
+
)
|
520
|
+
if not join_tool_spec.is_custom_join_tool():
|
521
|
+
logger.error(
|
522
|
+
f"Tool '{join_tool_spec.name}' configured as the custom join tool is not a valid join tool. A custom join tool must be a Python tool with specific input and output schema."
|
523
|
+
)
|
524
|
+
sys.exit(1)
|
525
|
+
|
479
526
|
agent_kind = agent.kind
|
480
527
|
|
481
528
|
if len(all_existing_agents) > 1:
|
@@ -762,32 +809,32 @@ class AgentsController:
|
|
762
809
|
rich.print(assistants_table)
|
763
810
|
|
764
811
|
def remove_agent(self, name: str, kind: AgentKind):
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
812
|
+
try:
|
813
|
+
if kind == AgentKind.NATIVE:
|
814
|
+
client = self.get_native_client()
|
815
|
+
elif kind == AgentKind.EXTERNAL:
|
816
|
+
client = self.get_external_client()
|
817
|
+
elif kind == AgentKind.ASSISTANT:
|
818
|
+
client = self.get_assistant_client()
|
819
|
+
else:
|
820
|
+
raise ValueError("'kind' must be 'native'")
|
821
|
+
|
822
|
+
draft_agents = client.get_draft_by_name(name)
|
823
|
+
if len(draft_agents) > 1:
|
824
|
+
logger.error(f"Multiple '{kind}' agents found with name '{name}'. Failed to delete agent")
|
825
|
+
sys.exit(1)
|
826
|
+
if len(draft_agents) > 0:
|
827
|
+
draft_agent = draft_agents[0]
|
828
|
+
agent_id = draft_agent.get("id")
|
829
|
+
client.delete(agent_id=agent_id)
|
830
|
+
|
831
|
+
logger.info(f"Successfully removed agent {name}")
|
832
|
+
else:
|
833
|
+
logger.warning(f"No agent named '{name}' found")
|
834
|
+
except requests.HTTPError as e:
|
835
|
+
logger.error(e.response.text)
|
836
|
+
exit(1)
|
837
|
+
|
791
838
|
def get_spec_file_content(self, agent: Agent | ExternalAgent | AssistantAgent):
|
792
839
|
ref_agent = self.reference_agent_dependencies(agent)
|
793
840
|
agent_spec = ref_agent.model_dump(mode='json', exclude_none=True)
|
@@ -810,7 +857,7 @@ class AgentsController:
|
|
810
857
|
|
811
858
|
return agent
|
812
859
|
|
813
|
-
def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent:
|
860
|
+
def get_agent_by_id(self, id: str) -> Agent | ExternalAgent | AssistantAgent | None:
|
814
861
|
native_client = self.get_native_client()
|
815
862
|
external_client = self.get_external_client()
|
816
863
|
assistant_client = self.get_assistant_client()
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# https://portkey.ai/
|
2
|
+
from dotenv import dotenv_values
|
3
|
+
import sys
|
4
|
+
import logging
|
5
|
+
from ibm_watsonx_orchestrate.agent_builder.models.types import ProviderConfig, ModelProvider
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
_BASIC_PROVIDER_CONFIG_KEYS = {'provider', 'api_key', 'custom_host', 'url_to_fetch', 'forward_headers', 'request_timeout', 'transform_to_form_data'}
|
10
|
+
|
11
|
+
PROVIDER_EXTRA_PROPERTIES_LUT = {
|
12
|
+
ModelProvider.ANTHROPIC: {'anthropic_beta', 'anthropic_version'},
|
13
|
+
# ModelProvider.AZURE_AI: {
|
14
|
+
# 'azure_resource_name',
|
15
|
+
# 'azure_deployment_id',
|
16
|
+
# 'azure_api_version',
|
17
|
+
# 'ad_auth',
|
18
|
+
# 'azure_auth_mode',
|
19
|
+
# 'azure_managed_client_id',
|
20
|
+
# 'azure_entra_client_id',
|
21
|
+
# 'azure_entra_client_secret',
|
22
|
+
# 'azure_entra_tenant_id',
|
23
|
+
# 'azure_ad_token',
|
24
|
+
# 'azure_model_name'
|
25
|
+
# },
|
26
|
+
# ModelProvider.BEDROCK: {
|
27
|
+
# 'aws_secret_access_key',
|
28
|
+
# 'aws_access_key_id',
|
29
|
+
# 'aws_session_token',
|
30
|
+
# 'aws_region',
|
31
|
+
# 'aws_auth_type',
|
32
|
+
# 'aws_role_arn',
|
33
|
+
# 'aws_external_id',
|
34
|
+
# 'aws_s3_bucket',
|
35
|
+
# 'aws_s3_object_key',
|
36
|
+
# 'aws_bedrock_model',
|
37
|
+
# 'aws_server_side_encryption',
|
38
|
+
# 'aws_server_side_encryption_kms_key_id'
|
39
|
+
# },
|
40
|
+
# ModelProvider.VERTEX_AI: {
|
41
|
+
# 'vertex_region',
|
42
|
+
# 'vertex_project_id',
|
43
|
+
# 'vertex_service_account_json',
|
44
|
+
# 'vertex_storage_bucket_name',
|
45
|
+
# 'vertex_model_name',
|
46
|
+
# 'filename'
|
47
|
+
# },
|
48
|
+
# ModelProvider.HUGGINGFACE: {'huggingfaceBaseUrl'},
|
49
|
+
ModelProvider.MISTRAL_AI: {'mistral_fim_completion'},
|
50
|
+
# ModelProvider.STABILITY_AI: {'stability_client_id', 'stability_client_user_id', 'stability_client_version'},
|
51
|
+
ModelProvider.WATSONX: {'watsonx_version', 'watsonx_space_id', 'watsonx_project_id', 'api_key', 'watsonx_deployment_id', 'watsonx_cpd_url', 'watsonx_cpd_username', 'watsonx_cpd_password'},
|
52
|
+
|
53
|
+
# 'palm': _bpp('PALM', {}),
|
54
|
+
# 'nomic': _bpp('NOMIC', {}),
|
55
|
+
# 'perplexity-ai': _bpp('PERPLEXITY_AI', {}),
|
56
|
+
# 'segmind': _bpp('SEGMIND', {}),
|
57
|
+
# 'deepinfra': _bpp('DEEPINFRA', {}),
|
58
|
+
# 'novita-ai': _bpp('NOVITA_AI', {}),
|
59
|
+
# 'fireworks-ai': _bpp('FIREWORKS',{
|
60
|
+
# 'FIREWORKS_ACCOUNT_ID': 'fireworks_account_id'
|
61
|
+
# }),
|
62
|
+
# 'deepseek': _bpp('DEEPSEEK', {}),
|
63
|
+
# 'voyage': _bpp('VOYAGE', {}),
|
64
|
+
# 'moonshot': _bpp('MOONSHOT', {}),
|
65
|
+
# 'lingyi': _bpp('LINGYI', {}),
|
66
|
+
# 'zhipu': _bpp('ZHIPU', {}),
|
67
|
+
# 'monsterapi': _bpp('MONSTERAPI', {}),
|
68
|
+
# 'predibase': _bpp('PREDIBASE', {}),
|
69
|
+
|
70
|
+
# 'github': _bpp('GITHUB', {}),
|
71
|
+
# 'deepbricks': _bpp('DEEPBRICKS', {}),
|
72
|
+
# 'sagemaker': _bpp('AMZN_SAGEMAKER', {
|
73
|
+
# 'AMZN_SAGEMAKER_CUSTOM_ATTRIBUTES': 'amzn_sagemaker_custom_attributes',
|
74
|
+
# 'AMZN_SAGEMAKER_TARGET_MODEL': 'amzn_sagemaker_target_model',
|
75
|
+
# 'AMZN_SAGEMAKER_TARGET_VARIANT': 'amzn_sagemaker_target_variant',
|
76
|
+
# 'AMZN_SAGEMAKER_TARGET_CONTAINER_HOSTNAME': 'amzn_sagemaker_target_container_hostname',
|
77
|
+
# 'AMZN_SAGEMAKER_INFERENCE_ID': 'amzn_sagemaker_inference_id',
|
78
|
+
# 'AMZN_SAGEMAKER_ENABLE_EXPLANATIONS': 'amzn_sagemaker_enable_explanations',
|
79
|
+
# 'AMZN_SAGEMAKER_INFERENCE_COMPONENT': 'amzn_sagemaker_inference_component',
|
80
|
+
# 'AMZN_SAGEMAKER_SESSION_ID': 'amzn_sagemaker_session_id',
|
81
|
+
# 'AMZN_SAGEMAKER_MODEL_NAME': 'amzn_sagemaker_model_name'
|
82
|
+
# }),
|
83
|
+
# '@cf': _bpp('WORKERS_AI', { # workers ai
|
84
|
+
# 'WORKERS_AI_ACCOUNT_ID': 'workers_ai_account_id'
|
85
|
+
# }),
|
86
|
+
# 'snowflake': _bpp('SNOWFLAKE', { # no provider prefix found
|
87
|
+
# 'SNOWFLAKE_ACCOUNT': 'snowflake_account'
|
88
|
+
# })
|
89
|
+
}
|
90
|
+
|
91
|
+
PROVIDER_REQUIRED_FIELDS = {k:['api_key'] for k in ModelProvider}
|
92
|
+
# Update required fields for each provider
|
93
|
+
# Use sets to denote when a requirement is 'or'
|
94
|
+
PROVIDER_REQUIRED_FIELDS.update({
|
95
|
+
ModelProvider.WATSONX: PROVIDER_REQUIRED_FIELDS[ModelProvider.WATSONX] + [{'watsonx_space_id', 'watsonx_project_id', 'watsonx_deployment_id'}],
|
96
|
+
ModelProvider.OLLAMA: PROVIDER_REQUIRED_FIELDS[ModelProvider.OLLAMA] + ['custom_host']
|
97
|
+
})
|
98
|
+
|
99
|
+
# def env_file_to_model_ProviderConfig(model_name: str, env_file_path: str) -> ProviderConfig | None:
|
100
|
+
# provider = next(filter(lambda x: x not in ('virtual-policy', 'virtual-model'), model_name.split('/')))
|
101
|
+
# if provider not in ModelProvider:
|
102
|
+
# logger.error(f"Unsupported model provider {provider}")
|
103
|
+
# sys.exit(1)
|
104
|
+
|
105
|
+
# values = dotenv_values(str(env_file_path))
|
106
|
+
|
107
|
+
# if values is None:
|
108
|
+
# logger.error(f"No provider configuration in env file {env_file_path}")
|
109
|
+
# sys.exit(1)
|
110
|
+
|
111
|
+
# cfg = ProviderConfig()
|
112
|
+
# cfg.provider = PROVIDER_LUT[provider]
|
113
|
+
|
114
|
+
# cred_lut = PROVIDER_PROPERTIES_LUT[provider]
|
115
|
+
|
116
|
+
|
117
|
+
# consumed_credentials = []
|
118
|
+
# # Ollama requires some apikey but its content don't matter
|
119
|
+
# # Default it to 'ollama' to avoid users needing to specify
|
120
|
+
# if cfg.provider == ModelProvider.OLLAMA:
|
121
|
+
# consumed_credentials.append('api_key')
|
122
|
+
# setattr(cfg, 'api_key', ModelProvider.OLLAMA)
|
123
|
+
|
124
|
+
# for key, value in values.items():
|
125
|
+
# if key in cred_lut:
|
126
|
+
# k = cred_lut[key]
|
127
|
+
# consumed_credentials.append(k)
|
128
|
+
# setattr(cfg, k, value)
|
129
|
+
|
130
|
+
# return cfg
|
131
|
+
|
132
|
+
def _validate_provider(provider: str | ModelProvider) -> None:
|
133
|
+
if not ModelProvider.has_value(provider):
|
134
|
+
logger.error(f"Unsupported model provider {provider}")
|
135
|
+
sys.exit(1)
|
136
|
+
|
137
|
+
def _validate_extra_fields(provider: ModelProvider, cfg: ProviderConfig) -> None:
|
138
|
+
accepted_fields = _BASIC_PROVIDER_CONFIG_KEYS.copy()
|
139
|
+
extra_accepted_fields = PROVIDER_EXTRA_PROPERTIES_LUT.get(provider)
|
140
|
+
if extra_accepted_fields:
|
141
|
+
accepted_fields = accepted_fields.union(extra_accepted_fields)
|
142
|
+
|
143
|
+
for attr in cfg.__dict__:
|
144
|
+
if attr.startswith("__"):
|
145
|
+
continue
|
146
|
+
|
147
|
+
if cfg.__dict__.get(attr) is not None and attr not in accepted_fields:
|
148
|
+
logger.warning(f"The config option '{attr}' is not used by provider '{provider}'")
|
149
|
+
|
150
|
+
def _validate_requirements(provider: ModelProvider, cfg: ProviderConfig, app_id: str = None) -> None:
|
151
|
+
provided_credentials = set([k for k,v in dict(cfg).items() if v is not None])
|
152
|
+
required_creds = PROVIDER_REQUIRED_FIELDS[provider]
|
153
|
+
missing_credentials = []
|
154
|
+
for cred in required_creds:
|
155
|
+
if isinstance(cred, set):
|
156
|
+
if not any(c in provided_credentials for c in cred):
|
157
|
+
missing_credentials.append(cred)
|
158
|
+
else:
|
159
|
+
if cred not in provided_credentials:
|
160
|
+
missing_credentials.append(cred)
|
161
|
+
|
162
|
+
if len(missing_credentials) > 0:
|
163
|
+
if not app_id:
|
164
|
+
missing_credentials_string = f"Missing configuration variable(s) required for the provider {provider}:"
|
165
|
+
else:
|
166
|
+
missing_credentials_string = f"The following configuration variable(s) for the provider {provider} are not in the spec provider_config:"
|
167
|
+
for cred in missing_credentials:
|
168
|
+
if isinstance(cred, set):
|
169
|
+
cred_str = ' or '.join(list(cred))
|
170
|
+
else:
|
171
|
+
cred_str = cred
|
172
|
+
missing_credentials_string += f"\n\t - {cred_str}"
|
173
|
+
|
174
|
+
if not app_id:
|
175
|
+
logger.error(missing_credentials_string)
|
176
|
+
logger.error("Please provide the above values in the provider config. For secret values (e.g. 'api_key') create a key_value connection `orchestrate connections add` then bind it to the model with `--app-id`")
|
177
|
+
sys.exit(1)
|
178
|
+
else:
|
179
|
+
logger.info(missing_credentials_string)
|
180
|
+
logger.info(f"Please ensure these values are set in the connection '{app_id}'.")
|
181
|
+
|
182
|
+
|
183
|
+
def validate_ProviderConfig(cfg: ProviderConfig, app_id: str)-> None:
|
184
|
+
if not cfg:
|
185
|
+
return
|
186
|
+
|
187
|
+
provider = cfg.provider
|
188
|
+
|
189
|
+
_validate_provider(provider)
|
190
|
+
_validate_extra_fields(provider=provider, cfg=cfg)
|
191
|
+
_validate_requirements(provider=provider, cfg=cfg, app_id=app_id)
|