ibm-watsonx-orchestrate 1.3.0__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ibm_watsonx_orchestrate/__init__.py +1 -1
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +2 -0
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +9 -2
  4. ibm_watsonx_orchestrate/agent_builder/toolkits/base_toolkit.py +32 -0
  5. ibm_watsonx_orchestrate/agent_builder/toolkits/types.py +42 -0
  6. ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +10 -1
  7. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +4 -2
  8. ibm_watsonx_orchestrate/agent_builder/tools/types.py +2 -1
  9. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +29 -0
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +271 -12
  11. ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +17 -2
  12. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +180 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +194 -8
  14. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +117 -48
  15. ibm_watsonx_orchestrate/cli/commands/server/types.py +105 -0
  16. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +55 -7
  17. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_controller.py +123 -42
  18. ibm_watsonx_orchestrate/cli/commands/tools/tools_command.py +22 -1
  19. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +197 -12
  20. ibm_watsonx_orchestrate/client/agents/agent_client.py +4 -1
  21. ibm_watsonx_orchestrate/client/agents/assistant_agent_client.py +5 -1
  22. ibm_watsonx_orchestrate/client/agents/external_agent_client.py +5 -1
  23. ibm_watsonx_orchestrate/client/analytics/llm/analytics_llm_client.py +2 -6
  24. ibm_watsonx_orchestrate/client/base_api_client.py +5 -2
  25. ibm_watsonx_orchestrate/client/connections/connections_client.py +3 -9
  26. ibm_watsonx_orchestrate/client/model_policies/__init__.py +0 -0
  27. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +47 -0
  28. ibm_watsonx_orchestrate/client/model_policies/types.py +36 -0
  29. ibm_watsonx_orchestrate/client/models/__init__.py +0 -0
  30. ibm_watsonx_orchestrate/client/models/models_client.py +46 -0
  31. ibm_watsonx_orchestrate/client/models/types.py +177 -0
  32. ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +15 -6
  33. ibm_watsonx_orchestrate/client/tools/tempus_client.py +40 -0
  34. ibm_watsonx_orchestrate/client/tools/tool_client.py +8 -0
  35. ibm_watsonx_orchestrate/docker/compose-lite.yml +68 -13
  36. ibm_watsonx_orchestrate/docker/default.env +22 -12
  37. ibm_watsonx_orchestrate/docker/tempus/common-config.yaml +1 -1
  38. ibm_watsonx_orchestrate/experimental/flow_builder/__init__.py +0 -0
  39. ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +41 -0
  40. ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +17 -0
  41. ibm_watsonx_orchestrate/experimental/flow_builder/flows/data_map.py +91 -0
  42. ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +143 -0
  43. ibm_watsonx_orchestrate/experimental/flow_builder/flows/events.py +72 -0
  44. ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +1288 -0
  45. ibm_watsonx_orchestrate/experimental/flow_builder/node.py +97 -0
  46. ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +98 -0
  47. ibm_watsonx_orchestrate/experimental/flow_builder/types.py +492 -0
  48. ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +113 -0
  49. ibm_watsonx_orchestrate/utils/utils.py +5 -2
  50. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/METADATA +4 -1
  51. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/RECORD +54 -32
  52. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/WHEEL +0 -0
  53. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/entry_points.txt +0 -0
  54. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.4.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1288 @@
