ibm-watsonx-orchestrate 1.8.0b0__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. ibm_watsonx_orchestrate/__init__.py +2 -1
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +12 -0
  3. ibm_watsonx_orchestrate/agent_builder/connections/types.py +14 -2
  4. ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +61 -11
  5. ibm_watsonx_orchestrate/agent_builder/tools/types.py +7 -2
  6. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +3 -3
  7. ibm_watsonx_orchestrate/cli/commands/channels/types.py +15 -2
  8. ibm_watsonx_orchestrate/cli/commands/channels/webchat/channels_webchat_controller.py +7 -7
  9. ibm_watsonx_orchestrate/cli/commands/connections/connections_command.py +14 -6
  10. ibm_watsonx_orchestrate/cli/commands/connections/connections_controller.py +6 -8
  11. ibm_watsonx_orchestrate/cli/commands/copilot/copilot_controller.py +111 -36
  12. ibm_watsonx_orchestrate/cli/commands/copilot/copilot_server_controller.py +23 -7
  13. ibm_watsonx_orchestrate/cli/commands/environment/types.py +1 -1
  14. ibm_watsonx_orchestrate/cli/commands/evaluations/evaluations_command.py +102 -37
  15. ibm_watsonx_orchestrate/cli/commands/evaluations/evaluations_controller.py +20 -2
  16. ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +10 -8
  17. ibm_watsonx_orchestrate/cli/commands/models/models_controller.py +5 -8
  18. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +2 -10
  19. ibm_watsonx_orchestrate/client/connections/connections_client.py +5 -30
  20. ibm_watsonx_orchestrate/client/copilot/cpe/copilot_cpe_client.py +2 -1
  21. ibm_watsonx_orchestrate/client/utils.py +22 -20
  22. ibm_watsonx_orchestrate/docker/compose-lite.yml +12 -5
  23. ibm_watsonx_orchestrate/docker/default.env +13 -12
  24. ibm_watsonx_orchestrate/flow_builder/flows/__init__.py +8 -5
  25. ibm_watsonx_orchestrate/flow_builder/flows/flow.py +47 -7
  26. ibm_watsonx_orchestrate/flow_builder/node.py +7 -1
  27. ibm_watsonx_orchestrate/flow_builder/types.py +168 -66
  28. ibm_watsonx_orchestrate/flow_builder/utils.py +0 -1
  29. {ibm_watsonx_orchestrate-1.8.0b0.dist-info → ibm_watsonx_orchestrate-1.8.1.dist-info}/METADATA +2 -4
  30. {ibm_watsonx_orchestrate-1.8.0b0.dist-info → ibm_watsonx_orchestrate-1.8.1.dist-info}/RECORD +33 -34
  31. ibm_watsonx_orchestrate/agent_builder/utils/pydantic_utils.py +0 -149
  32. {ibm_watsonx_orchestrate-1.8.0b0.dist-info → ibm_watsonx_orchestrate-1.8.1.dist-info}/WHEEL +0 -0
  33. {ibm_watsonx_orchestrate-1.8.0b0.dist-info → ibm_watsonx_orchestrate-1.8.1.dist-info}/entry_points.txt +0 -0
  34. {ibm_watsonx_orchestrate-1.8.0b0.dist-info → ibm_watsonx_orchestrate-1.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -71,6 +71,7 @@ class KnowledgeBaseController:
71
71
  client = self.get_client()
72
72
 
73
73
  knowledge_bases = parse_file(file=file)
74
+
74
75
  existing_knowledge_bases = client.get_by_names([kb.name for kb in knowledge_bases])
75
76
 
76
77
  for kb in knowledge_bases:
@@ -137,23 +138,24 @@ class KnowledgeBaseController:
137
138
 
138
139
  def update_knowledge_base(
139
140
  self, knowledge_base_id: str, kb: KnowledgeBase, file_dir: str
140
- ) -> None:
141
- filtered_files = []
142
-
141
+ ) -> None:
143
142
  if kb.documents:
144
143
  status = self.get_client().status(knowledge_base_id)
145
144
  existing_docs = [doc.get("metadata", {}).get("original_file_name", "") for doc in status.get("documents", [])]
146
145
 
