ibm-watsonx-orchestrate 1.3.0__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. ibm_watsonx_orchestrate/__init__.py +2 -1
  2. ibm_watsonx_orchestrate/agent_builder/agents/types.py +2 -0
  3. ibm_watsonx_orchestrate/agent_builder/knowledge_bases/types.py +10 -2
  4. ibm_watsonx_orchestrate/agent_builder/toolkits/base_toolkit.py +32 -0
  5. ibm_watsonx_orchestrate/agent_builder/toolkits/types.py +42 -0
  6. ibm_watsonx_orchestrate/agent_builder/tools/openapi_tool.py +10 -1
  7. ibm_watsonx_orchestrate/agent_builder/tools/python_tool.py +4 -2
  8. ibm_watsonx_orchestrate/agent_builder/tools/types.py +2 -1
  9. ibm_watsonx_orchestrate/cli/commands/agents/agents_command.py +29 -0
  10. ibm_watsonx_orchestrate/cli/commands/agents/agents_controller.py +271 -12
  11. ibm_watsonx_orchestrate/cli/commands/knowledge_bases/knowledge_bases_controller.py +17 -2
  12. ibm_watsonx_orchestrate/cli/commands/models/env_file_model_provider_mapper.py +180 -0
  13. ibm_watsonx_orchestrate/cli/commands/models/models_command.py +199 -8
  14. ibm_watsonx_orchestrate/cli/commands/server/server_command.py +117 -48
  15. ibm_watsonx_orchestrate/cli/commands/server/types.py +105 -0
  16. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_command.py +55 -7
  17. ibm_watsonx_orchestrate/cli/commands/toolkit/toolkit_controller.py +123 -42
  18. ibm_watsonx_orchestrate/cli/commands/tools/tools_command.py +22 -1
  19. ibm_watsonx_orchestrate/cli/commands/tools/tools_controller.py +197 -12
  20. ibm_watsonx_orchestrate/client/agents/agent_client.py +4 -1
  21. ibm_watsonx_orchestrate/client/agents/assistant_agent_client.py +5 -1
  22. ibm_watsonx_orchestrate/client/agents/external_agent_client.py +5 -1
  23. ibm_watsonx_orchestrate/client/analytics/llm/analytics_llm_client.py +2 -6
  24. ibm_watsonx_orchestrate/client/base_api_client.py +5 -2
  25. ibm_watsonx_orchestrate/client/connections/connections_client.py +3 -9
  26. ibm_watsonx_orchestrate/client/model_policies/__init__.py +0 -0
  27. ibm_watsonx_orchestrate/client/model_policies/model_policies_client.py +47 -0
  28. ibm_watsonx_orchestrate/client/model_policies/types.py +36 -0
  29. ibm_watsonx_orchestrate/client/models/__init__.py +0 -0
  30. ibm_watsonx_orchestrate/client/models/models_client.py +46 -0
  31. ibm_watsonx_orchestrate/client/models/types.py +189 -0
  32. ibm_watsonx_orchestrate/client/toolkit/toolkit_client.py +20 -6
  33. ibm_watsonx_orchestrate/client/tools/tempus_client.py +40 -0
  34. ibm_watsonx_orchestrate/client/tools/tool_client.py +8 -0
  35. ibm_watsonx_orchestrate/docker/compose-lite.yml +68 -13
  36. ibm_watsonx_orchestrate/docker/default.env +22 -12
  37. ibm_watsonx_orchestrate/docker/tempus/common-config.yaml +1 -1
  38. ibm_watsonx_orchestrate/experimental/flow_builder/__init__.py +0 -0
  39. ibm_watsonx_orchestrate/experimental/flow_builder/data_map.py +19 -0
  40. ibm_watsonx_orchestrate/experimental/flow_builder/flows/__init__.py +42 -0
  41. ibm_watsonx_orchestrate/experimental/flow_builder/flows/constants.py +19 -0
  42. ibm_watsonx_orchestrate/experimental/flow_builder/flows/decorators.py +144 -0
  43. ibm_watsonx_orchestrate/experimental/flow_builder/flows/events.py +72 -0
  44. ibm_watsonx_orchestrate/experimental/flow_builder/flows/flow.py +1310 -0
  45. ibm_watsonx_orchestrate/experimental/flow_builder/node.py +116 -0
  46. ibm_watsonx_orchestrate/experimental/flow_builder/resources/flow_status.openapi.yml +66 -0
  47. ibm_watsonx_orchestrate/experimental/flow_builder/types.py +765 -0
  48. ibm_watsonx_orchestrate/experimental/flow_builder/utils.py +115 -0
  49. ibm_watsonx_orchestrate/utils/utils.py +5 -2
  50. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.5.0b0.dist-info}/METADATA +4 -1
  51. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.5.0b0.dist-info}/RECORD +54 -32
  52. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.5.0b0.dist-info}/WHEEL +0 -0
  53. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.5.0b0.dist-info}/entry_points.txt +0 -0
  54. {ibm_watsonx_orchestrate-1.3.0.dist-info → ibm_watsonx_orchestrate-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1310 @@
1
+ """
2
+ The Flow model. There are multiple methods to allow creation and population of
3
+ the Flow model.
4
+ """
5
+
6
+ import asyncio
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ import inspect
10
+ from typing import (
11
+ Any, AsyncIterator, Callable, cast, List, Sequence, Union, Tuple
12
+ )
13
+ import json
14
+ import logging
15
+ import copy
16
+ import uuid
17
+ import pytz
18
+
19
+ from typing_extensions import Self
20
+ from pydantic import BaseModel, Field, SerializeAsAny
21
+ import yaml
22
+ from ibm_watsonx_orchestrate.agent_builder.tools.python_tool import PythonTool
23
+ from ibm_watsonx_orchestrate.client.tools.tempus_client import TempusClient
24
+ from ibm_watsonx_orchestrate.client.utils import instantiate_client
25
+ from ..types import (
26
+ EndNodeSpec, Expression, ForeachPolicy, ForeachSpec, LoopSpec, BranchNodeSpec, MatchPolicy, PromptLLMParameters, PromptNodeSpec,
27
+ StartNodeSpec, ToolSpec, JsonSchemaObject, ToolRequestBody, ToolResponseBody, UserFieldKind, UserFieldOption, UserFlowSpec, UserNodeSpec, WaitPolicy
28
+ )
29
+ from .constants import CURRENT_USER, START, END, ANY_USER
30
+ from ..node import (
31
+ EndNode, Node, PromptNode, StartNode, UserNode, AgentNode, DataMap, ToolNode
32
+ )
33
+ from ..types import (
34
+ AgentNodeSpec, extract_node_spec, FlowContext, FlowEventType, FlowEvent, FlowSpec,
35
+ NodeSpec, TaskEventType, ToolNodeSpec, SchemaRef, JsonSchemaObjectRef, _to_json_from_json_schema
36
+ )
37
+
38
+ from ..data_map import DataMap
39
+ from ..utils import _get_json_schema_obj, get_valid_name, import_flow_model
40
+
41
+ from .events import StreamConsumer
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # Mapping each event to its type
46
+ EVENT_TYPE_MAP = {
47
+ FlowEventType.ON_FLOW_START: "informational",
48
+ FlowEventType.ON_FLOW_END: "informational",
49
+ FlowEventType.ON_FLOW_ERROR: "interrupting",
50
+ TaskEventType.ON_TASK_WAIT: "interrupting",
51
+ TaskEventType.ON_TASK_START: "informational",
52
+ TaskEventType.ON_TASK_END: "informational",
53
+ TaskEventType.ON_TASK_STREAM: "interrupting",
54
+ TaskEventType.ON_TASK_ERROR: "interrupting",
55
+ }
56
+
57
+ class FlowEdge(BaseModel):
58
+ '''Used to designate the edge of a flow.'''
59
+ start: str
60
+ end: str
61
+
62
+ class Flow(Node):
63
+ '''Flow represents a flow that will be run by wxO Flow engine.'''
64
+ output_map: DataMap | None = None
65
+ nodes: dict[str, SerializeAsAny[Node]] = {}
66
+ edges: List[FlowEdge] = []
67
+ schemas: dict[str, JsonSchemaObject] = {}
68
+ compiled: bool = False
69
+ validated: bool = False
70
+ metadata: dict[str, str] = {}
71
+ parent: Any = None
72
+ _sequence_id: int = 0 # internal-id
73
+
74
+ def __init__(self, **kwargs):
75
+ super().__init__(**kwargs)
76
+
77
+ # extract data schemas
78
+ self._refactor_node_to_schemaref(self)
79
+
80
+ def _find_topmost_flow(self) -> Self:
81
+ if self.parent:
82
+ return self.parent._find_topmost_flow()
83
+ return self
84
+
85
+ def _next_sequence_id(self) -> int:
86
+ self._sequence_id += 1
87
+ return self._sequence_id
88
+
89
+ def _add_schema(self, schema: JsonSchemaObject, title: str = None) -> JsonSchemaObject:
90
+ '''
91
+ Adds a schema to the dictionary of schemas. If a schema with the same name already exists, it returns the existing schema. Otherwise, it creates a deep copy of the schema, adds it to the dictionary, and returns the new schema.
92
+
93
+ Parameters:
94
+ schema (JsonSchemaObject): The schema to be added.
95
+ title (str, optional): The title of the schema. If not provided, it will be generated based on the schema's title or aliasName.
96
+
97
+ Returns:
98
+ JsonSchemaObject: The added or existing schema.
99
+ '''
100
+
101
+ # find the top most flow and add the schema to that scope
102
+ top_flow = self._find_topmost_flow()
103
+
104
+ # if there is already a schema with the same name, return it
105
+ if title:
106
+ if title in top_flow.schemas:
107
+ return top_flow.schemas[title]
108
+
109
+ # otherwise, create a deep copy of the schema, add it to the dictionary and return it
110
+ if schema:
111
+ if isinstance(schema, dict):
112
+ # recast schema to support direct access
113
+ schema = JsonSchemaObject.model_validate(schema)
114
+ # we should only add schema when it is a complex object
115
+ if schema.type != "object" and schema.type != "array":
116
+ return schema
117
+
118
+ new_schema = copy.deepcopy(schema)
119
+ if not title:
120
+ if schema.title:
121
+ title = get_valid_name(schema.title)
122
+ elif schema.aliasName:
123
+ title = get_valid_name(schema.aliasName)
124
+ else:
125
+ title = "bo_" + uuid.uuid4().hex
126
+
127
+ if new_schema.type == "object":
128
+ # iterate the properties and add schema recursively
129
+ if new_schema.properties is not None:
130
+ for key, value in new_schema.properties.items():
131
+ if isinstance(value, JsonSchemaObject):
132
+ if value.type == "object":
133
+ schema_ref = self._add_schema_ref(value, value.title)
134
+ new_schema.properties[key] = JsonSchemaObjectRef(title=value.title,
135
+ ref = f"{schema_ref.ref}")
136
+ elif value.type == "array" and (value.items.type == "object" or value.items.type == "array"):
137
+ schema_ref = self._add_schema_ref(value.items, value.items.title)
138
+ new_schema.properties[key].items = JsonSchemaObjectRef(title=value.title,
139
+ ref = f"{schema_ref.ref}")
140
+ elif value.model_extra and hasattr(value.model_extra, "$ref"):
141
+ # there is already a reference, remove $/defs/ from the initial ref
142
+ ref_value = value.model_extra["$ref"]
143
+ schema_ref = f"#/schemas/{ref_value[8:]}"
144
+ new_schema.properties[key] = JsonSchemaObjectRef(ref = f"{schema_ref}")
145
+
146
+ elif new_schema.type == "array":
147
+ if new_schema.items.type == "object" or new_schema.items.type == "array":
148
+ schema_ref = self._add_schema_ref(new_schema.items, new_schema.items.title)
149
+ new_schema.items = JsonSchemaObjectRef(title=new_schema.items.title,
150
+ ref= f"{schema_ref.ref}")
151
+
152
+ # we also need to unpack local references
153
+ if hasattr(new_schema, "model_extra") and "$defs" in new_schema.model_extra:
154
+ for schema_name, schema_def in new_schema.model_extra["$defs"].items():
155
+ self._add_schema(schema_def, schema_name)
156
+
157
+ # set the title
158
+ new_schema.title = title
159
+ top_flow.schemas[title] = new_schema
160
+
161
+ return new_schema
162
+ return None
163
+
164
+ def _add_schema_ref(self, schema: JsonSchemaObject, title: str = None) -> SchemaRef:
165
+ '''Create a schema reference'''
166
+ if schema and (schema.type == "object" or schema.type == "array"):
167
+ new_schema = self._add_schema(schema, title)
168
+ return SchemaRef(ref=f"#/schemas/{new_schema.title}")
169
+ raise AssertionError(f"schema is not a complex object: {schema}")
170
+
171
+ def _refactor_node_to_schemaref(self, node: Node):
172
+ self._refactor_spec_to_schemaref(node.spec)
173
+
174
+ def _refactor_spec_to_schemaref(self, spec: NodeSpec):
175
+ if spec.input_schema:
176
+ if isinstance(spec.input_schema, ToolRequestBody):
177
+ spec.input_schema = self._add_schema_ref(JsonSchemaObject(type = spec.input_schema.type,
178
+ properties= spec.input_schema.properties,
179
+ required= spec.input_schema.required),
180
+ f"{spec.name}_input")
181
+ if spec.output_schema_object is not None and spec.output_schema_object.type == "object":
182
+ spec.output_schema = self._add_schema_ref(spec.output_schema_object, spec.output_schema_object.title)
183
+ spec.output_schema_object = None
184
+ elif spec.output_schema is not None:
185
+ if isinstance(spec.output_schema, ToolResponseBody):
186
+ if spec.output_schema.type == "object":
187
+ json_obj = JsonSchemaObject(type = spec.output_schema.type,
188
+ description=spec.output_schema.description,
189
+ properties= spec.output_schema.properties,
190
+ items = spec.output_schema.items,
191
+ uniqueItems=spec.output_schema.uniqueItems,
192
+ anyOf=spec.output_schema.anyOf,
193
+ required= spec.output_schema.required)
194
+ spec.output_schema = self._add_schema_ref(json_obj, f"{spec.name}_output")
195
+ elif spec.output_schema.type == "array":
196
+ if spec.output_schema.items.type == "object":
197
+ schema_ref = self._add_schema_ref(spec.output_schema.items)
198
+ spec.output_schema.items = JsonSchemaObjectRef(ref=f"{schema_ref.ref}")
199
+
200
+ # def refactor_datamap_spec_to_schemaref(self, spec: FnDataMapSpec):
201
+ # '''TODO'''
202
+ # if spec.input_schema:
203
+ # if isinstance(spec.input_schema, ToolRequestBody):
204
+ # spec.input_schema = self._add_schema_ref(JsonSchemaObject(type = spec.input_schema.type,
205
+ # properties= spec.input_schema.properties,
206
+ # required= spec.input_schema.required),
207
+ # f"{spec.name}_input")
208
+ # if spec.output_schema_object is not None and spec.output_schema_object.type == "object":
209
+ # spec.output_schema = self._add_schema_ref(spec.output_schema_object, spec.output_schema_object.title)
210
+ # spec.output_schema_object = None
211
+ # elif spec.output_schema is not None:
212
+ # if isinstance(spec.output_schema, ToolResponseBody):
213
+ # spec.output_schema = self._add_schema_ref(JsonSchemaObject(type = spec.output_schema.type,
214
+ # Sdescription=spec.output_schema.description,
215
+ # properties= spec.output_schema.properties,
216
+ # items = spec.output_schema.items,
217
+ # uniqueItems=spec.output_schema.uniqueItems,
218
+ # anyOf=spec.output_schema.anyOf,
219
+ # required= spec.output_schema.required),
220
+ # f"{spec.name}_output")
221
+
222
+ def _create_node_from_tool_fn(
223
+ self,
224
+ tool: Callable
225
+ ) -> ToolNode:
226
+ if not isinstance(tool, Callable):
227
+ raise ValueError("Only functions with @tool decorator can be added.")
228
+
229
+ spec = getattr(tool, "__tool_spec__", None)
230
+ if not spec:
231
+ raise ValueError("Only functions with @tool decorator can be added.")
232
+
233
+ self._check_compiled()
234
+
235
+ tool_spec = cast(ToolSpec, spec)
236
+
237
+ # we need more information from the function signature
238
+ spec = extract_node_spec(tool)
239
+
240
+ toolnode_spec = ToolNodeSpec(type = "tool",
241
+ name = tool_spec.name,
242
+ display_name = tool_spec.name,
243
+ description = tool_spec.description,
244
+ input_schema = tool_spec.input_schema,
245
+ output_schema = tool_spec.output_schema,
246
+ output_schema_object = spec.output_schema_object,
247
+ tool = tool_spec.name)
248
+
249
+ return ToolNode(spec=toolnode_spec)
250
+
251
+ def tool(
252
+ self,
253
+ tool: Callable | str | None = None,
254
+ name: str | None = None,
255
+ display_name: str | None = None,
256
+ description: str | None = None,
257
+
258
+ input_schema: type[BaseModel] | None = None,
259
+ output_schema: type[BaseModel] | None = None,
260
+ input_map: DataMap = None
261
+ ) -> ToolNode:
262
+ '''create a tool node in the flow'''
263
+ if tool is None:
264
+ raise ValueError("tool must be provided")
265
+
266
+ if isinstance(tool, str):
267
+ name = name if name is not None and name != "" else tool
268
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
269
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
270
+
271
+ toolnode_spec = ToolNodeSpec(type = "tool",
272
+ name = name,
273
+ display_name = display_name,
274
+ description = description,
275
+ input_schema=ToolRequestBody(
276
+ type=input_schema_obj.type,
277
+ properties=input_schema_obj.properties,
278
+ required=input_schema_obj.required,
279
+ ) if input_schema is not None else None,
280
+ output_schema=ToolResponseBody(
281
+ type=output_schema_obj.type,
282
+ properties=output_schema_obj.properties,
283
+ required=output_schema_obj.required
284
+ ) if output_schema is not None else None,
285
+ output_schema_object = output_schema_obj,
286
+ tool = tool)
287
+
288
+ node = ToolNode(spec=toolnode_spec)
289
+ elif isinstance(tool, PythonTool):
290
+ if callable(tool):
291
+ tool_spec = getattr(tool, "__tool_spec__", None)
292
+ if tool_spec:
293
+ node = self._create_node_from_tool_fn(tool)
294
+ else:
295
+ raise ValueError("Only functions with @tool decorator can be added.")
296
+ else:
297
+ raise ValueError(f"tool is not a string or Callable: {tool}")
298
+
299
+ # setup input and output map
300
+ if input_map:
301
+ node.input_map = self._get_data_map(input_map)
302
+
303
+ node = self._add_node(node)
304
+ return cast(ToolNode, node)
305
+
306
+
307
+ def _add_node(self, node: Node) -> Node:
308
+ self._check_compiled()
309
+
310
+ if node.spec.name in self.nodes:
311
+ raise ValueError(f"Node `{id}` already present.")
312
+
313
+ # make a copy
314
+ new_node = copy.copy(node)
315
+
316
+ self._refactor_node_to_schemaref(new_node)
317
+
318
+ self.nodes[node.spec.name] = new_node
319
+ return new_node
320
+
321
+ def agent(self,
322
+ name: str,
323
+ agent: str,
324
+ display_name: str|None=None,
325
+ message: str | None = "Follow the agent instructions.",
326
+ description: str | None = None,
327
+ input_schema: type[BaseModel]|None = None,
328
+ output_schema: type[BaseModel]|None=None,
329
+ guidelines: str|None=None,
330
+ input_map: DataMap = None) -> AgentNode:
331
+
332
+ # create input spec
333
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
334
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
335
+
336
+ # Create the tool spec
337
+ task_spec = AgentNodeSpec(
338
+ name=name,
339
+ display_name=display_name,
340
+ description=description,
341
+ agent=agent,
342
+ message=message,
343
+ guidelines=guidelines,
344
+ input_schema=ToolRequestBody(
345
+ type=input_schema_obj.type,
346
+ properties=input_schema_obj.properties,
347
+ required=input_schema_obj.required,
348
+ ),
349
+ output_schema=ToolResponseBody(
350
+ type=output_schema_obj.type,
351
+ properties=output_schema_obj.properties,
352
+ required=output_schema_obj.required
353
+ ),
354
+ output_schema_object = output_schema_obj
355
+ )
356
+
357
+ node = AgentNode(spec=task_spec)
358
+ # setup input map
359
+ if input_map:
360
+ node.input_map = self._get_data_map(input_map)
361
+
362
+ # add the node to the list of node
363
+ node = self._add_node(node)
364
+ return cast(AgentNode, node)
365
+
366
+ def prompt(self,
367
+ name: str,
368
+ display_name: str|None=None,
369
+ system_prompt: str | list[str] | None = None,
370
+ user_prompt: str | list[str] | None = None,
371
+ llm: str | None = None,
372
+ llm_parameters: PromptLLMParameters | None = None,
373
+ description: str | None = None,
374
+ input_schema: type[BaseModel]|None = None,
375
+ output_schema: type[BaseModel]|None=None,
376
+ input_map: DataMap = None) -> PromptNode:
377
+
378
+ if name is None:
379
+ raise ValueError("name must be provided.")
380
+
381
+ # create input spec
382
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
383
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
384
+
385
+ # Create the tool spec
386
+ task_spec = PromptNodeSpec(
387
+ name=name,
388
+ display_name=display_name if display_name is not None else name,
389
+ description=description,
390
+ system_prompt=system_prompt,
391
+ user_prompt=user_prompt,
392
+ llm=llm,
393
+ llm_parameters=llm_parameters,
394
+ input_schema=ToolRequestBody(
395
+ type=input_schema_obj.type,
396
+ properties=input_schema_obj.properties,
397
+ required=input_schema_obj.required,
398
+ ),
399
+ output_schema=ToolResponseBody(
400
+ type=output_schema_obj.type,
401
+ properties=output_schema_obj.properties,
402
+ required=output_schema_obj.required
403
+ ),
404
+ output_schema_object = output_schema_obj
405
+ )
406
+
407
+ node = PromptNode(spec=task_spec)
408
+ # setup input map
409
+ if input_map:
410
+ node.input_map = self._get_data_map(input_map)
411
+
412
+ # add the node to the list of node
413
+ node = self._add_node(node)
414
+ return cast(PromptNode, node)
415
+
416
+ def node_exists(self, node: Union[str, Node]):
417
+
418
+ if isinstance(node, Node):
419
+ node_id = node.spec.name
420
+ else:
421
+ node_id = node
422
+
423
+ if (node_id == END or node_id == START):
424
+ return True
425
+ if node_id in self.nodes:
426
+ return True
427
+ return False
428
+
429
+ def edge(self,
430
+ start_task: Union[str, Node],
431
+ end_task: Union[str, Node]) -> Self:
432
+
433
+ self._check_compiled()
434
+
435
+ start_id = self._get_node_id(start_task)
436
+ end_id = self._get_node_id(end_task)
437
+
438
+ if not self.node_exists(start_id):
439
+ raise ValueError(f"Node {start_id} has not been added to the flow yet.")
440
+ if not self.node_exists(end_id):
441
+ raise ValueError(f"Node {end_id} has not been added to the flow yet.")
442
+ if start_id == END:
443
+ raise ValueError("END cannot be used as a start Node")
444
+ if end_id == START:
445
+ raise ValueError("START cannot be used as an end Node")
446
+
447
+ # Run this validation only for non-StateGraph graphs
448
+ self.edges.append(FlowEdge(start = start_id, end = end_id))
449
+ return self
450
+
451
+ def sequence(self, *elements: Union[str, Node] | None) -> Self:
452
+ '''TODO: Docstrings'''
453
+ start_element: Union[str, Node] | None = None
454
+ for element in elements:
455
+ if not start_element:
456
+ start_element = element
457
+ else:
458
+ end_element = element
459
+
460
+ if isinstance(start_element, str):
461
+ start_node = start_element
462
+ elif isinstance(start_element, Node):
463
+ start_node = start_element
464
+ else:
465
+ start_node = START
466
+
467
+ if isinstance(end_element, str):
468
+ end_node = end_element
469
+ elif isinstance(end_element, Node):
470
+ end_node = end_element
471
+ else:
472
+ end_node = END
473
+
474
+ self.edge(start_node, end_node)
475
+
476
+ # set start as the current end element
477
+ start_element = end_element
478
+
479
+ return self
480
+
481
+ def starts_with(self, node: Union[str, Node]) -> Self:
482
+ '''Create an edge with an automatic START node.'''
483
+ return self.edge(START, node)
484
+
485
+ def ends_with(self, node: Union[str, Node]) -> Self:
486
+ '''Create an edge with an automatic END node.'''
487
+ return self.edge(node, END)
488
+
489
+ def starts_and_ends_with(self, node: Union[str, Node]) -> Self:
490
+ '''Create a single node flow with an automatic START and END node.'''
491
+ return self.sequence(START, node, END)
492
+
493
+ def branch(self, evaluator: Union[Callable, Expression]) -> "Branch":
494
+ '''Create a BRANCH node'''
495
+ e = evaluator
496
+ if isinstance(evaluator, Callable):
497
+ # We need to get the python tool representation of it
498
+ raise ValueError("Branch with function as an evaluator is not supported yet.")
499
+ # script_spec = getattr(evaluator, "__script_spec__", None)
500
+ # if not script_spec:
501
+ # raise ValueError("Only functions with @script can be used as an evaluator.")
502
+ # new_script_spec = copy.deepcopy(script_spec)
503
+ # self._refactor_spec_to_schemaref(new_script_spec)
504
+ # e = new_script_spec
505
+ elif isinstance(evaluator, str):
506
+ e = Expression(expression=evaluator)
507
+
508
+ spec = BranchNodeSpec(name = "branch_" + str(self._next_sequence_id()), evaluator=e)
509
+ branch_node = Branch(spec = spec, containing_flow=self)
510
+ return cast(Branch, self._add_node(branch_node))
511
+
512
+ def wait_for(self, *args) -> "Wait":
513
+ '''Wait for all incoming nodes to complete.'''
514
+ raise ValueError("Not implemented yet.")
515
+ # spec = NodeSpec(name = "wait_" + uuid.uuid4().hex)
516
+ # wait_node = Wait(spec = spec)
517
+
518
+ # for arg in args:
519
+ # if isinstance(arg, Node):
520
+ # wait_node.node(arg)
521
+ # else:
522
+ # raise ValueError("Only nodes can be added to a wait node.")
523
+
524
+ # return cast(Wait, self.node(wait_node))
525
+
526
+
527
+ def foreach(self, item_schema: type[BaseModel],
528
+ input_schema: type[BaseModel] |None=None,
529
+ output_schema: type[BaseModel] |None=None) -> "Foreach": # return an Foreach object
530
+ '''TODO: Docstrings'''
531
+
532
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
533
+ input_schema_obj = _get_json_schema_obj("input", input_schema)
534
+ foreach_item_schema = _get_json_schema_obj("item_schema", item_schema)
535
+
536
+ if input_schema_obj is None:
537
+ input_schema_obj = JsonSchemaObject(
538
+ type = 'object',
539
+ properties = {
540
+ "items": JsonSchemaObject(
541
+ type = "array",
542
+ items = foreach_item_schema)
543
+ },
544
+ required = ["items"])
545
+
546
+ spec = ForeachSpec(name = "foreach_" + str(self._next_sequence_id()),
547
+ input_schema=ToolRequestBody(
548
+ type=input_schema_obj.type,
549
+ properties=input_schema_obj.properties,
550
+ required=input_schema_obj.required,
551
+ ),
552
+ output_schema=ToolResponseBody(
553
+ type=output_schema_obj.type,
554
+ properties=output_schema_obj.properties,
555
+ required=output_schema_obj.required
556
+ ) if output_schema_obj is not None else None,
557
+ item_schema = foreach_item_schema)
558
+ foreach_obj = Foreach(spec = spec, parent = self)
559
+ foreach_node = self._add_node(foreach_obj)
560
+ self._add_schema(foreach_item_schema)
561
+
562
+ return cast(Flow, foreach_node)
563
+
564
+ def loop(self, evaluator: Union[Callable, Expression],
565
+ input_schema: type[BaseModel]|None=None,
566
+ output_schema: type[BaseModel]|None=None) -> "Loop": # return a WhileLoop object
567
+ e = evaluator
568
+ input_schema_obj = _get_json_schema_obj("input", input_schema)
569
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
570
+
571
+ if isinstance(evaluator, Callable):
572
+ # we need to get the python tool representation of it
573
+ script_spec = getattr(evaluator, "__script_spec__", None)
574
+ if not script_spec:
575
+ raise ValueError("Only function with @script can be used as evaluator")
576
+ new_script_spec = copy.deepcopy(script_spec)
577
+ e = new_script_spec
578
+ elif isinstance(evaluator, str):
579
+ e = Expression(expression=evaluator)
580
+
581
+ loop_spec = LoopSpec(name = "loop_" + str(self._next_sequence_id()),
582
+ evaluator = e,
583
+ input_schema=ToolRequestBody(
584
+ type=input_schema_obj.type,
585
+ properties=input_schema_obj.properties,
586
+ required=input_schema_obj.required,
587
+ ) if input_schema_obj is not None else None,
588
+ output_schema=ToolResponseBody(
589
+ type=output_schema_obj.type,
590
+ properties=output_schema_obj.properties,
591
+ required=output_schema_obj.required
592
+ ) if output_schema_obj is not None else None)
593
+ while_loop = Loop(spec = loop_spec, parent = self)
594
+ while_node = self._add_node(while_loop)
595
+ return cast(Loop, while_node)
596
+
597
+ def userflow(self,
598
+ owners: Sequence[str] = [],
599
+ input_schema: type[BaseModel] |None=None,
600
+ output_schema: type[BaseModel] |None=None) -> "UserFlow": # return a UserFlow object
601
+
602
+ logger.warning("userflow is NOT working yet.")
603
+
604
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
605
+ input_schema_obj = _get_json_schema_obj("input", input_schema)
606
+
607
+ spec = UserFlowSpec(name = "userflow_" + str(self._next_sequence_id()),
608
+ input_schema=ToolRequestBody(
609
+ type=input_schema_obj.type,
610
+ properties=input_schema_obj.properties,
611
+ required=input_schema_obj.required,
612
+ ) if input_schema_obj is not None else None,
613
+ output_schema=ToolResponseBody(
614
+ type=output_schema_obj.type,
615
+ properties=output_schema_obj.properties,
616
+ required=output_schema_obj.required
617
+ ) if output_schema_obj is not None else None,
618
+ owners = owners)
619
+ userflow_obj = UserFlow(spec = spec, parent = self)
620
+ userflow_node = self._add_node(userflow_obj)
621
+
622
+ return cast(UserFlow, userflow_node)
623
+
624
+ def validate_model(self) -> bool:
625
+ ''' Validate the model. '''
626
+ validator = FlowValidator(flow=self)
627
+ messages = validator.validate_model()
628
+ if validator.no_error(messages):
629
+ return True
630
+ raise ValueError(f"Invalid flow: {messages}")
631
+
632
+ def _check_compiled(self) -> None:
633
+ if self.compiled:
634
+ raise ValueError("Flow has already been compiled.")
635
+
636
+ def compile(self, **kwargs) -> "CompiledFlow":
637
+ """
638
+ Compile the current Flow model into a CompiledFlow object.
639
+
640
+ This method validates the flow model (if not already validated).
641
+
642
+ To also deploy the model to the engine and test it use the compile_deploy() function.
643
+
644
+ Returns:
645
+ CompiledFlow: An instance of the CompiledFlow class representing
646
+ the compiled flow.
647
+
648
+ Raises:
649
+ ValidationError: If the flow model is invalid and fails validation.
650
+ """
651
+
652
+ if not self.validated:
653
+ # we need to validate the flow first
654
+ self.validate_model()
655
+
656
+ self.compiled = True
657
+ self.metadata["source_kind"] = "adk/python"
658
+ self.metadata["compiled_on"] = datetime.now(pytz.utc).isoformat()
659
+ return CompiledFlow(flow=self, **kwargs)
660
+
661
+ async def compile_deploy(self, **kwargs) -> "CompiledFlow":
662
+ """
663
+ Compile the current Flow model into a CompiledFlow object.
664
+
665
+ This method validates the flow model (if not already validated),
666
+ deploys it to the engine, and marks it as compiled.
667
+
668
+ You can use the compiled flow to start a flow run.
669
+
670
+ Returns:
671
+ CompiledFlow: An instance of the CompiledFlow class representing
672
+ the compiled flow.
673
+
674
+ Raises:
675
+ ValidationError: If the flow model is invalid and fails validation.
676
+ """
677
+
678
+ compiled_flow = self.compile(**kwargs)
679
+
680
+ # Deploy flow to the engine
681
+ model = self.to_json()
682
+ await import_flow_model(model)
683
+
684
+ compiled_flow.deployed = True
685
+
686
+ return compiled_flow
687
+
688
+ def to_json(self) -> dict[str, Any]:
689
+ flow_dict = super().to_json()
690
+
691
+ # serialize nodes
692
+ nodes_dict = {}
693
+ for key, value in self.nodes.items():
694
+ nodes_dict[key] = value.to_json()
695
+ flow_dict["nodes"] = nodes_dict
696
+
697
+ # serialize edges
698
+ flow_dict["edges"] = []
699
+ for edge in self.edges:
700
+ flow_dict["edges"].append(
701
+ edge.model_dump(mode="json", exclude_unset=True, exclude_none=True, by_alias=True))
702
+
703
+ schema_dict = {}
704
+ for key, value in self.schemas.items():
705
+ schema_dict[key] = _to_json_from_json_schema(value)
706
+ flow_dict["schemas"] = schema_dict
707
+
708
+ metadata_dict = {}
709
+ for key, value in self.metadata.items():
710
+ metadata_dict[key] = value
711
+ flow_dict["metadata"] = metadata_dict
712
+ return flow_dict
713
+
714
+ def _get_node_id(self, node: Union[str, Node]) -> str:
715
+ if isinstance(node, Node):
716
+ node_id = node.spec.name
717
+ elif isinstance(node, FlowControl):
718
+ node_id = node.spec.name
719
+ else:
720
+ if (node == START):
721
+ # need to create a start node if one does not yet exist
722
+ if (START not in self.nodes):
723
+ start_node = StartNode(spec=StartNodeSpec(name=START))
724
+ self._add_node(start_node)
725
+ return START
726
+ if (node == END):
727
+ if (END not in self.nodes):
728
+ end_node = EndNode(spec=EndNodeSpec(name=END))
729
+ self._add_node(end_node)
730
+ return END
731
+ node_id = node
732
+ return node_id
733
+
734
+ def _get_data_map(self, map: DataMap) -> DataMap:
735
+ return map
736
+
737
+ class FlowRunStatus(str, Enum):
738
+ NOT_STARTED = "not_started"
739
+ IN_PROGRESS = "in_progress"
740
+ COMPLETED = "completed"
741
+ INTERRUPTED = "interrupted"
742
+ FAILED = "failed"
743
+
744
+ class FlowRun(BaseModel):
745
+ '''Instance of a flow that is running.'''
746
+ name: str | None = None
747
+ id: str = None
748
+ flow: Flow
749
+ status: FlowRunStatus = FlowRunStatus.NOT_STARTED
750
+ output: Any = None
751
+ error: Any = None
752
+
753
+ debug: bool = False
754
+ on_flow_end_handler: Callable = None
755
+ on_flow_error_handler: Callable = None
756
+
757
+ model_config = {
758
+ "arbitrary_types_allowed": True
759
+ }
760
+
761
+ async def _arun_events(self, input_data:dict=None, filters: Sequence[Union[FlowEventType, TaskEventType]]=None) -> AsyncIterator[FlowEvent]:
762
+
763
+ if self.status is not FlowRunStatus.NOT_STARTED:
764
+ raise ValueError("Flow has already been started")
765
+
766
+ # Start the flow
767
+ client:TempusClient = instantiate_client(client=TempusClient)
768
+ ack = client.arun_flow(self.flow.spec.name,input_data)
769
+ self.id=ack["instance_id"]
770
+ self.name = f"{self.flow.spec.name}:{self.id}"
771
+ self.status = FlowRunStatus.IN_PROGRESS
772
+
773
+ # Listen for events
774
+ consumer = StreamConsumer(self.id)
775
+
776
+ async for event in consumer.consume():
777
+ if not event or (filters and event.kind not in filters):
778
+ continue
779
+ if self.debug:
780
+ logger.debug(f"Flow instance `{self.name}` event: `{event.kind}`")
781
+
782
+ self._update_status(event)
783
+
784
+ yield event
785
+
786
+ def _update_status(self, event:FlowEvent):
787
+
788
+ if event.kind == FlowEventType.ON_FLOW_END:
789
+ self.status = FlowRunStatus.COMPLETED
790
+ elif event.kind == FlowEventType.ON_FLOW_ERROR:
791
+ self.status = FlowRunStatus.FAILED
792
+ else:
793
+ self.status = FlowRunStatus.INTERRUPTED if EVENT_TYPE_MAP.get(event.kind, "unknown") == "interrupting" else FlowRunStatus.IN_PROGRESS
794
+
795
+
796
+ if self.debug:
797
+ logger.debug(f"Flow instance `{self.name}` status change: `{self.status}`")
798
+
799
+ async def _arun(self, input_data: dict=None, **kwargs):
800
+
801
+ if self.status is not FlowRunStatus.NOT_STARTED:
802
+ raise ValueError("Flow has already been started")
803
+
804
+ async for event in self._arun_events(input_data):
805
+ if not event:
806
+ continue
807
+
808
+ if event.kind == FlowEventType.ON_FLOW_END:
809
+ # result should come back on the event
810
+ self._on_flow_end(event)
811
+ break
812
+ elif event.kind == FlowEventType.ON_FLOW_ERROR:
813
+ # error should come back on the event
814
+ self._on_flow_error(event)
815
+ break
816
+
817
+ def update_state(self, task_id: str, data: dict) -> Self:
818
+ '''Not Implemented Yet'''
819
+ # update task and continue
820
+ return self
821
+
822
+ def _on_flow_end(self, event:FlowEvent):
823
+
824
+ self.status = FlowRunStatus.COMPLETED
825
+ self.output = event.context.data["output"] if "output" in event.context.data else None
826
+
827
+ if self.debug:
828
+ logger.debug(f"Flow run `{self.name}`: on_complete handler called. Output: {self.output}")
829
+
830
+ if self.on_flow_end_handler:
831
+ self.on_flow_end_handler(self.output)
832
+
833
+
834
+ def _on_flow_error(self, event:FlowEvent):
835
+
836
+ self.status = FlowRunStatus.FAILED
837
+ self.error = event.error
838
+
839
+ if self.debug:
840
+ logger.debug(f"Flow run `{self.name}`: on_error handler called. Error: {self.error}")
841
+
842
+ if self.on_flow_error_handler:
843
+ self.on_flow_error_handler(self.error)
844
+
845
+
846
+ class CompiledFlow(BaseModel):
847
+ '''A compiled version of the flow'''
848
+ flow: Flow
849
+ deployed: bool = False
850
+
851
+ async def invoke(self, input_data:dict=None, on_flow_end_handler: Callable=None, on_flow_error_handler: Callable=None, debug:bool=False, **kwargs) -> FlowRun:
852
+ """
853
+ Sets up and initializes a FlowInstance for the current flow. This only works for CompiledFlow instances that have been deployed.
854
+
855
+ Args:
856
+ input_data (dict, optional): Input data to be passed to the flow. Defaults to None.
857
+ on_flow_end_handler (callable, optional): A callback function to be executed
858
+ when the flow completes successfully. Defaults to None. Takes the flow output as an argument.
859
+ on_flow_error_handler (callable, optional): A callback function to be executed
860
+ when an error occurs during the flow execution. Defaults to None.
861
+ debug (bool, optional): If True, enables debug mode for the flow run. Defaults to False.
862
+
863
+ Returns:
864
+ FlowInstance: An instance of the flow initialized with the provided handlers
865
+ and additional parameters.
866
+ """
867
+
868
+ if self.deployed is False:
869
+ raise ValueError("Flow has not been deployed yet. Please deploy the flow before invoking it by using the Flow.compile_deploy() function.")
870
+
871
+ flow_run = FlowRun(flow=self.flow, on_flow_end_handler=on_flow_end_handler, on_flow_error_handler=on_flow_error_handler, debug=debug, **kwargs)
872
+ asyncio.create_task(flow_run._arun(input_data=input_data, **kwargs))
873
+ return flow_run
874
+
875
+ async def invoke_events(self, input_data:dict=None, filters: Sequence[Union[FlowEventType, TaskEventType]]=None, debug:bool=False) -> AsyncIterator[Tuple[FlowEvent,FlowRun]]:
876
+ """
877
+ Asynchronously runs the flow and yields events received from the flow for the client to handle. This only works for CompiledFlow instances that have been deployed.
878
+
879
+ Args:
880
+ input_data (dict, optional): Input data to be passed to the flow. Defaults to None.
881
+ filters (Sequence[Union[FlowEventType, TaskEventType]], optional):
882
+ A sequence of event types to filter the events. Only events matching these types
883
+ will be yielded. Defaults to None.
884
+ debug (bool, optional): If True, enables debug mode for the flow run. Defaults to False.
885
+
886
+ Yields:
887
+ FlowEvent: Events received from the flow that match the specified filters.
888
+ """
889
+
890
+ if self.deployed is False:
891
+ raise ValueError("Flow has not been deployed yet. Please deploy the flow before invoking it by using the Flow.compile_deploy() function.")
892
+
893
+ flow_run = FlowRun(flow=self.flow, debug=debug)
894
+ async for event in flow_run._arun_events(input_data=input_data, filters=filters):
895
+ yield (event, flow_run)
896
+
897
+ def dump_spec(self, file: str) -> None:
898
+ dumped = self.flow.to_json()
899
+ with open(file, 'w') as f:
900
+ if file.endswith(".yaml") or file.endswith(".yml"):
901
+ yaml.dump(dumped, f)
902
+ elif file.endswith(".json"):
903
+ json.dump(dumped, f, indent=2)
904
+ else:
905
+ raise ValueError('file must end in .json, .yaml, or .yml')
906
+
907
+ def dumps_spec(self) -> str:
908
+ dumped = self.flow.to_json()
909
+ return json.dumps(dumped, indent=2)
910
+
911
+
912
+
913
+ class FlowFactory(BaseModel):
914
+ '''A factory class to create a Flow model'''
915
+
916
+ @staticmethod
917
+ def create_flow(name: str|Callable,
918
+ display_name: str|None=None,
919
+ description: str|None=None,
920
+ initiators: Sequence[str]|None=None,
921
+ input_schema: type[BaseModel]|None=None,
922
+ output_schema: type[BaseModel]|None=None) -> Flow:
923
+ if isinstance(name, Callable):
924
+ flow_spec = getattr(name, "__flow_spec__", None)
925
+ if not flow_spec:
926
+ raise ValueError("Only functions with @flow_spec can be used to create a Flow specification.")
927
+ return Flow(spec = flow_spec)
928
+
929
+ # create input spec
930
+ input_schema_obj = _get_json_schema_obj(parameter_name = "input", type_def = input_schema)
931
+ output_schema_obj = _get_json_schema_obj("output", output_schema)
932
+ if initiators is None:
933
+ initiators = []
934
+
935
+ flow_spec = FlowSpec(
936
+ type="flow",
937
+ name=name,
938
+ display_name=display_name,
939
+ description=description,
940
+ initiators=initiators,
941
+ input_schema=ToolRequestBody(
942
+ type=input_schema_obj.type,
943
+ properties=input_schema_obj.properties,
944
+ required=input_schema_obj.required,
945
+ ) if input_schema_obj else None,
946
+ output_schema=ToolResponseBody(
947
+ type=output_schema_obj.type,
948
+ properties=output_schema_obj.properties,
949
+ required=output_schema_obj.required
950
+ ) if output_schema_obj else None,
951
+ output_schema_object = output_schema_obj
952
+ )
953
+
954
+ return Flow(spec = flow_spec)
955
+
956
+
957
+ class FlowControl(Node):
958
+ '''A parent object representing a flow control node.'''
959
+ ...
960
+
961
+ class Branch(FlowControl):
962
+ containing_flow: Flow = Field(description="The containing flow.")
963
+
964
+ def __repr__(self):
965
+ return f"MatchNode(name='{self.spec.name}', description='{self.spec.description}')"
966
+
967
+ def policy(self, kind: MatchPolicy) -> Self:
968
+ '''
969
+ Set the match policy for this node.
970
+
971
+ Parameters:
972
+ kind (MatchPolicy): The match policy to set.
973
+
974
+ Returns:
975
+ Self: The current node.
976
+ '''
977
+ self.spec.match_policy = kind
978
+ return self
979
+
980
+ def _add_case(self, label: str | bool, node: Node)->Self:
981
+ '''
982
+ Add a case to this branch.
983
+
984
+ Parameters:
985
+ label (str | bool): The label for this case.
986
+ node (Node): The node to add as a case.
987
+
988
+ Returns:
989
+ Self: The current node.
990
+ '''
991
+ node_id = self.containing_flow._get_node_id(node)
992
+ self.spec.cases[label] = {
993
+ "display_name": node_id,
994
+ "node": node_id
995
+ }
996
+ self.containing_flow.edge(self, node)
997
+
998
+ return self
999
+
1000
+ def case(self, label: str | bool, node: Node) -> Self:
1001
+ '''
1002
+ Add a case to this node.
1003
+
1004
+ Parameters:
1005
+ label (str | bool): The label for this case.
1006
+ node (Node): The node to add as a case.
1007
+
1008
+ Returns:
1009
+ Self: The current node.
1010
+ '''
1011
+ if label == "__default__":
1012
+ raise ValueError("Cannot have custom label __default__. Use default() instead.")
1013
+
1014
+ return self._add_case(label, node)
1015
+
1016
+ def default(self, node: Node) -> Self:
1017
+ '''
1018
+ Add a default case to this node.
1019
+
1020
+ Parameters:
1021
+ node (Node): The node to add as a default case.
1022
+
1023
+ Returns:
1024
+ Self: The current node.
1025
+ '''
1026
+ return self._add_case("__default__", node)
1027
+
1028
+ def to_json(self) -> dict[str, Any]:
1029
+ my_dict = super().to_json()
1030
+
1031
+ return my_dict
1032
+
1033
+
1034
+ class Wait(FlowControl):
1035
+ '''
1036
+ A node that represents a wait in a pipeline.
1037
+
1038
+ Attributes:
1039
+ spec (WaitSpec): The specification of the wait node.
1040
+
1041
+ Methods:
1042
+ policy(kind: WaitPolicy) -> Self: Sets the wait policy for the wait node.
1043
+ node(node: Node) -> Self: Adds a node to the list of nodes to wait for.
1044
+ nodes(nodes: List[Node]) -> Self: Adds a list of nodes to the list of nodes to wait for.
1045
+ to_json() -> dict[str, Any]: Converts the wait node to a JSON dictionary.
1046
+ '''
1047
+
1048
+ def policy(self, kind: WaitPolicy) -> Self:
1049
+ '''
1050
+ Sets the wait policy for the wait node.
1051
+
1052
+ Args:
1053
+ kind (WaitPolicy): The wait policy to set.
1054
+
1055
+ Returns:
1056
+ Self: The wait node object.
1057
+ '''
1058
+ self.spec.wait_policy = kind
1059
+ return self
1060
+
1061
+ def node(self, node: Node) -> Self:
1062
+ '''
1063
+ Adds a node to the list of nodes to wait for.
1064
+
1065
+ Args:
1066
+ node (Node): The node to add.
1067
+
1068
+ Returns:
1069
+ Self: The wait node object.
1070
+ '''
1071
+ self.spec.nodes.append(node.spec.name)
1072
+
1073
+ def nodes(self, nodes: List[Node]) -> Self:
1074
+ '''
1075
+ Adds a list of nodes to the list of nodes to wait for.
1076
+
1077
+ Args:
1078
+ nodes (List[Node]): The list of nodes to add.
1079
+
1080
+ Returns:
1081
+ Self: The wait node object.
1082
+ '''
1083
+ for node in nodes:
1084
+ self.spec.nodes.append(node.spec.name)
1085
+
1086
+ def to_json(self) -> dict[str, Any]:
1087
+ my_dict = super().to_json()
1088
+
1089
+ return my_dict
1090
+
1091
+ class Loop(Flow):
1092
+ '''
1093
+ A Loop is a Flow that executes a set of steps repeatedly.
1094
+
1095
+ Args:
1096
+ **kwargs (dict): Arbitrary keyword arguments.
1097
+
1098
+ Returns:
1099
+ dict[str, Any]: A dictionary representation of the Loop object.
1100
+ '''
1101
+
1102
+ def __init__(self, **kwargs):
1103
+ super().__init__(**kwargs)
1104
+
1105
+ def to_json(self) -> dict[str, Any]:
1106
+ my_dict = super().to_json()
1107
+
1108
+ return my_dict
1109
+
1110
+
1111
+
1112
+ class Foreach(Flow):
1113
+ '''
1114
+ A flow that iterates over a list of items.
1115
+
1116
+ Args:
1117
+ **kwargs: Arbitrary keyword arguments.
1118
+
1119
+ Returns:
1120
+ dict[str, Any]: A dictionary representation of the flow.
1121
+ '''
1122
+ def __init__(self, **kwargs):
1123
+ super().__init__(**kwargs)
1124
+
1125
+ # refactor item schema
1126
+ if (self.spec.item_schema.type == "object"):
1127
+ self.spec.item_schema = self._add_schema_ref(self.spec.item_schema, self.spec.item_schema.title)
1128
+
1129
+ def policy(self, kind: ForeachPolicy) -> Self:
1130
+ '''
1131
+ Sets the policy for the foreach flow.
1132
+
1133
+ Args:
1134
+ kind (ForeachPolicy): The policy to set.
1135
+
1136
+ Returns:
1137
+ Self: The current instance of the flow.
1138
+ '''
1139
+ self.spec.foreach_policy = kind
1140
+ return self
1141
+
1142
+ def to_json(self) -> dict[str, Any]:
1143
+ my_dict = super().to_json()
1144
+
1145
+ return my_dict
1146
+
1147
+ class FlowValidationKind(str, Enum):
1148
+ '''
1149
+ This class defines the type of validation for a flow.
1150
+
1151
+ Attributes:
1152
+ ERROR (str): Indicates an error in the flow.
1153
+ WARNING (str): Indicates a warning in the flow.
1154
+ INFO (str): Indicates informational messages related to the flow.
1155
+ '''
1156
+ ERROR = "ERROR",
1157
+ WARNING = "WARNING",
1158
+ INFO = "INFO"
1159
+
1160
+ class FlowValidationMessage(BaseModel):
1161
+ '''
1162
+ FlowValidationMessage class to store validation messages for a flow.
1163
+
1164
+ Attributes:
1165
+ kind (FlowValidationKind): The kind of validation message.
1166
+ message (str): The validation message.
1167
+ node (Node): The node associated with the validation message.
1168
+
1169
+ Methods:
1170
+ __init__(self, kind: FlowValidationKind, message: str, node: Node) -> None:
1171
+ Initializes the FlowValidationMessage object with the given parameters.
1172
+ '''
1173
+ kind: FlowValidationKind
1174
+ message: str
1175
+ node: Node
1176
+
1177
+ class FlowValidator(BaseModel):
1178
+ '''Validate the flow to ensure it is valid and runnable.'''
1179
+ flow: Flow
1180
+
1181
+ def validate_model(self) -> List[FlowValidationMessage]:
1182
+ '''Check the model for possible errors.
1183
+
1184
+ Returns:
1185
+ List[FlowValidationMessage]: A list of validation messages.
1186
+ '''
1187
+ return []
1188
+
1189
+ def any_errors(self, messages: List[FlowValidationMessage]) -> bool:
1190
+ '''
1191
+ Check if any of the messages have a kind of ERROR.
1192
+
1193
+ Args:
1194
+ messages (List[FlowValidationMessage]): A list of validation messages.
1195
+
1196
+ Returns:
1197
+ bool: True if there are any errors, False otherwise.
1198
+ '''
1199
+ return any(m.kind == FlowValidationKind.ERROR for m in messages)
1200
+
1201
+ def no_error(self, messages: List[FlowValidationMessage]) -> bool:
1202
+ '''Check if there are no errors in the messages.
1203
+
1204
+ Args:
1205
+ messages (List[FlowValidationMessage]): A list of validation messages.
1206
+
1207
+ Returns:
1208
+ bool: True if there are no errors, False otherwise.
1209
+ '''
1210
+ return not any(m.kind == FlowValidationKind.ERROR for m in messages)
1211
+
1212
+ class UserFlow(Flow):
1213
+ '''
1214
+ A flow that represents a series of user nodes.
1215
+ A user flow can include other nodes, but not another User Flows.
1216
+ '''
1217
+
1218
+ def __repr__(self):
1219
+ return f"UserFlow(name='{self.spec.name}', description='{self.spec.description}')"
1220
+
1221
+ def get_spec(self) -> NodeSpec:
1222
+ return cast(UserFlowSpec, self.spec)
1223
+
1224
+ def to_json(self) -> dict[str, Any]:
1225
+ my_dict = super().to_json()
1226
+
1227
+ return my_dict
1228
+
1229
+ def field(self,
1230
+ name: str,
1231
+ kind: UserFieldKind = UserFieldKind.Text,
1232
+ display_name: str | None = None,
1233
+ description: str | None = None,
1234
+ owners: list[str] = [],
1235
+ default: Any | None = None,
1236
+ text: str = None,
1237
+ option: UserFieldOption | None = None,
1238
+ input_map: DataMap = None,
1239
+ custom: dict[str, Any] = {}) -> UserNode:
1240
+ '''create a node in the flow'''
1241
+ # create a json schema object based on the single field
1242
+ if not name:
1243
+ raise AssertionError("name cannot be empty")
1244
+
1245
+ schema_obj = JsonSchemaObject(type="object",
1246
+ title=name,
1247
+ description=description)
1248
+
1249
+ schema_obj.properties = {}
1250
+ schema_obj.properties[name] = UserFieldKind.convert_kind_to_schema_property(kind, name, description, default, option, custom)
1251
+
1252
+ return self.user(name,
1253
+ display_name=display_name,
1254
+ description=description,
1255
+ owners=owners,
1256
+ text=text,
1257
+ output_schema=schema_obj,
1258
+ input_map=input_map)
1259
+
1260
+ def user(
1261
+ self,
1262
+ name: str | None = None,
1263
+ display_name: str | None = None,
1264
+ description: str | None = None,
1265
+ owners: list[str] = [],
1266
+ text: str | None = None,
1267
+ output_schema: type[BaseModel] | JsonSchemaObject| None = None,
1268
+ input_map: DataMap = None,
1269
+ ) -> UserNode:
1270
+ '''create a user node in the flow'''
1271
+
1272
+ output_schema_obj = output_schema
1273
+ if inspect.isclass(output_schema):
1274
+ # create input spec
1275
+ output_schema_obj = _get_json_schema_obj(parameter_name = "output", type_def = output_schema)
1276
+ # input and output is always the same in an user node
1277
+ output_schema_obj = output_schema_obj
1278
+
1279
+ # identify owner
1280
+ if not owners:
1281
+ owners = [ANY_USER]
1282
+
1283
+ # Create the tool spec
1284
+ task_spec = UserNodeSpec(
1285
+ name=name,
1286
+ display_name=display_name,
1287
+ description=description,
1288
+ owners=owners,
1289
+ input_schema=None,
1290
+ output_schema=ToolResponseBody(
1291
+ type=output_schema_obj.type,
1292
+ properties=output_schema_obj.properties,
1293
+ required=output_schema_obj.required
1294
+ ),
1295
+ text=text,
1296
+ output_schema_object = output_schema_obj
1297
+ )
1298
+
1299
+ task_spec.setup_fields()
1300
+
1301
+ node = UserNode(spec = task_spec)
1302
+
1303
+ # setup input map
1304
+ if input_map:
1305
+ node.input_map = self._get_data_map(input_map)
1306
+
1307
+ node = self._add_node(node)
1308
+ return cast(UserNode, node)
1309
+
1310
+