1
+ """
2
+ The Flow model. There are multiple methods to allow creation and population of
3
+ the Flow model.
4
+ """
5
+
6
+ import asyncio
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ from typing import (
10
+ Any, AsyncIterator, Callable, cast, List, Sequence, Union, Tuple
11
+ )
12
+ import json
13
+ import logging
14
+ import time
15
+ import copy
16
+ import uuid
17
+ import pytz
18
+
19
+ from typing_extensions import Self
20
+ from pydantic import BaseModel, Field, PrivateAttr, SerializeAsAny
21
+ import yaml
22
+ from munch import Munch
23
+ from ibm_watsonx_orchestrate.agent_builder.tools.python_tool import PythonTool
24
+ from ibm_watsonx_orchestrate.client.tools.tempus_client import TempusClient
25
+ from ibm_watsonx_orchestrate.client.utils import instantiate_client
26
+ from ..types import (
27
+ EndNodeSpec, Expression, ForeachPolicy, ForeachSpec, LoopSpec, BranchNodeSpec, MatchPolicy,
28
+ StartNodeSpec, ToolSpec, JsonSchemaObject, ToolRequestBody, ToolResponseBody, WaitPolicy
29
+ )
30
+ from .constants import START, END, ANY_USER
31
+ from ..node import (
32
+ EndNode, Node, StartNode, UserNode, AgentNode, DataMap, ToolNode
33
+ )
34
+ from ..types import (
35
+ AgentNodeSpec, extract_node_spec, FlowContext, FlowEventType, FlowEvent, FlowSpec,
36
+ NodeSpec, TaskEventType, ToolNodeSpec, SchemaRef, JsonSchemaObjectRef, _to_json_from_json_schema
37
+ )
38
+
39
+ from .data_map import Assignment, AssignmentDataMap, AssignmentDataMapSpec
40
+ from ..utils import _get_json_schema_obj, get_valid_name, import_flow_model
41
+
42
+ from .events import StreamConsumer
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Mapping each event to its type
47
+ EVENT_TYPE_MAP = {
48
+ FlowEventType.ON_FLOW_START: "informational",
49
+ FlowEventType.ON_FLOW_END: "informational",
50
+ FlowEventType.ON_FLOW_ERROR: "interrupting",
51
+ TaskEventType.ON_TASK_WAIT: "interrupting",
52
+ TaskEventType.ON_TASK_START: "informational",
53
+ TaskEventType.ON_TASK_END: "informational",
54
+ TaskEventType.ON_TASK_STREAM: "interrupting",
55
+ TaskEventType.ON_TASK_ERROR: "interrupting",
56
+ }
57
+
58
+ class FlowEdge(BaseModel):
59
+ '''Used to designate the edge of a flow.'''
60
+ start: str
61
+ end: str
62
+
63
+
64
+ class Flow(Node):
65
+ '''Flow represents a flow that will be run by wxO Flow engine.'''
66
+ output_map: DataMap | None = None
67
+ nodes: dict[str, SerializeAsAny[Node]] = {}
68
+ edges: List[FlowEdge] = []
69
+ schemas: dict[str, JsonSchemaObject] = {}
70
+ compiled: bool = False
71
+ validated: bool = False
72
+ metadata: dict[str, str] = {}
73
+ parent: Any = None
74
+
75
+ def __init__(self, **kwargs):
76
+ super().__init__(**kwargs)
77
+
78
+ # extract data schemas
79
+ self._refactor_node_to_schemaref(self)
80
+
81
+ def _find_topmost_flow(self) -> Self:
82
+ if self.parent:
83
+ return self.parent._find_topmost_flow()
84
+ return self
85
+
86
+ def _add_schema(self, schema: JsonSchemaObject, title: str = None) -> JsonSchemaObject:
87
+ '''
88
+ Adds a schema to the dictionary of schemas. If a schema with the same name already exists, it returns the existing schema. Otherwise, it creates a deep copy of the schema, adds it to the dictionary, and returns the new schema.
89
+
90
+ Parameters:
91
+ schema (JsonSchemaObject): The schema to be added.
92
+ title (str, optional): The title of the schema. If not provided, it will be generated based on the schema's title or aliasName.
93
+
94
+ Returns:
95
+ JsonSchemaObject: The added or existing schema.
96
+ '''
97
+
98
+ # find the top most flow and add the schema to that scope
99
+ top_flow = self._find_topmost_flow()
100
+
101
+ # if there is already a schema with the same name, return it
102
+ if title:
103
+ if title in top_flow.schemas:
104
+ return top_flow.schemas[title]
105
+
106
+ # otherwise, create a deep copy of the schema, add it to the dictionary and return it
107
+ if schema:
108
+ if isinstance(schema, dict):
109
+ # recast schema to support direct access
110
+ schema = Munch(schema)
111
+ # we should only add schema when it is a complex object
112
+ if schema.type != "object" and schema.type != "array":
113
+ return schema
114
+
115
+ new_schema = copy.deepcopy(schema)
116
+ if not title:
117
+ if schema.title:
118
+ title = get_valid_name(schema.title)
119
+ elif schema.aliasName:
120
+ title = get_valid_name(schema.aliasName)
121
+ else:
122
+ title = "bo_" + uuid.uuid4().hex
123
+
124
+ if new_schema.type == "object":
125
+ # iterate the properties and add schema recursively
126
+ if new_schema.properties is not None:
127
+ for key, value in new_schema.properties.items():
128
+ if isinstance(value, JsonSchemaObject):
129
+ if value.type == "object":
130
+ schema_ref = self._add_schema_ref(value, value.title)
131
+ new_schema.properties[key] = JsonSchemaObjectRef(title=value.title,
132
+ ref = f"{schema_ref.ref}")
133
+ elif value.type == "array" and (value.items.type == "object" or value.items.type == "array"):
134
+ schema_ref = self._add_schema_ref(value.items, value.items.title)
135
+ new_schema.properties[key].items = JsonSchemaObjectRef(title=value.title,
136
+ ref = f"{schema_ref.ref}")
137
+ elif value.model_extra and value.model_extra["$ref"]:
138
+ # there is already a reference, remove $/defs/ from the initial ref
139
+ ref_value = value.model_extra["$ref"]
140
+ schema_ref = f"#/schemas/{ref_value[8:]}"
141
+ new_schema.properties[key] = JsonSchemaObjectRef(ref = f"{schema_ref}")
142
+
143
+ elif new_schema.type == "array":
144
+ if new_schema.items.type == "object" or new_schema.items.type == "array":
145
+ schema_ref = self._add_schema_ref(new_schema.items, new_schema.items.title)
146
+ new_schema.items = JsonSchemaObjectRef(title=new_schema.items.title,
147
+ ref= f"{schema_ref.ref}")
148
+
149
+ # we also need to unpack local references
150
+ if hasattr(new_schema, "model_extra") and "$defs" in new_schema.model_extra:
151
+ for schema_name, schema_def in new_schema.model_extra["$defs"].items():
152
+ self._add_schema(schema_def, schema_name)
153
+
154
+ # set the title
155
+ new_schema.title = title
156
+ top_flow.schemas[title] = new_schema
157
+
158
+ return new_schema
159
+ return None
160
+
161
+ def _add_schema_ref(self, schema: JsonSchemaObject, title: str = None) -> SchemaRef:
162
+ '''Create a schema reference'''
163
+ if schema and (schema.type == "object" or schema.type == "array"):
164
+ new_schema = self._add_schema(schema, title)
165
+ return SchemaRef(ref=f"#/schemas/{new_schema.title}")
166
+ raise AssertionError(f"schema is not a complex object: {schema}")
167
+
168
+ def _refactor_node_to_schemaref(self, node: Node):
169
+ self._refactor_spec_to_schemaref(node.spec)
170
+
171
+ def _refactor_spec_to_schemaref(self, spec: NodeSpec):
172
+ if spec.input_schema:
173
+ if isinstance(spec.input_schema, ToolRequestBody):
174
+ spec.input_schema = self._add_schema_ref(JsonSchemaObject(type = spec.input_schema.type,
175
+ properties= spec.input_schema.properties,
176
+ required= spec.input_schema.required),
177
+ f"{spec.name}_input")
178
+ if spec.output_schema_object is not None and spec.output_schema_object.type == "object":
179
+ spec.output_schema = self._add_schema_ref(spec.output_schema_object, spec.output_schema_object.title)
180
+ spec.output_schema_object = None
181
+ elif spec.output_schema is not None:
182
+ if isinstance(spec.output_schema, ToolResponseBody):
183
+ if spec.output_schema.type == "object":
184
+ json_obj = JsonSchemaObject(type = spec.output_schema.type,
185
+ description=spec.output_schema.description,
186
+ properties= spec.output_schema.properties,
187
+ items = spec.output_schema.items,
188
+ uniqueItems=spec.output_schema.uniqueItems,
189
+ anyOf=spec.output_schema.anyOf,
190
+ required= spec.output_schema.required)
191
+ spec.output_schema = self._add_schema_ref(json_obj, f"{spec.name}_output")
192
+ elif spec.output_schema.type == "array":
193
+ if spec.output_schema.items.type == "object":
194
+ schema_ref = self._add_schema_ref(spec.output_schema.items)
195
+ spec.output_schema.items = JsonSchemaObjectRef(ref=f"{schema_ref.ref}")
196
+
197
+ # def refactor_datamap_spec_to_schemaref(self, spec: FnDataMapSpec):
198
+ # '''TODO'''
199
+ # if spec.input_schema:
200
+ # if isinstance(spec.input_schema, ToolRequestBody):
201
+ # spec.input_schema = self._add_schema_ref(JsonSchemaObject(type = spec.input_schema.type,
202
+ # properties= spec.input_schema.properties,
203
+ # required= spec.input_schema.required),
204
+ # f"{spec.name}_input")
205
+ # if spec.output_schema_object is not None and spec.output_schema_object.type == "object":
206
+ # spec.output_schema = self._add_schema_ref(spec.output_schema_object, spec.output_schema_object.title)
207
+ # spec.output_schema_object = None
208
+ # elif spec.output_schema is not None:
209
+ # if isinstance(spec.output_schema, ToolResponseBody):
210
+ # spec.output_schema = self._add_schema_ref(JsonSchemaObject(type = spec.output_schema.type,
211
+ # Sdescription=spec.output_schema.description,
212
+ # properties= spec.output_schema.properties,
213
+ # items = spec.output_schema.items,
214
+ # uniqueItems=spec.output_schema.uniqueItems,
215
+ # anyOf=spec.output_schema.anyOf,
216
+ # required= spec.output_schema.required),
217
+ # f"{spec.name}_output")
218
+
219
+ def _create_node_from_tool_fn(
220
+ self,
221
+ tool: Callable
222
+ ) -> ToolNode:
223
+ if not isinstance(tool, Callable):
224
+ raise ValueError("Only functions with @tool decorator can be added.")
225
+
226
+ spec = getattr(tool, "__tool_spec__", None)
227
+ if not spec:
228
+ raise ValueError("Only functions with @tool decorator can be added.")
229
+
230
+ self._check_compiled()
231
+
232
+ tool_spec = cast(ToolSpec, spec)
233
+
234
+ # we need more information from the function signature
235
+ spec = extract_node_spec(tool)
236
+
237
+ toolnode_spec = ToolNodeSpec(type = "tool",
238
+ name = tool_spec.name,
239
+ display_name = tool_spec.name,
240
+ description = tool_spec.description,
241
+ input_schema = tool_spec.input_schema,
242
+ output_schema = tool_spec.output_schema,
243
+ output_schema_object = spec.output_schema_object,
244
+ tool = tool_spec.name)
245
+
246
+ return ToolNode(spec=toolnode_spec)
247
+
248
+ def tool(
249
+ self,
250
+ tool: Callable | str | None = None,
251
+ name: str | None = None,
252
+ display_name: str | None = None,
253
+ description: str | None = None,
254
+
255
+ input_schema: type[BaseModel] | None = None,
256
+ output_schema: type[BaseModel] | None = None,
257
+ input_map: List[Assignment] = None
258
+ ) -> Node:
259
+ '''create a tool node in the flow'''
260
+ if tool is None:
261
+ raise ValueError("tool must be provided")
262
+
263
+ if isinstance(tool, str):
264
+ return self._node(
265
+ name=name if name is not None and name != "" else tool,
266
+ tool=tool,
267
+ display_name=display_name,
268
+ description=description,
269
+ input_schema=input_schema,
270
+ output_schema=output_schema,
271
+ input_map=input_map)
272
+ elif isinstance(tool, PythonTool):
273
+ return self._node(
274
+ node=tool,
275
+ name=name if name is not None and name != "" else tool.fn.__name__,
276
+ display_name=display_name,
277
+ description=description,
278
+ input_schema=input_schema,
279
+ output_schema=output_schema,
280
+ input_map=input_map)
281
+ else:
282
+ raise ValueError(f"tool is not a string or a callable: {tool}")
283
+
284
+ def agent(
285
+ self,
286
+ name: str | None = None,
287
+ display_name: str | None = None,
288
+ description: str | None = None,
289
+ agent: str | None = None,
290
+ message: str | None = None,
291
+ guidelines: str | None = None,
292
+ input_schema: type[BaseModel] | None = None,
293
+ output_schema: type[BaseModel] | None = None,
294
+ input_map: List[Assignment] = None
295
+ ) -> Node:
296
+ '''create an agent node in the flow'''
297
+ return self._node(
298
+ name=name,
299
+ display_name=display_name,
300
+ description=description,
301
+ agent=agent,
302
+ message=message,
303
+ guidelines=guidelines,
304
+ input_schema=input_schema,
305
+ output_schema=output_schema,
306
+ input_map=input_map
307
+ )
308
+
309
+ def _node(
310
+ self,
311
+ node: Union[Node, Callable] = None,
312
+ name: str = None,
313
+ display_name: str | None = None,
314
+ description: str | None = None,
315
+ owners: Sequence[str] | None = None,
316
+ input_schema: type[BaseModel] | None = None,
317
+ output_schema: type[BaseModel] | None = None,
318
+ agent: str | None = None,
319
+ tool: str | None = None,
320
+ message: str | None = None,
321
+ guidelines: str | None = None,
322
+ input_map: Callable | List[Assignment] = None,
323
+ output_map: Callable | List[Assignment] = None,
324
+ ) -> Node:
325
+
326
+ self._check_compiled()
327
+
328
+ if owners is None:
329
+ owners = []
330
+
331
+ if node is not None:
332
+ if not isinstance(node, Node):
333
+ if callable(node):
334
+ user_spec = getattr(node, "__user_spec__", None)
335
+ # script_spec = getattr(node, "__script_spec__", None)
336
+ tool_spec = getattr(node, "__tool_spec__", None)
337
+ if user_spec:
338
+ node = UserNode(spec = user_spec)
339
+ # elif script_spec:
340
+ # node = ScriptNode(spec = script_spec)
341
+ elif tool_spec:
342
+ node = self._create_node_from_tool_fn(node)
343
+ else:
344
+ raise ValueError(
345
+ "Only functions with @user, @tool or @script decorator can be added.")
346
+ elif isinstance(node, Node):
347
+ if node.spec.name in self.nodes:
348
+ raise ValueError(f"Node `{id}` already present.")
349
+
350
+ if node.spec.name == END or node.spec.name == START:
351
+ raise ValueError(f"Node `{id}` is reserved.")
352
+ else:
353
+ raise ValueError(
354
+ "A valid node or function must be specified for the node parameter.")
355
+
356
+ # setup input and output map
357
+ if input_map:
358
+ node.input_map = self._get_data_map(input_map)
359
+ if output_map:
360
+ node.output_map = self._get_data_map(output_map)
361
+
362
+ # add the node to the list of node
363
+ node = self._add_node(node)
364
+ return node
365
+
366
+ if name is not None:
367
+ if agent is not None:
368
+ node = self._create_agent_node(
369
+ name, agent, display_name, message, description, input_schema, output_schema, guidelines)
370
+ elif tool is not None:
371
+ node = self._create_tool_node(
372
+ name, tool, display_name, description, input_schema, output_schema)
373
+ else:
374
+ node = self._create_user_node(
375
+ name, display_name, description, owners, input_schema, output_schema)
376
+
377
+ # setup input and output map
378
+ if input_map:
379
+ node.input_map = self._get_data_map(input_map)
380
+ if output_map:
381
+ node.output_map = self._get_data_map(output_map)
382
+
383
+ # add the node to the list of node
384
+ node = self._add_node(node)
385
+ return node
386
+
387
+ raise ValueError("Either a node or a name must be specified.")
388
+
389
+ def _add_node(self, node: Node) -> Node:
390
+ # make a copy
391
+ new_node = copy.copy(node)
392
+
393
+ self._refactor_node_to_schemaref(new_node)
394
+
395
+ self.nodes[node.spec.name] = new_node
396
+ return new_node
397
+
398
+
399
+ def _create_tool_node(self, name: str, tool: str,
400
+ display_name: str|None=None,
401
+ description: str|None=None,
402
+ input_schema: type[BaseModel]|None=None,
403
+ output_schema: type[BaseModel]|None=None) -> Node:
404
+
405
+ # create input spec
406
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
407
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
408
+
409
+ toolnode_spec = ToolNodeSpec(type = "tool",
410
+ name = name,
411
+ display_name = display_name,
412
+ description = description,
413
+ input_schema=ToolRequestBody(
414
+ type=input_schema_obj.type,
415
+ properties=input_schema_obj.properties,
416
+ required=input_schema_obj.required,
417
+ ) if input_schema is not None else None,
418
+ output_schema=ToolResponseBody(
419
+ type=output_schema_obj.type,
420
+ properties=output_schema_obj.properties,
421
+ required=output_schema_obj.required
422
+ ) if output_schema is not None else None,
423
+ output_schema_object = output_schema_obj,
424
+ tool = tool)
425
+
426
+ return ToolNode(spec=toolnode_spec)
427
+
428
+ def _create_user_node(self, name: str,
429
+ display_name: str|None=None,
430
+ description: str|None=None,
431
+ owners: Sequence[str]|None=[ANY_USER],
432
+ input_schema: type[BaseModel]|None=None,
433
+ output_schema: type[BaseModel]|None=None) -> Node:
434
+ # create input spec
435
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
436
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
437
+
438
+ # identify owner
439
+ if not owners:
440
+ owners = [ANY_USER]
441
+
442
+ # Create the tool spec
443
+ task_spec = NodeSpec(
444
+ name=name,
445
+ display_name=display_name,
446
+ description=description,
447
+ owners=owners,
448
+ input_schema=ToolRequestBody(
449
+ type=input_schema_obj.type,
450
+ properties=input_schema_obj.properties,
451
+ required=input_schema_obj.required,
452
+ ),
453
+ output_schema=ToolResponseBody(
454
+ type=output_schema_obj.type,
455
+ properties=output_schema_obj.properties,
456
+ required=output_schema_obj.required
457
+ ),
458
+ tool=[],
459
+ output_schema_object = output_schema_obj
460
+ )
461
+
462
+ return UserNode(spec = task_spec)
463
+
464
+ def _create_agent_node(self, name: str, agent: str, display_name: str|None=None,
465
+ message: str | None = "Follow the agent instructions.",
466
+ description: str | None = None,
467
+ input_schema: type[BaseModel]|None = None,
468
+ output_schema: type[BaseModel]|None=None,
469
+ guidelines: str|None=None) -> Node:
470
+
471
+ # create input spec
472
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
473
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
474
+
475
+ # Create the tool spec
476
+ task_spec = AgentNodeSpec(
477
+ name=name,
478
+ display_name=display_name,
479
+ description=description,
480
+ agent=agent,
481
+ message=message,
482
+ guidelines=guidelines,
483
+ input_schema=ToolRequestBody(
484
+ type=input_schema_obj.type,
485
+ properties=input_schema_obj.properties,
486
+ required=input_schema_obj.required,
487
+ ),
488
+ output_schema=ToolResponseBody(
489
+ type=output_schema_obj.type,
490
+ properties=output_schema_obj.properties,
491
+ required=output_schema_obj.required
492
+ ),
493
+ output_schema_object = output_schema_obj
494
+ )
495
+
496
+ return AgentNode(spec=task_spec)
497
+
498
+ def node_exists(self, node: Union[str, Node]):
499
+
500
+ if isinstance(node, Node):
501
+ node_id = node.spec.name
502
+ else:
503
+ node_id = node
504
+
505
+ if (node_id == END or node_id == START):
506
+ return True
507
+ if node_id in self.nodes:
508
+ return True
509
+ return False
510
+
511
+ def edge(self,
512
+ start_task: Union[str, Node],
513
+ end_task: Union[str, Node]) -> Self:
514
+
515
+ self._check_compiled()
516
+
517
+ start_id = self._get_node_id(start_task)
518
+ end_id = self._get_node_id(end_task)
519
+
520
+ if not self.node_exists(start_id):
521
+ raise ValueError(f"Node {start_id} has not been added to the flow yet.")
522
+ if not self.node_exists(end_id):
523
+ raise ValueError(f"Node {end_id} has not been added to the flow yet.")
524
+ if start_id == END:
525
+ raise ValueError("END cannot be used as a start Node")
526
+ if end_id == START:
527
+ raise ValueError("START cannot be used as an end Node")
528
+
529
+ # Run this validation only for non-StateGraph graphs
530
+ self.edges.append(FlowEdge(start = start_id, end = end_id))
531
+ return self
532
+
533
+ def sequence(self, *elements: Union[str, Node] | None) -> Self:
534
+ '''TODO: Docstrings'''
535
+ start_element: Union[str, Node] | None = None
536
+ for element in elements:
537
+ if not start_element:
538
+ start_element = element
539
+ else:
540
+ end_element = element
541
+
542
+ if isinstance(start_element, str):
543
+ start_node = start_element
544
+ elif isinstance(start_element, Node):
545
+ start_node = start_element
546
+ else:
547
+ start_node = START
548
+
549
+ if isinstance(end_element, str):
550
+ end_node = end_element
551
+ elif isinstance(end_element, Node):
552
+ end_node = end_element
553
+ else:
554
+ end_node = END
555
+
556
+ self.edge(start_node, end_node)
557
+
558
+ # set start as the current end element
559
+ start_element = end_element
560
+
561
+ return self
562
+
563
+ def starts_with(self, node: Union[str, Node]) -> Self:
564
+ '''Create an edge with an automatic START node.'''
565
+ return self.edge(START, node)
566
+
567
+ def ends_with(self, node: Union[str, Node]) -> Self:
568
+ '''Create an edge with an automatic END node.'''
569
+ return self.edge(node, END)
570
+
571
+ def starts_and_ends_with(self, node: Union[str, Node]) -> Self:
572
+ '''Create a single node flow with an automatic START and END node.'''
573
+ return self.sequence(START, node, END)
574
+
575
+ def branch(self, evaluator: Union[Callable, Expression]) -> "Branch":
576
+ '''Create a BRANCH node'''
577
+ e = evaluator
578
+ if isinstance(evaluator, Callable):
579
+ # We need to get the python tool representation of it
580
+ raise ValueError("Branch with function as an evaluator is not supported yet.")
581
+ # script_spec = getattr(evaluator, "__script_spec__", None)
582
+ # if not script_spec:
583
+ # raise ValueError("Only functions with @script can be used as an evaluator.")
584
+ # new_script_spec = copy.deepcopy(script_spec)
585
+ # self._refactor_spec_to_schemaref(new_script_spec)
586
+ # e = new_script_spec
587
+ elif isinstance(evaluator, str):
588
+ e = Expression(expression=evaluator)
589
+
590
+ spec = BranchNodeSpec(name = "branch_" + uuid.uuid4().hex, evaluator=e)
591
+ branch_node = Branch(spec = spec, containing_flow=self)
592
+ return cast(Branch, self._node(branch_node))
593
+
594
+ def wait_for(self, *args) -> "Wait":
595
+ '''Wait for all incoming nodes to complete.'''
596
+ raise ValueError("Not implemented yet.")
597
+ # spec = NodeSpec(name = "wait_" + uuid.uuid4().hex)
598
+ # wait_node = Wait(spec = spec)
599
+
600
+ # for arg in args:
601
+ # if isinstance(arg, Node):
602
+ # wait_node.node(arg)
603
+ # else:
604
+ # raise ValueError("Only nodes can be added to a wait node.")
605
+
606
+ # return cast(Wait, self.node(wait_node))
607
+
608
+
609
+ def foreach(self, item_schema: type[BaseModel],
610
+ input_schema: type[BaseModel] |None=None,
611
+ output_schema: type[BaseModel] |None=None) -> "Flow": # return an Foreach object
612
+ '''TODO: Docstrings'''
613
+
614
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
615
+ input_schema_obj = _get_json_schema_obj("input", input_schema)
616
+ foreach_item_schema = _get_json_schema_obj("item_schema", item_schema)
617
+
618
+ if input_schema_obj is None:
619
+ input_schema_obj = JsonSchemaObject(
620
+ type = 'object',
621
+ properties = {
622
+ "items": JsonSchemaObject(
623
+ type = "array",
624
+ items = foreach_item_schema)
625
+ },
626
+ required = ["items"])
627
+
628
+ spec = ForeachSpec(name = "foreach_" + uuid.uuid4().hex,
629
+ input_schema=ToolRequestBody(
630
+ type=input_schema_obj.type,
631
+ properties=input_schema_obj.properties,
632
+ required=input_schema_obj.required,
633
+ ),
634
+ output_schema=ToolResponseBody(
635
+ type=output_schema_obj.type,
636
+ properties=output_schema_obj.properties,
637
+ required=output_schema_obj.required
638
+ ) if output_schema_obj is not None else None,
639
+ item_schema = foreach_item_schema)
640
+ foreach_obj = Foreach(spec = spec, parent = self)
641
+ foreach_node = self._node(foreach_obj)
642
+ self._add_schema(foreach_item_schema)
643
+
644
+ return cast(Flow, foreach_node)
645
+
646
+ def loop(self, evaluator: Union[Callable, Expression],
647
+ input_schema: type[BaseModel]|None=None,
648
+ output_schema: type[BaseModel]|None=None) -> "Flow": # return a WhileLoop object
649
+ '''TODO: Docstrings'''
650
+ e = evaluator
651
+ input_schema_obj = _get_json_schema_obj("input", input_schema)
652
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
653
+
654
+ if isinstance(evaluator, Callable):
655
+ # we need to get the python tool representation of it
656
+ script_spec = getattr(evaluator, "__script_spec__", None)
657
+ if not script_spec:
658
+ raise ValueError("Only function with @script can be used as evaluator")
659
+ new_script_spec = copy.deepcopy(script_spec)
660
+ e = new_script_spec
661
+ elif isinstance(evaluator, str):
662
+ e = Expression(expression=evaluator)
663
+
664
+ loop_spec = LoopSpec(name = "loop_" + uuid.uuid4().hex,
665
+ evaluator = e,
666
+ input_schema=ToolRequestBody(
667
+ type=input_schema_obj.type,
668
+ properties=input_schema_obj.properties,
669
+ required=input_schema_obj.required,
670
+ ) if input_schema_obj is not None else None,
671
+ output_schema=ToolResponseBody(
672
+ type=output_schema_obj.type,
673
+ properties=output_schema_obj.properties,
674
+ required=output_schema_obj.required
675
+ ) if output_schema_obj is not None else None)
676
+ while_loop = Loop(spec = loop_spec, parent = self)
677
+ while_node = self._node(while_loop)
678
+ return while_node
679
+
680
+ def validate_model(self) -> bool:
681
+ ''' Validate the model. '''
682
+ validator = FlowValidator(flow=self)
683
+ messages = validator.validate_model()
684
+ if validator.no_error(messages):
685
+ return True
686
+ raise ValueError(f"Invalid flow: {messages}")
687
+
688
+ def _check_compiled(self) -> None:
689
+ if self.compiled:
690
+ raise ValueError("Flow has already been compiled.")
691
+
692
+ def compile(self, **kwargs) -> "CompiledFlow":
693
+ """
694
+ Compile the current Flow model into a CompiledFlow object.
695
+
696
+ This method validates the flow model (if not already validated).
697
+
698
+ To also deploy the model to the engine and test it use the compile_deploy() function.
699
+
700
+ Returns:
701
+ CompiledFlow: An instance of the CompiledFlow class representing
702
+ the compiled flow.
703
+
704
+ Raises:
705
+ ValidationError: If the flow model is invalid and fails validation.
706
+ """
707
+
708
+ if not self.validated:
709
+ # we need to validate the flow first
710
+ self.validate_model()
711
+
712
+ self.compiled = True
713
+ self.metadata["source_kind"] = "adk/python"
714
+ self.metadata["compiled_on"] = datetime.now(pytz.utc).isoformat()
715
+ return CompiledFlow(flow=self, **kwargs)
716
+
717
+ async def compile_deploy(self, **kwargs) -> "CompiledFlow":
718
+ """
719
+ Compile the current Flow model into a CompiledFlow object.
720
+
721
+ This method validates the flow model (if not already validated),
722
+ deploys it to the engine, and marks it as compiled.
723
+
724
+ You can use the compiled flow to start a flow run.
725
+
726
+ Returns:
727
+ CompiledFlow: An instance of the CompiledFlow class representing
728
+ the compiled flow.
729
+
730
+ Raises:
731
+ ValidationError: If the flow model is invalid and fails validation.
732
+ """
733
+
734
+ compiled_flow = self.compile(**kwargs)
735
+
736
+ # Deploy flow to the engine
737
+ model = self.to_json()
738
+ await import_flow_model(model)
739
+
740
+ compiled_flow.deployed = True
741
+
742
+ return compiled_flow
743
+
744
+ def to_json(self) -> dict[str, Any]:
745
+ flow_dict = super().to_json()
746
+
747
+ # serialize nodes
748
+ nodes_dict = {}
749
+ for key, value in self.nodes.items():
750
+ nodes_dict[key] = value.to_json()
751
+ flow_dict["nodes"] = nodes_dict
752
+
753
+ # serialize edges
754
+ flow_dict["edges"] = []
755
+ for edge in self.edges:
756
+ flow_dict["edges"].append(
757
+ edge.model_dump(mode="json", exclude_unset=True, exclude_none=True, by_alias=True))
758
+
759
+ schema_dict = {}
760
+ for key, value in self.schemas.items():
761
+ schema_dict[key] = _to_json_from_json_schema(value)
762
+ flow_dict["schemas"] = schema_dict
763
+
764
+ metadata_dict = {}
765
+ for key, value in self.metadata.items():
766
+ metadata_dict[key] = value
767
+ flow_dict["metadata"] = metadata_dict
768
+ return flow_dict
769
+
770
+ def _get_node_id(self, node: Union[str, Node]) -> str:
771
+ if isinstance(node, Node):
772
+ node_id = node.spec.name
773
+ elif isinstance(node, FlowControl):
774
+ node_id = node.spec.name
775
+ else:
776
+ if (node == START):
777
+ # need to create a start node if one does not yet exist
778
+ if (START not in self.nodes):
779
+ start_node = StartNode(spec=StartNodeSpec(name=START))
780
+ self._add_node(start_node)
781
+ return START
782
+ if (node == END):
783
+ if (END not in self.nodes):
784
+ end_node = EndNode(spec=EndNodeSpec(name=END))
785
+ self._add_node(end_node)
786
+ return END
787
+ node_id = node
788
+ return node_id
789
+
790
+ def _get_data_map(self, map_fn: Callable | List[Assignment]) -> DataMap:
791
+ if map_fn:
792
+ if isinstance(map_fn, Callable):
793
+ raise ValueError("Datamap with function is not supported yet.")
794
+ # map_spec = getattr(map_fn, "__map_spec__", None)
795
+ # if not map_spec:
796
+ # raise ValueError(
797
+ # "Only functions with @map decorator can be used to map between nodes.")
798
+ # map_spec_copy = copy.deepcopy(map_spec)
799
+ # self.refactor_datamap_spec_to_schemaref(map_spec_copy)
800
+ # data_map = FnDataMap(spec=map_spec_copy)
801
+ # return data_map
802
+ elif isinstance(map_fn, list):
803
+ data_map = AssignmentDataMap(spec=AssignmentDataMapSpec(
804
+ name="assignment",
805
+ maps=map_fn))
806
+ return data_map
807
+ return None
808
+
809
+
810
+ class FlowRunStatus(str, Enum):
811
+ NOT_STARTED = "not_started"
812
+ IN_PROGRESS = "in_progress"
813
+ COMPLETED = "completed"
814
+ INTERRUPTED = "interrupted"
815
+ FAILED = "failed"
816
+
817
+ class FlowRun(BaseModel):
818
+ '''Instance of a flow that is running.'''
819
+ name: str | None = None
820
+ id: str = None
821
+ flow: Flow
822
+ status: FlowRunStatus = FlowRunStatus.NOT_STARTED
823
+ output: Any = None
824
+ error: Any = None
825
+
826
+ debug: bool = False
827
+ on_flow_end_handler: Callable = None
828
+ on_flow_error_handler: Callable = None
829
+
830
+ model_config = {
831
+ "arbitrary_types_allowed": True
832
+ }
833
+
834
+
835
+ async def _arun_events(self, input_data:dict=None, filters: Sequence[Union[FlowEventType, TaskEventType]]=None) -> AsyncIterator[FlowEvent]:
836
+
837
+ if self.status is not FlowRunStatus.NOT_STARTED:
838
+ raise ValueError("Flow has already been started")
839
+
840
+ # Start the flow
841
+ client:TempusClient = instantiate_client(client=TempusClient)
842
+ ack = client.arun_flow(self.flow.spec.name,input_data)
843
+ self.id=ack["instance_id"]
844
+ self.name = f"{self.flow.spec.name}:{self.id}"
845
+ self.status = FlowRunStatus.IN_PROGRESS
846
+
847
+ # Listen for events
848
+ consumer = StreamConsumer(self.id)
849
+
850
+ async for event in consumer.consume():
851
+ if not event or (filters and event.kind not in filters):
852
+ continue
853
+ if self.debug:
854
+ logger.debug(f"Flow instance `{self.name}` event: `{event.kind}`")
855
+
856
+ self._update_status(event)
857
+
858
+ yield event
859
+
860
+ def _update_status(self, event:FlowEvent):
861
+
862
+ if event.kind == FlowEventType.ON_FLOW_END:
863
+ self.status = FlowRunStatus.COMPLETED
864
+ elif event.kind == FlowEventType.ON_FLOW_ERROR:
865
+ self.status = FlowRunStatus.FAILED
866
+ else:
867
+ self.status = FlowRunStatus.INTERRUPTED if EVENT_TYPE_MAP.get(event.kind, "unknown") == "interrupting" else FlowRunStatus.IN_PROGRESS
868
+
869
+
870
+ if self.debug:
871
+ logger.debug(f"Flow instance `{self.name}` status change: `{self.status}`")
872
+
873
+ async def _arun(self, input_data: dict=None, **kwargs):
874
+
875
+ if self.status is not FlowRunStatus.NOT_STARTED:
876
+ raise ValueError("Flow has already been started")
877
+
878
+ async for event in self._arun_events(input_data):
879
+ if not event:
880
+ continue
881
+
882
+ if event.kind == FlowEventType.ON_FLOW_END:
883
+ # result should come back on the event
884
+ self._on_flow_end(event)
885
+ break
886
+ elif event.kind == FlowEventType.ON_FLOW_ERROR:
887
+ # error should come back on the event
888
+ self._on_flow_error(event)
889
+ break
890
+
891
+ def update_state(self, task_id: str, data: dict) -> Self:
892
+ '''Not Implemented Yet'''
893
+ # update task and continue
894
+ return self
895
+
896
+ def _on_flow_end(self, event:FlowEvent):
897
+
898
+ self.status = FlowRunStatus.COMPLETED
899
+ self.output = event.context.data["output"] if "output" in event.context.data else None
900
+
901
+ if self.debug:
902
+ logger.debug(f"Flow run `{self.name}`: on_complete handler called. Output: {self.output}")
903
+
904
+ if self.on_flow_end_handler:
905
+ self.on_flow_end_handler(self.output)
906
+
907
+
908
+ def _on_flow_error(self, event:FlowEvent):
909
+
910
+ self.status = FlowRunStatus.FAILED
911
+ self.error = event.error
912
+
913
+ if self.debug:
914
+ logger.debug(f"Flow run `{self.name}`: on_error handler called. Error: {self.error}")
915
+
916
+ if self.on_flow_error_handler:
917
+ self.on_flow_error_handler(self.error)
918
+
919
+
920
+ class CompiledFlow(BaseModel):
921
+ '''A compiled version of the flow'''
922
+ flow: Flow
923
+ deployed: bool = False
924
+
925
+ async def invoke(self, input_data:dict=None, on_flow_end_handler: Callable=None, on_flow_error_handler: Callable=None, debug:bool=False, **kwargs) -> FlowRun:
926
+ """
927
+ Sets up and initializes a FlowInstance for the current flow. This only works for CompiledFlow instances that have been deployed.
928
+
929
+ Args:
930
+ input_data (dict, optional): Input data to be passed to the flow. Defaults to None.
931
+ on_flow_end_handler (callable, optional): A callback function to be executed
932
+ when the flow completes successfully. Defaults to None. Takes the flow output as an argument.
933
+ on_flow_error_handler (callable, optional): A callback function to be executed
934
+ when an error occurs during the flow execution. Defaults to None.
935
+ debug (bool, optional): If True, enables debug mode for the flow run. Defaults to False.
936
+
937
+ Returns:
938
+ FlowInstance: An instance of the flow initialized with the provided handlers
939
+ and additional parameters.
940
+ """
941
+
942
+ if self.deployed is False:
943
+ raise ValueError("Flow has not been deployed yet. Please deploy the flow before invoking it by using the Flow.compile_deploy() function.")
944
+
945
+ flow_run = FlowRun(flow=self.flow, on_flow_end_handler=on_flow_end_handler, on_flow_error_handler=on_flow_error_handler, debug=debug, **kwargs)
946
+ asyncio.create_task(flow_run._arun(input_data=input_data, **kwargs))
947
+ return flow_run
948
+
949
+ async def invoke_events(self, input_data:dict=None, filters: Sequence[Union[FlowEventType, TaskEventType]]=None, debug:bool=False) -> AsyncIterator[Tuple[FlowEvent,FlowRun]]:
950
+ """
951
+ Asynchronously runs the flow and yields events received from the flow for the client to handle. This only works for CompiledFlow instances that have been deployed.
952
+
953
+ Args:
954
+ input_data (dict, optional): Input data to be passed to the flow. Defaults to None.
955
+ filters (Sequence[Union[FlowEventType, TaskEventType]], optional):
956
+ A sequence of event types to filter the events. Only events matching these types
957
+ will be yielded. Defaults to None.
958
+ debug (bool, optional): If True, enables debug mode for the flow run. Defaults to False.
959
+
960
+ Yields:
961
+ FlowEvent: Events received from the flow that match the specified filters.
962
+ """
963
+
964
+ if self.deployed is False:
965
+ raise ValueError("Flow has not been deployed yet. Please deploy the flow before invoking it by using the Flow.compile_deploy() function.")
966
+
967
+ flow_run = FlowRun(flow=self.flow, debug=debug)
968
+ async for event in flow_run._arun_events(input_data=input_data, filters=filters):
969
+ yield (event, flow_run)
970
+
971
+ def dump_spec(self, file: str) -> None:
972
+ dumped = self.flow.to_json()
973
+ with open(file, 'w') as f:
974
+ if file.endswith(".yaml") or file.endswith(".yml"):
975
+ yaml.dump(dumped, f)
976
+ elif file.endswith(".json"):
977
+ json.dump(dumped, f, indent=2)
978
+ else:
979
+ raise ValueError('file must end in .json, .yaml, or .yml')
980
+
981
+ def dumps_spec(self) -> str:
982
+ dumped = self.flow.to_json()
983
+ return json.dumps(dumped, indent=2)
984
+
985
+
986
+
987
+ class FlowFactory(BaseModel):
988
+ '''A factory class to create a Flow model'''
989
+
990
+ @staticmethod
991
+ def create_flow(name: str|Callable,
992
+ display_name: str|None=None,
993
+ description: str|None=None,
994
+ initiators: Sequence[str]|None=None,
995
+ input_schema: type[BaseModel]|None=None,
996
+ output_schema: type[BaseModel]|None=None) -> Flow:
997
+ if isinstance(name, Callable):
998
+ flow_spec = getattr(name, "__flow_spec__", None)
999
+ if not flow_spec:
1000
+ raise ValueError("Only functions with @flow_spec can be used to create a Flow specification.")
1001
+ return Flow(spec = flow_spec)
1002
+
1003
+ # create input spec
1004
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
1005
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
1006
+ if initiators is None:
1007
+ initiators = []
1008
+
1009
+ flow_spec = FlowSpec(
1010
+ type="flow",
1011
+ name=name,
1012
+ display_name=display_name,
1013
+ description=description,
1014
+ initiators=initiators,
1015
+ input_schema=ToolRequestBody(
1016
+ type=input_schema_obj.type,
1017
+ properties=input_schema_obj.properties,
1018
+ required=input_schema_obj.required,
1019
+ ) if input_schema_obj else None,
1020
+ output_schema=ToolResponseBody(
1021
+ type=output_schema_obj.type,
1022
+ properties=output_schema_obj.properties,
1023
+ required=output_schema_obj.required
1024
+ ) if output_schema_obj else None,
1025
+ output_schema_object = output_schema_obj
1026
+ )
1027
+
1028
+ return Flow(spec = flow_spec)
1029
+
1030
+
1031
+ class FlowControl(Node):
1032
+ '''A parent object representing a flow control node.'''
1033
+ ...
1034
+
1035
+ class Branch(FlowControl):
1036
+ containing_flow: Flow = Field(description="The containing flow.")
1037
+
1038
+ def __repr__(self):
1039
+ return f"MatchNode(name='{self.spec.name}', description='{self.spec.description}')"
1040
+
1041
+ def policy(self, kind: MatchPolicy) -> Self:
1042
+ '''
1043
+ Set the match policy for this node.
1044
+
1045
+ Parameters:
1046
+ kind (MatchPolicy): The match policy to set.
1047
+
1048
+ Returns:
1049
+ Self: The current node.
1050
+ '''
1051
+ self.spec.match_policy = kind
1052
+ return self
1053
+
1054
+ def _add_case(self, label: str | bool, node: Node)->Self:
1055
+ '''
1056
+ Add a case to this branch.
1057
+
1058
+ Parameters:
1059
+ label (str | bool): The label for this case.
1060
+ node (Node): The node to add as a case.
1061
+
1062
+ Returns:
1063
+ Self: The current node.
1064
+ '''
1065
+ node_id = self.containing_flow._get_node_id(node)
1066
+ self.spec.cases[label] = {
1067
+ "display_name": node_id,
1068
+ "node": node_id
1069
+ }
1070
+ self.containing_flow.edge(self, node)
1071
+
1072
+ return self
1073
+
1074
+ def case(self, label: str | bool, node: Node) -> Self:
1075
+ '''
1076
+ Add a case to this node.
1077
+
1078
+ Parameters:
1079
+ label (str | bool): The label for this case.
1080
+ node (Node): The node to add as a case.
1081
+
1082
+ Returns:
1083
+ Self: The current node.
1084
+ '''
1085
+ if label == "__default__":
1086
+ raise ValueError("Cannot have custom label __default__. Use default() instead.")
1087
+
1088
+ return self._add_case(label, node)
1089
+
1090
+ def default(self, node: Node) -> Self:
1091
+ '''
1092
+ Add a default case to this node.
1093
+
1094
+ Parameters:
1095
+ node (Node): The node to add as a default case.
1096
+
1097
+ Returns:
1098
+ Self: The current node.
1099
+ '''
1100
+ return self._add_case("__default__", node)
1101
+
1102
+ def to_json(self) -> dict[str, Any]:
1103
+ my_dict = super().to_json()
1104
+
1105
+ return my_dict
1106
+
1107
+
1108
+ class Wait(FlowControl):
1109
+ '''
1110
+ A node that represents a wait in a pipeline.
1111
+
1112
+ Attributes:
1113
+ spec (WaitSpec): The specification of the wait node.
1114
+
1115
+ Methods:
1116
+ policy(kind: WaitPolicy) -> Self: Sets the wait policy for the wait node.
1117
+ node(node: Node) -> Self: Adds a node to the list of nodes to wait for.
1118
+ nodes(nodes: List[Node]) -> Self: Adds a list of nodes to the list of nodes to wait for.
1119
+ to_json() -> dict[str, Any]: Converts the wait node to a JSON dictionary.
1120
+ '''
1121
+
1122
+ def policy(self, kind: WaitPolicy) -> Self:
1123
+ '''
1124
+ Sets the wait policy for the wait node.
1125
+
1126
+ Args:
1127
+ kind (WaitPolicy): The wait policy to set.
1128
+
1129
+ Returns:
1130
+ Self: The wait node object.
1131
+ '''
1132
+ self.spec.wait_policy = kind
1133
+ return self
1134
+
1135
+ def node(self, node: Node) -> Self:
1136
+ '''
1137
+ Adds a node to the list of nodes to wait for.
1138
+
1139
+ Args:
1140
+ node (Node): The node to add.
1141
+
1142
+ Returns:
1143
+ Self: The wait node object.
1144
+ '''
1145
+ self.spec.nodes.append(node.spec.name)
1146
+
1147
+ def nodes(self, nodes: List[Node]) -> Self:
1148
+ '''
1149
+ Adds a list of nodes to the list of nodes to wait for.
1150
+
1151
+ Args:
1152
+ nodes (List[Node]): The list of nodes to add.
1153
+
1154
+ Returns:
1155
+ Self: The wait node object.
1156
+ '''
1157
+ for node in nodes:
1158
+ self.spec.nodes.append(node.spec.name)
1159
+
1160
+ def to_json(self) -> dict[str, Any]:
1161
+ my_dict = super().to_json()
1162
+
1163
+ return my_dict
1164
+
1165
+ class Loop(Flow):
1166
+ '''
1167
+ A Loop is a Flow that executes a set of steps repeatedly.
1168
+
1169
+ Args:
1170
+ **kwargs (dict): Arbitrary keyword arguments.
1171
+
1172
+ Returns:
1173
+ dict[str, Any]: A dictionary representation of the Loop object.
1174
+ '''
1175
+
1176
+ def __init__(self, **kwargs):
1177
+ super().__init__(**kwargs)
1178
+
1179
+ # refactor item schema
1180
+ if isinstance(self.spec.evaluator, ScriptNodeSpec):
1181
+ self._refactor_spec_to_schemaref(self.spec.evaluator)
1182
+
1183
+ def to_json(self) -> dict[str, Any]:
1184
+ my_dict = super().to_json()
1185
+
1186
+ return my_dict
1187
+
1188
+
1189
+
1190
+ class Foreach(Flow):
1191
+ '''
1192
+ A flow that iterates over a list of items.
1193
+
1194
+ Args:
1195
+ **kwargs: Arbitrary keyword arguments.
1196
+
1197
+ Returns:
1198
+ dict[str, Any]: A dictionary representation of the flow.
1199
+ '''
1200
+ def __init__(self, **kwargs):
1201
+ super().__init__(**kwargs)
1202
+
1203
+ # refactor item schema
1204
+ if (self.spec.item_schema.type == "object"):
1205
+ self.spec.item_schema = self._add_schema_ref(self.spec.item_schema, self.spec.item_schema.title)
1206
+
1207
+ def policy(self, kind: ForeachPolicy) -> Self:
1208
+ '''
1209
+ Sets the policy for the foreach flow.
1210
+
1211
+ Args:
1212
+ kind (ForeachPolicy): The policy to set.
1213
+
1214
+ Returns:
1215
+ Self: The current instance of the flow.
1216
+ '''
1217
+ self.spec.foreach_policy = kind
1218
+ return self
1219
+
1220
+ def to_json(self) -> dict[str, Any]:
1221
+ my_dict = super().to_json()
1222
+
1223
+ return my_dict
1224
+
1225
+ class FlowValidationKind(str, Enum):
1226
+ '''
1227
+ This class defines the type of validation for a flow.
1228
+
1229
+ Attributes:
1230
+ ERROR (str): Indicates an error in the flow.
1231
+ WARNING (str): Indicates a warning in the flow.
1232
+ INFO (str): Indicates informational messages related to the flow.
1233
+ '''
1234
+ ERROR = "ERROR",
1235
+ WARNING = "WARNING",
1236
+ INFO = "INFO"
1237
+
1238
+ class FlowValidationMessage(BaseModel):
1239
+ '''
1240
+ FlowValidationMessage class to store validation messages for a flow.
1241
+
1242
+ Attributes:
1243
+ kind (FlowValidationKind): The kind of validation message.
1244
+ message (str): The validation message.
1245
+ node (Node): The node associated with the validation message.
1246
+
1247
+ Methods:
1248
+ __init__(self, kind: FlowValidationKind, message: str, node: Node) -> None:
1249
+ Initializes the FlowValidationMessage object with the given parameters.
1250
+ '''
1251
+ kind: FlowValidationKind
1252
+ message: str
1253
+ node: Node
1254
+
1255
+ class FlowValidator(BaseModel):
1256
+ '''Validate the flow to ensure it is valid and runnable.'''
1257
+ flow: Flow
1258
+
1259
+ def validate_model(self) -> List[FlowValidationMessage]:
1260
+ '''Check the model for possible errors.
1261
+
1262
+ Returns:
1263
+ List[FlowValidationMessage]: A list of validation messages.
1264
+ '''
1265
+ return []
1266
+
1267
+ def any_errors(self, messages: List[FlowValidationMessage]) -> bool:
1268
+ '''
1269
+ Check if any of the messages have a kind of ERROR.
1270
+
1271
+ Args:
1272
+ messages (List[FlowValidationMessage]): A list of validation messages.
1273
+
1274
+ Returns:
1275
+ bool: True if there are any errors, False otherwise.
1276
+ '''
1277
+ return any(m.kind == FlowValidationKind.ERROR for m in messages)
1278
+
1279
+ def no_error(self, messages: List[FlowValidationMessage]) -> bool:
1280
+ '''Check if there are no errors in the messages.
1281
+
1282
+ Args:
1283
+ messages (List[FlowValidationMessage]): A list of validation messages.
1284
+
1285
+ Returns:
1286
+ bool: True if there are no errors, False otherwise.
1287
+ '''
1288
+ return not any(m.kind == FlowValidationKind.ERROR for m in messages)