146
+ removed_docs = existing_docs[:]
147
147
  for filepath in kb.documents:
148
148
  filename = get_file_name(filepath)
149
149
 
150
150
  if filename in existing_docs:
151
- logger.warning(f'Document \"{filename}\" already exists in knowledge base, skipping.')
152
- else:
153
- filtered_files.append(filepath)
151
+ logger.warning(f'Document \"{filename}\" already exists in knowledge base. Updating...')
152
+ removed_docs.remove(filename)
153
+
154
+ for filename in removed_docs:
155
+ logger.warning(f'Document \"{filename}\" removed from knowledge base.')
156
+
154
157
 
155
- if filtered_files:
156
- files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in filtered_files]
158
+ files = [('files', (get_file_name(file_path), open(get_relative_file_path(file_path, file_dir), 'rb'))) for file_path in kb.documents]
157
159
 
158
160
  kb.prioritize_built_in_index = True
159
161
  payload = kb.model_dump(exclude_none=True);
@@ -167,15 +167,12 @@ class ModelsController:
167
167
  logger.error("Error: WATSONX_URL is required in the environment.")
168
168
  sys.exit(1)
169
169
 
170
- if is_cpd_env(models_client.base_url):
171
- virtual_models = []
172
- virtual_model_policies = []
173
- else:
174
- logger.info("Retrieving virtual-model models list...")
175
- virtual_models = models_client.list()
170
+
171
+ logger.info("Retrieving virtual-model models list...")
172
+ virtual_models = models_client.list()
176
173
 
177
- logger.info("Retrieving virtual-policies models list...")
178
- virtual_model_policies = model_policies_client.list()
174
+ logger.info("Retrieving virtual-policies models list...")
175
+ virtual_model_policies = model_policies_client.list()
179
176
 
180
177
  logger.info("Retrieving watsonx.ai models list...")
181
178
  found_models = _get_wxai_foundational_models()
@@ -46,15 +46,7 @@ _ALWAYS_UNSET: set[str] = {
46
46
 
47
47
  def define_saas_wdu_runtime(value: str = "none") -> None:
48
48
  cfg = Config()
49
-
50
- current_config_file_values = cfg.get(USER_ENV_CACHE_HEADER)
51
- current_config_file_values["SAAS_WDU_RUNTIME"] = value
52
-
53
- cfg.save(
54
- {
55
- USER_ENV_CACHE_HEADER: current_config_file_values
56
- }
57
- )
49
+ cfg.write(USER_ENV_CACHE_HEADER,"SAAS_WDU_RUNTIME",value)
58
50
 
59
51
  def ensure_docker_installed() -> None:
60
52
  try:
@@ -885,7 +877,7 @@ def server_start(
885
877
  if experimental_with_langfuse:
886
878
  logger.info(f"You can access the observability platform Langfuse at http://localhost:3010, username: orchestrate@ibm.com, password: orchestrate")
887
879
  if with_doc_processing:
888
- logger.info(f"Document processing capabilities are now available for use in Flows (both ADK and runtime). Note: This option is currently available only in the Developer edition.")
880
+ logger.info(f"Document processing in Flows (Public Preview) has been enabled.")
889
881
 
890
882
  @server_app.command(name="stop")
891
883
  def server_stop(
@@ -63,12 +63,7 @@ class ConnectionsClient(BaseAPIClient):
63
63
  # GET /api/v1/connections/applications/{app_id}
64
64
  def get(self, app_id: str) -> GetConnectionResponse | None:
65
65
  try:
66
- path = (
67
- f"/connections/applications/{app_id}"
68
- if is_cpd_env(self.base_url)
69
- else f"/connections/applications?app_id={app_id}"
70
- )
71
- return GetConnectionResponse.model_validate(self._get(path))
66
+ return GetConnectionResponse.model_validate(self._get(f"/connections/applications?app_id={app_id}"))
72
67
  except ClientAPIException as e:
73
68
  if e.response.status_code == 404:
74
69
  return None
@@ -78,12 +73,7 @@ class ConnectionsClient(BaseAPIClient):
78
73
  # GET api/v1/connections/applications
79
74
  def list(self) -> List[ListConfigsResponse]:
80
75
  try:
81
- path = (
82
- f"/connections/applications"
83
- if is_cpd_env(self.base_url)
84
- else f"/connections/applications?include_details=true"
85
- )
86
- res = self._get(path)
76
+ res = self._get(f"/connections/applications?include_details=true")
87
77
  import json
88
78
  json.dumps(res)
89
79
  return [ListConfigsResponse.model_validate(conn) for conn in res.get("applications", [])]
@@ -135,19 +125,9 @@ class ConnectionsClient(BaseAPIClient):
135
125
  def get_credentials(self, app_id: str, env: ConnectionEnvironment, use_app_credentials: bool) -> dict:
136
126
  try:
137
127
  if use_app_credentials:
138
- path = (
139
- f"/connections/applications/{app_id}/credentials?env={env}"
140
- if is_cpd_env(self.base_url)
141
- else f"/connections/applications/{app_id}/credentials/{env}"
142
- )
143
- return self._get(path)
128
+ return self._get(f"/connections/applications/{app_id}/credentials/{env}")
144
129
  else:
145
- path = (
146
- f"/connections/applications/{app_id}/configs/runtime_credentials?env={env}"
147
- if is_cpd_env(self.base_url)
148
- else f"/connections/applications/runtime_credentials?app_id={app_id}&env={env}"
149
- )
150
- return self._get(path)
130
+ return self._get(f"/connections/applications/runtime_credentials?app_id={app_id}&env={env}")
151
131
  except ClientAPIException as e:
152
132
  if e.response.status_code == 404:
153
133
  return None
@@ -177,12 +157,7 @@ class ConnectionsClient(BaseAPIClient):
177
157
  if conn_id is None:
178
158
  return ""
179
159
  try:
180
- path = (
181
- f"/connections/applications/id/{conn_id}"
182
- if is_cpd_env(self.base_url)
183
- else f"/connections/applications?connection_id={conn_id}"
184
- )
185
- app_details = self._get(path)
160
+ app_details = self._get(f"/connections/applications?connection_id={conn_id}")
186
161
  return app_details.get("app_id")
187
162
  except ClientAPIException as e:
188
163
  if e.response.status_code == 404:
@@ -21,10 +21,11 @@ class CPEClient(BaseAPIClient):
21
21
  }
22
22
 
23
23
 
24
- def submit_pre_cpe_chat(self, user_message: str | None =None, tools: Dict[str, Any] = None) -> dict:
24
+ def submit_pre_cpe_chat(self, user_message: str | None =None, tools: Dict[str, Any] = None, agents: Dict[str, Any] = None) -> dict:
25
25
  payload = {
26
26
  "message": user_message,
27
27
  "tools": tools,
28
+ "agents": agents,
28
29
  "chat_id": self.chat_id,
29
30
  "chat_model_name": self.chat_model_name
30
31
  }
@@ -16,6 +16,7 @@ from ibm_watsonx_orchestrate.cli.config import (
16
16
  from threading import Lock
17
17
  from ibm_watsonx_orchestrate.client.base_api_client import BaseAPIClient
18
18
  from ibm_watsonx_orchestrate.utils.utils import yaml_safe_load
19
+ from ibm_watsonx_orchestrate.cli.commands.channels.types import RuntimeEnvironmentType
19
20
  import logging
20
21
  from typing import TypeVar
21
22
  import os
@@ -32,18 +33,6 @@ def get_current_env_url() -> str:
32
33
  active_env = cfg.read(CONTEXT_SECTION_HEADER, CONTEXT_ACTIVE_ENV_OPT)
33
34
  return cfg.get(ENVIRONMENTS_SECTION_HEADER, active_env, ENV_WXO_URL_OPT)
34
35
 
35
- def get_cpd_instance_id_from_url(url: str | None = None) -> str:
36
- if url is None:
37
- url = get_current_env_url()
38
-
39
- if not is_cpd_env(url):
40
- logger.error(f"The host {url} is not a CPD instance")
41
- sys.exit(1)
42
-
43
- url_fragments = url.split('/')
44
- return url_fragments[-1] if url_fragments[-1] else url_fragments[-2]
45
-
46
-
47
36
  def is_local_dev(url: str | None = None) -> bool:
48
37
  if url is None:
49
38
  url = get_current_env_url()
@@ -62,11 +51,11 @@ def is_local_dev(url: str | None = None) -> bool:
62
51
 
63
52
  return False
64
53
 
65
- def is_cpd_env(url: str | None = None) -> bool:
54
+ def is_ga_platform(url: str | None = None) -> bool:
66
55
  if url is None:
67
56
  url = get_current_env_url()
68
57
 
69
- if url.lower().startswith("https://cpd"):
58
+ if url.__contains__("orchestrate.ibm.com"):
70
59
  return True
71
60
  return False
72
61
 
@@ -81,24 +70,37 @@ def is_ibm_cloud_platform(url:str | None = None) -> bool:
81
70
  return True
82
71
  return False
83
72
 
84
- def is_ga_platform(url: str | None = None) -> bool:
73
+ def is_cpd_env(url: str | None = None) -> bool:
85
74
  if url is None:
86
75
  url = get_current_env_url()
87
76
 
88
- if url.__contains__("orchestrate.ibm.com"):
77
+ if url.lower().startswith("https://cpd"):
89
78
  return True
90
79
  return False
91
80
 
81
+ def get_cpd_instance_id_from_url(url: str | None = None) -> str:
82
+ if url is None:
83
+ url = get_current_env_url()
84
+
85
+ if not is_cpd_env(url):
86
+ logger.error(f"The host {url} is not a CPD instance")
87
+ sys.exit(1)
88
+
89
+ url_fragments = url.split('/')
90
+ return url_fragments[-1] if url_fragments[-1] else url_fragments[-2]
91
+
92
+
93
+
92
94
 
93
95
  def get_environment() -> str:
94
96
  if is_local_dev():
95
- return "local"
97
+ return RuntimeEnvironmentType.LOCAL
96
98
  if is_cpd_env():
97
- return "cpd"
99
+ return RuntimeEnvironmentType.CPD
98
100
  if is_ibm_cloud_platform():
99
- return "ibmcloud"
101
+ return RuntimeEnvironmentType.IBM_CLOUD
100
102
  if is_ga_platform():
101
- return "ga"
103
+ return RuntimeEnvironmentType.AWS
102
104
  return None
103
105
 
104
106
  def check_token_validity(token: str) -> bool:
@@ -46,7 +46,7 @@ services:
46
46
  WXO_SERVER_URL: http://wxo-server:4321
47
47
  MAX_POOL: 60
48
48
  DEPLOYMENT_MODE: laptop
49
- SUFFIXLIST: '["global_05d7ba72", "ibm_184bdbd3"]'
49
+ SUFFIXLIST: ${CM_SUFFIXLIST:-[]}
50
50
  ports:
51
51
  - 3001:3001
52
52
 
@@ -302,7 +302,7 @@ services:
302
302
  JWT_PRIVATE_CONFIGS_PATH: "/"
303
303
  VECTOR_STORE_PROVIDER: ${VECTOR_STORE_PROVIDER:-milvus}
304
304
  CELERY_RESULTS_TTL: "3600"
305
- EVENT_BROKER_TTL: "-1"
305
+ EVENT_BROKER_TTL: "3600"
306
306
  TAVILY_API_KEY: ${TAVILY_API_KEY:-dummy_tavily_api_key}
307
307
  WATSONX_APIKEY: ${WATSONX_APIKEY}
308
308
  WATSONX_URL: ${WATSONX_URL}
@@ -404,7 +404,7 @@ services:
404
404
  WXAI_API_KEY: ${WXAI_API_KEY}
405
405
  VECTOR_STORE_PROVIDER: ${VECTOR_STORE_PROVIDER:-milvus}
406
406
  CELERY_RESULTS_TTL: "3600"
407
- EVENT_BROKER_TTL: "-1"
407
+ EVENT_BROKER_TTL: "3600"
408
408
  TAVILY_API_KEY: ${TAVILY_API_KEY:-dummy_tavily_api_key}
409
409
  WATSONX_APIKEY: ${WATSONX_APIKEY}
410
410
  WATSONX_URL: ${WATSONX_URL}
@@ -789,7 +789,7 @@ services:
789
789
  WO_AUTH_TYPE: ${WO_AUTH_TYPE}
790
790
  healthcheck:
791
791
  test: curl -k http://localhost:9044/readiness --fail
792
- interval: 5s
792
+ interval: 30s
793
793
  timeout: 60s
794
794
  retries: 5
795
795
  ports:
@@ -804,6 +804,13 @@ services:
804
804
  environment:
805
805
  WATSONX_APIKEY: ${WATSONX_APIKEY}
806
806
  WATSONX_SPACE_ID: ${WATSONX_SPACE_ID}
807
+ WO_API_KEY: ${WO_API_KEY}
808
+ WO_USERNAME: ${WO_USERNAME}
809
+ WO_PASSWORD: ${WO_PASSWORD}
810
+ WO_INSTANCE: ${WO_INSTANCE}
811
+ USE_SAAS_ML_TOOLS_RUNTIME: ${USE_SAAS_ML_TOOLS_RUNTIME}
812
+ WO_AUTH_TYPE: ${WO_AUTH_TYPE}
813
+ AUTHORIZATION_URL: ${AUTHORIZATION_URL}
807
814
  ports:
808
815
  - 8081:8080
809
816
 
@@ -909,7 +916,7 @@ services:
909
916
  ENRICHMENT_BATCH_SIZE: "1000"
910
917
  CIPHER_AES_REALM_KEY: "dGVzdHRlc3R0ZXN0dGVzdA=="
911
918
  SIDECAR_METERED_ENABLED: "false"
912
- DPI_DEBUG: true
919
+ DPI_DEBUG: "false"
913
920
  DPI_WO_WDU_SERVER_ENDPOINT: https://wxo-doc-processing-service:8080
914
921
  # DPI_RAG_SERVER_ENDPOINT: https://wxo-doc-processing-llm-service:8083
915
922
  DISABLE_TLS: true
@@ -42,7 +42,7 @@ LANGFUSE_PRIVATE_KEY=sk-lf-7bc4da63-7b2b-40c0-b5eb-1e0cf64f9af2
42
42
 
43
43
  CELERY_WORKER_CONCURRENCY=12
44
44
  CELERY_RESULTS_TTL="3600"
45
- EVENT_BROKER_TTL="-1"
45
+ EVENT_BROKER_TTL="3600"
46
46
 
47
47
  # START -- IMAGE REGISTRIES AND TAGS
48
48
  # The registry URL to pull the private images from, including the name of the repository in the registry.
@@ -53,13 +53,13 @@ EVENT_BROKER_TTL="-1"
53
53
  REGISTRY_URL=
54
54
 
55
55
 
56
- SERVER_TAG=02-07-2025
56
+ SERVER_TAG=22-07-2025
57
57
  SERVER_REGISTRY=
58
58
 
59
- WORKER_TAG=02-07-2025
59
+ WORKER_TAG=22-07-2025
60
60
  WORKER_REGISTRY=
61
61
 
62
- AI_GATEWAY_TAG=01-07-2025
62
+ AI_GATEWAY_TAG=21-07-2025
63
63
  AI_GATEWAY_REGISTRY=
64
64
 
65
65
  AGENT_GATEWAY_TAG=07-07-2025
@@ -68,26 +68,26 @@ AGENT_GATEWAY_REGISTRY=
68
68
  DB_REGISTRY=
69
69
  # If you build multiarch set all three of these to the same, we have a pr against main
70
70
  # to not have this separation, but we can merge it later
71
- DBTAG=24-06-2025-v1
71
+ DBTAG=22-07-2025
72
72
  AMDDBTAG=24-06-2025-v1
73
73
  ARM64DBTAG=24-06-2025-v1
74
74
 
75
75
  UI_REGISTRY=
76
- UITAG=27-06-2025
76
+ UITAG=23-07-2025
77
77
 
78
78
  CM_REGISTRY=
79
- CM_TAG=27-06-2025
79
+ CM_TAG=24-07-2025
80
80
 
81
- TRM_TAG=08-07-2025
81
+ TRM_TAG=23-07-2025-3c60549f0bac275de3e5736265a3fd49cdd3a203
82
82
  TRM_REGISTRY=
83
83
 
84
- TR_TAG=08-07-2025
84
+ TR_TAG=23-07-2025-3c60549f0bac275de3e5736265a3fd49cdd3a203
85
85
  TR_REGISTRY=
86
86
 
87
- BUILDER_TAG=02-07-2025
87
+ BUILDER_TAG=22-07-2025-v1
88
88
  BUILDER_REGISTRY=
89
89
 
90
- FLOW_RUNTIME_TAG=10-07-2025
90
+ FLOW_RUNTIME_TAG=15-07-2025
91
91
  FLOW_RUMTIME_REGISTRY=
92
92
 
93
93
 
@@ -100,7 +100,7 @@ JAEGER_PROXY_REGISTRY=
100
100
  SOCKET_HANDLER_TAG=29-05-2025
101
101
  SOCKET_HANDLER_REGISTRY=
102
102
 
103
- CPE_TAG=08-07-2025
103
+ CPE_TAG=17-07-2025
104
104
  CPE_REGISTRY=
105
105
 
106
106
  # IBM Document Processing
@@ -171,6 +171,7 @@ WO_INSTANCE=
171
171
  AUTHORIZATION_URL=
172
172
  WO_AUTH_TYPE=
173
173
  PYTHONPATH=
174
+ CM_SUFFIXLIST=
174
175
 
175
176
  # Use your machine's local IP address for external async tool communication.
176
177
  CALLBACK_HOST_URL=
@@ -1,6 +1,8 @@
1
1
  from .constants import START, END, RESERVED
2
- from ..types import FlowContext, TaskData, TaskEventType, DocumentContent
3
- from ..node import UserNode, AgentNode, StartNode, EndNode, PromptNode, ToolNode
2
+
3
+ from ..types import FlowContext, TaskData, TaskEventType, File, DecisionsCondition, DecisionsRule
4
+ from ..node import UserNode, AgentNode, StartNode, EndNode, PromptNode, ToolNode, DecisionsNode
5
+
4
6
  from .flow import Flow, CompiledFlow, FlowRun, FlowEvent, FlowEventType, FlowFactory, MatchPolicy, WaitPolicy, ForeachPolicy, Branch, Foreach, Loop
5
7
  from .decorators import flow
6
8
  from ..data_map import Assignment, DataMap
@@ -14,7 +16,7 @@ __all__ = [
14
16
  "FlowContext",
15
17
  "TaskData",
16
18
  "TaskEventType",
17
- "DocumentContent",
19
+ "File",
18
20
 
19
21
  "DocProcNode",
20
22
  "UserNode",
@@ -23,6 +25,7 @@ __all__ = [
23
25
  "EndNode",
24
26
  "PromptNode",
25
27
  "ToolNode",
28
+ "DecisionsNode",
26
29
  "Assignment",
27
30
  "DataMap",
28
31
 
@@ -38,8 +41,8 @@ __all__ = [
38
41
  "Branch",
39
42
  "Foreach",
40
43
  "Loop",
44
+ "DecisionsCondition",
45
+ "DecisionsRule",
41
46
 
42
- "user",
43
- "flow_spec",
44
47
  "flow"
45
48
  ]
@@ -27,12 +27,11 @@ from ibm_watsonx_orchestrate.client.utils import instantiate_client
27
27
  from ..types import (
28
28
  EndNodeSpec, Expression, ForeachPolicy, ForeachSpec, LoopSpec, BranchNodeSpec, MatchPolicy, PromptLLMParameters, PromptNodeSpec,
29
29
  StartNodeSpec, ToolSpec, JsonSchemaObject, ToolRequestBody, ToolResponseBody, UserFieldKind, UserFieldOption, UserFlowSpec, UserNodeSpec, WaitPolicy,
30
- DocProcSpec, TextExtractionResponse, KVPInvoicesExtractionResponse, KVPUtilityBillsExtractionResponse,
31
- DocumentContent
30
+ DocProcSpec, TextExtractionResponse, File, DecisionsNodeSpec, DecisionsRule
32
31
  )
33
32
  from .constants import CURRENT_USER, START, END, ANY_USER
34
33
  from ..node import (
35
- EndNode, Node, PromptNode, StartNode, UserNode, AgentNode, DataMap, ToolNode, DocProcNode
34
+ EndNode, Node, PromptNode, StartNode, UserNode, AgentNode, DataMap, ToolNode, DocProcNode, DecisionsNode
36
35
  )
37
36
  from ..types import (
38
37
  AgentNodeSpec, extract_node_spec, FlowContext, FlowEventType, FlowEvent, FlowSpec,
@@ -434,6 +433,49 @@ class Flow(Node):
434
433
  node = self._add_node(node)
435
434
  return cast(PromptNode, node)
436
435
 
436
+ def decisions(self,
437
+ name: str,
438
+ display_name: str|None=None,
439
+ rules: list[DecisionsRule] | None = None,
440
+ default_actions: dict[str, Any] = None,
441
+ locale: str | None = None,
442
+ description: str | None = None,
443
+ input_schema: type[BaseModel]|None = None,
444
+ output_schema: type[BaseModel]|None=None,
445
+ input_map: DataMap = None) -> PromptNode:
446
+
447
+ if name is None:
448
+ raise ValueError("name must be provided.")
449
+
450
+ if rules is None:
451
+ raise ValueError("rules must be specified.")
452
+
453
+ # create input spec
454
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
455
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
456
+
457
+ # Create the tool spec
458
+ task_spec = DecisionsNodeSpec(
459
+ name=name,
460
+ display_name=display_name if display_name is not None else name,
461
+ description=description,
462
+ rules=rules,
463
+ default_actions=default_actions,
464
+ locale=locale,
465
+ input_schema=_get_tool_request_body(input_schema_obj),
466
+ output_schema=_get_tool_response_body(output_schema_obj),
467
+ output_schema_object = output_schema_obj
468
+ )
469
+
470
+ node = DecisionsNode(spec=task_spec)
471
+ # setup input map
472
+ if input_map:
473
+ node.input_map = self._get_data_map(input_map)
474
+
475
+ # add the node to the list of node
476
+ node = self._add_node(node)
477
+ return cast(DecisionsNode, node)
478
+
437
479
  def docproc(self,
438
480
  name: str,
439
481
  task: str,
@@ -448,12 +490,10 @@ class Flow(Node):
448
490
  raise ValueError("task must be provided.")
449
491
 
450
492
  output_schema_dict = {
451
- "text_extraction" : TextExtractionResponse,
452
- "kvp_invoices_extraction" : KVPInvoicesExtractionResponse,
453
- "kvp_utility_bills_extraction" : KVPUtilityBillsExtractionResponse
493
+ "text_extraction" : TextExtractionResponse
454
494
  }
455
495
  # create input spec
456
- input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = DocumentContent)
496
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = File)
457
497
  output_schema_obj = _get_json_schema_obj("output", output_schema_dict[task])
458
498
  if "$defs" in output_schema_obj.model_extra:
459
499
  output_schema_obj.model_extra.pop("$defs")
@@ -5,7 +5,7 @@ import uuid
5
5
  import yaml
6
6
  from pydantic import BaseModel, Field, SerializeAsAny
7
7
 
8
- from .types import EndNodeSpec, NodeSpec, AgentNodeSpec, PromptNodeSpec, StartNodeSpec, ToolNodeSpec, UserFieldKind, UserFieldOption, UserNodeSpec, DocProcSpec
8
+ from .types import EndNodeSpec, NodeSpec, AgentNodeSpec, PromptNodeSpec, StartNodeSpec, ToolNodeSpec, UserFieldKind, UserFieldOption, UserNodeSpec, DocProcSpec, DecisionsNodeSpec
9
9
  from .data_map import DataMap
10
10
 
11
11
  class Node(BaseModel):
@@ -116,7 +116,13 @@ class DocProcNode(Node):
116
116
 
117
117
  def get_spec(self) -> DocProcSpec:
118
118
  return cast(DocProcSpec, self.spec)
119
+ class DecisionsNode(Node):
120
+ def __repr__(self):
121
+ return f"DecisionsNode(name='{self.spec.name}', description='{self.spec.description}')"
119
122
 
123
+ def get_spec(self) -> DecisionsNodeSpec:
124
+ return cast(DecisionsNodeSpec, self.spec)
125
+
120
126
  class NodeInstance(BaseModel):
121
127
  node: Node
122
128
  id: str # unique id of this task instance