lmnr 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lmnr might be problematic. Click here for more details.

Files changed (31) hide show
  1. lmnr/__init__.py +5 -1
  2. lmnr/cli/__main__.py +4 -0
  3. lmnr/cli/cli.py +97 -0
  4. lmnr/cli/cookiecutter.json +9 -0
  5. lmnr/cli/parser/__init__.py +0 -0
  6. lmnr/cli/parser/nodes/__init__.py +50 -0
  7. lmnr/cli/parser/nodes/types.py +156 -0
  8. lmnr/cli/parser/parser.py +58 -0
  9. lmnr/cli/parser/utils.py +25 -0
  10. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/__init__.py +0 -0
  11. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/__init__.py +1 -0
  12. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/action.py +14 -0
  13. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/engine.py +261 -0
  14. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/state.py +69 -0
  15. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/task.py +38 -0
  16. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/__init__.py +1 -0
  17. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/nodes/functions.py +149 -0
  18. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/{{cookiecutter.pipeline_dir_name}}.py +87 -0
  19. lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/types.py +50 -0
  20. lmnr/sdk/endpoint.py +166 -0
  21. lmnr/types.py +93 -0
  22. lmnr-0.1.3.dist-info/LICENSE +72 -0
  23. lmnr-0.1.3.dist-info/METADATA +78 -0
  24. lmnr-0.1.3.dist-info/RECORD +26 -0
  25. lmnr-0.1.3.dist-info/entry_points.txt +3 -0
  26. lmnr/endpoint.py +0 -43
  27. lmnr/model.py +0 -39
  28. lmnr-0.1.1.dist-info/LICENSE +0 -7
  29. lmnr-0.1.1.dist-info/METADATA +0 -37
  30. lmnr-0.1.1.dist-info/RECORD +0 -7
  31. {lmnr-0.1.1.dist-info → lmnr-0.1.3.dist-info}/WHEEL +0 -0
lmnr/__init__.py CHANGED
@@ -1 +1,5 @@
1
- from .endpoint import Laminar
1
+ from .sdk.endpoint import Laminar
2
+ from .types import (
3
+ ChatMessage, EndpointRunError, EndpointRunResponse,
4
+ NodeInput, SDKError,
5
+ )
lmnr/cli/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .cli import cli
2
+
3
+ if __name__ == "__main__":
4
+ cli()
lmnr/cli/cli.py ADDED
@@ -0,0 +1,97 @@
1
+ import requests
2
+ from dotenv import load_dotenv
3
+ import os
4
+ import click
5
+ import logging
6
+ from cookiecutter.main import cookiecutter
7
+ from importlib import resources as importlib_resources
8
+ from pydantic.alias_generators import to_pascal
9
+
10
+ from .parser.parser import runnable_graph_to_template_vars
11
+ import lmnr
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @click.group()
17
+ @click.version_option()
18
+ def cli():
19
+ "CLI for Laminar AI Engine"
20
+
21
+
22
+ @cli.command(name="pull")
23
+ @click.argument("pipeline_name")
24
+ @click.argument("pipeline_version_name")
25
+ @click.option(
26
+ "-p",
27
+ "--project-api-key",
28
+ help="Project API key",
29
+ )
30
+ @click.option(
31
+ "-l",
32
+ "--loglevel",
33
+ help="Sets logging level",
34
+ )
35
+ def pull(pipeline_name, pipeline_version_name, project_api_key, loglevel):
36
+ loglevel_str_to_val = {
37
+ "DEBUG": logging.DEBUG,
38
+ "INFO": logging.INFO,
39
+ "WARNING": logging.WARNING,
40
+ "ERROR": logging.ERROR,
41
+ "CRITICAL": logging.CRITICAL,
42
+ }
43
+ logging.basicConfig()
44
+ logging.getLogger().setLevel(loglevel_str_to_val.get(loglevel, logging.WARNING))
45
+
46
+ project_api_key = project_api_key or os.environ.get("LMNR_PROJECT_API_KEY")
47
+ if not project_api_key:
48
+ load_dotenv()
49
+ project_api_key = os.environ.get("LMNR_PROJECT_API_KEY")
50
+ if not project_api_key:
51
+ raise ValueError("LMNR_PROJECT_API_KEY is not set")
52
+
53
+ headers = {"Authorization": f"Bearer {project_api_key}"}
54
+ params = {
55
+ "pipelineName": pipeline_name,
56
+ "pipelineVersionName": pipeline_version_name,
57
+ }
58
+ res = requests.get(
59
+ "https://api.lmnr.ai/v1/pipeline-version-by-name",
60
+ headers=headers,
61
+ params=params,
62
+ )
63
+ if res.status_code != 200:
64
+ try:
65
+ res_json = res.json()
66
+ except Exception:
67
+ raise ValueError(
68
+ f"Error in fetching pipeline version: {res.status_code}\n{res.text}"
69
+ )
70
+ raise ValueError(
71
+ f"Error in fetching pipeline version: {res.status_code}\n{res_json}"
72
+ )
73
+
74
+ pipeline_version = res.json()
75
+
76
+ class_name = to_pascal(pipeline_name.replace(" ", "_"))
77
+
78
+ context = {
79
+ "pipeline_name": pipeline_name,
80
+ "pipeline_version_name": pipeline_version_name,
81
+ "class_name": class_name,
82
+ # _tasks starts from underscore because we don't want it to be templated
83
+ # some tasks contains LLM nodes which have prompts
84
+ # which we don't want to be rendered by cookiecutter
85
+ "_tasks": runnable_graph_to_template_vars(pipeline_version["runnableGraph"]),
86
+ }
87
+
88
+ logger.info(f"Context:\n{context}")
89
+ cookiecutter(
90
+ str(importlib_resources.files(lmnr)),
91
+ output_dir=".",
92
+ config_file=None,
93
+ extra_context=context,
94
+ directory="cli",
95
+ no_input=True,
96
+ overwrite_if_exists=True,
97
+ )
@@ -0,0 +1,9 @@
1
+ {
2
+ "lmnr_pipelines_dir_name": "lmnr_engine",
3
+ "pipeline_name": "Laminar Pipeline",
4
+ "pipeline_dir_name": "{{ cookiecutter['pipeline_name'].lower().replace('-', '_').replace(' ', '_') }}",
5
+ "class_name": "LaminarPipeline",
6
+ "pipeline_version_name": "main",
7
+ "_tasks": {},
8
+ "_jinja2_env_vars": {"lstrip_blocks": true, "trim_blocks": true}
9
+ }
File without changes
@@ -0,0 +1,50 @@
1
+ from abc import ABCMeta, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+ import uuid
5
+
6
+
7
+ HandleType = str # "String" | "ChatMessageList" | "Any"
8
+
9
+
10
+ @dataclass
11
+ class Handle:
12
+ id: uuid.UUID
13
+ name: Optional[str]
14
+ type: HandleType
15
+
16
+ @classmethod
17
+ def from_dict(cls, dict: dict) -> "Handle":
18
+ return cls(
19
+ id=uuid.UUID(dict["id"]),
20
+ name=(dict["name"] if "name" in dict else None),
21
+ type=dict["type"],
22
+ )
23
+
24
+
25
+ @abstractmethod
26
+ class NodeFunctions(metaclass=ABCMeta):
27
+ @abstractmethod
28
+ def handles_mapping(
29
+ self, output_handle_id_to_node_name: dict[str, str]
30
+ ) -> list[tuple[str, str]]:
31
+ """
32
+ Returns a list of tuples mapping from this node's input
33
+ handle name to the unique name of the previous node.
34
+
35
+ Assumes previous node has only one output.
36
+ """
37
+ pass
38
+
39
+ @abstractmethod
40
+ def node_type(self) -> str:
41
+ pass
42
+
43
+ @abstractmethod
44
+ def config(self) -> dict:
45
+ """
46
+ Returns a dictionary of node-specific configuration.
47
+
48
+ E.g. prompt and model name for LLM node.
49
+ """
50
+ pass
@@ -0,0 +1,156 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional, Union
3
+ import uuid
4
+ from lmnr.cli.parser.nodes import Handle, HandleType, NodeFunctions
5
+ from lmnr.cli.parser.utils import map_handles
6
+ from lmnr.types import NodeInput, ChatMessage
7
+
8
+ def node_input_from_json(json_val: Any) -> NodeInput:
9
+ if isinstance(json_val, str):
10
+ return json_val
11
+ elif isinstance(json_val, list):
12
+ return [ChatMessage.model_validate(msg) for msg in json_val]
13
+ else:
14
+ raise ValueError(f"Invalid NodeInput value: {json_val}")
15
+
16
+
17
+ # TODO: Convert to Pydantic
18
+ @dataclass
19
+ class InputNode(NodeFunctions):
20
+ id: uuid.UUID
21
+ name: str
22
+ outputs: list[Handle]
23
+ input: Optional[NodeInput]
24
+ input_type: HandleType
25
+
26
+ def handles_mapping(
27
+ self, output_handle_id_to_node_name: dict[str, str]
28
+ ) -> list[tuple[str, str]]:
29
+ return []
30
+
31
+ def node_type(self) -> str:
32
+ return "Input"
33
+
34
+ def config(self) -> dict:
35
+ return {}
36
+
37
+
38
+ # TODO: Convert to Pydantic
39
+ @dataclass
40
+ class LLMNode(NodeFunctions):
41
+ id: uuid.UUID
42
+ name: str
43
+ inputs: list[Handle]
44
+ dynamic_inputs: list[Handle]
45
+ outputs: list[Handle]
46
+ inputs_mappings: dict[uuid.UUID, uuid.UUID]
47
+ prompt: str
48
+ model: str
49
+ model_params: Optional[str]
50
+ stream: bool
51
+ structured_output_enabled: bool
52
+ structured_output_max_retries: int
53
+ structured_output_schema: Optional[str]
54
+ structured_output_schema_target: Optional[str]
55
+
56
+ def handles_mapping(
57
+ self, output_handle_id_to_node_name: dict[str, str]
58
+ ) -> list[tuple[str, str]]:
59
+ combined_inputs = self.inputs + self.dynamic_inputs
60
+ return map_handles(
61
+ combined_inputs, self.inputs_mappings, output_handle_id_to_node_name
62
+ )
63
+
64
+ def node_type(self) -> str:
65
+ return "LLM"
66
+
67
+ def config(self) -> dict:
68
+ # For easier access in the template separate the provider and model here
69
+ provider, model = self.model.split(":", maxsplit=1)
70
+
71
+ return {
72
+ "prompt": self.prompt,
73
+ "provider": provider,
74
+ "model": model,
75
+ "model_params": self.model_params,
76
+ "stream": self.stream,
77
+ "structured_output_enabled": self.structured_output_enabled,
78
+ "structured_output_max_retries": self.structured_output_max_retries,
79
+ "structured_output_schema": self.structured_output_schema,
80
+ "structured_output_schema_target": self.structured_output_schema_target,
81
+ }
82
+
83
+
84
+ # TODO: Convert to Pydantic
85
+ @dataclass
86
+ class OutputNode(NodeFunctions):
87
+ id: uuid.UUID
88
+ name: str
89
+ inputs: list[Handle]
90
+ outputs: list[Handle]
91
+ inputs_mappings: dict[uuid.UUID, uuid.UUID]
92
+
93
+ def handles_mapping(
94
+ self, output_handle_id_to_node_name: dict[str, str]
95
+ ) -> list[tuple[str, str]]:
96
+ return map_handles(
97
+ self.inputs, self.inputs_mappings, output_handle_id_to_node_name
98
+ )
99
+
100
+ def node_type(self) -> str:
101
+ return "Output"
102
+
103
+ def config(self) -> dict:
104
+ return {}
105
+
106
+
107
+ Node = Union[InputNode, OutputNode, LLMNode]
108
+
109
+
110
+ def node_from_dict(node_dict: dict) -> Node:
111
+ if node_dict["type"] == "Input":
112
+ return InputNode(
113
+ id=uuid.UUID(node_dict["id"]),
114
+ name=node_dict["name"],
115
+ outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
116
+ input=node_input_from_json(node_dict["input"]),
117
+ input_type=node_dict["inputType"],
118
+ )
119
+ elif node_dict["type"] == "Output":
120
+ return OutputNode(
121
+ id=uuid.UUID(node_dict["id"]),
122
+ name=node_dict["name"],
123
+ inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
124
+ outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
125
+ inputs_mappings={
126
+ uuid.UUID(k): uuid.UUID(v)
127
+ for k, v in node_dict["inputsMappings"].items()
128
+ },
129
+ )
130
+ elif node_dict["type"] == "LLM":
131
+ return LLMNode(
132
+ id=uuid.UUID(node_dict["id"]),
133
+ name=node_dict["name"],
134
+ inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
135
+ dynamic_inputs=[
136
+ Handle.from_dict(handle) for handle in node_dict["dynamicInputs"]
137
+ ],
138
+ outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
139
+ inputs_mappings={
140
+ uuid.UUID(k): uuid.UUID(v)
141
+ for k, v in node_dict["inputsMappings"].items()
142
+ },
143
+ prompt=node_dict["prompt"],
144
+ model=node_dict["model"],
145
+ model_params=(
146
+ node_dict["modelParams"] if "modelParams" in node_dict else None
147
+ ),
148
+ stream=False,
149
+ # TODO: Implement structured output
150
+ structured_output_enabled=False,
151
+ structured_output_max_retries=3,
152
+ structured_output_schema=None,
153
+ structured_output_schema_target=None,
154
+ )
155
+ else:
156
+ raise ValueError(f"Node type {node_dict['type']} not supported")
@@ -0,0 +1,58 @@
1
+ from .nodes.types import node_from_dict
2
+
3
+
4
+ def runnable_graph_to_template_vars(graph: dict) -> dict:
5
+ """
6
+ Convert a runnable graph to template vars to be rendered in a cookiecutter context.
7
+ """
8
+ node_id_to_node_name = {}
9
+ output_handle_id_to_node_name: dict[str, str] = {}
10
+ for node in graph["nodes"].values():
11
+ node_id_to_node_name[node["id"]] = node["name"]
12
+ for handle in node["outputs"]:
13
+ output_handle_id_to_node_name[handle["id"]] = node["name"]
14
+
15
+ tasks = []
16
+ for node_obj in graph["nodes"].values():
17
+ node = node_from_dict(node_obj)
18
+ handles_mapping = node.handles_mapping(output_handle_id_to_node_name)
19
+ node_type = node.node_type()
20
+ tasks.append(
21
+ {
22
+ "name": node.name,
23
+ "function_name": f"run_{node.name}",
24
+ "node_type": node_type,
25
+ "handles_mapping": handles_mapping,
26
+ # since we map from to to from, all 'to's won't repeat
27
+ "input_handle_names": [
28
+ handle_name for (handle_name, _) in handles_mapping
29
+ ],
30
+ "handle_args": ", ".join(
31
+ [
32
+ f"{handle_name}: NodeInput"
33
+ for (handle_name, _) in handles_mapping
34
+ ]
35
+ ),
36
+ "prev": [],
37
+ "next": [],
38
+ "config": node.config(),
39
+ }
40
+ )
41
+
42
+ for to, from_ in graph["pred"].items():
43
+ # TODO: Make "tasks" a hashmap from node id (as str!) to task
44
+ to_task = [task for task in tasks if task["name"] == node_id_to_node_name[to]][
45
+ 0
46
+ ]
47
+ from_tasks = []
48
+ for f in from_:
49
+ from_tasks.append(
50
+ [task for task in tasks if task["name"] == node_id_to_node_name[f]][0]
51
+ )
52
+
53
+ for from_task in from_tasks:
54
+ to_task["prev"].append(from_task["name"])
55
+ from_task["next"].append(node_id_to_node_name[to])
56
+
57
+ # Return as a hashmap due to cookiecutter limitations, investigate later.
58
+ return {task["name"]: task for task in tasks}
@@ -0,0 +1,25 @@
1
+ # Convert a list of handles to a map of input handle names
2
+ # to their respective values
3
+ import uuid
4
+ from .nodes import Handle
5
+
6
+
7
+ def map_handles(
8
+ inputs: list[Handle],
9
+ inputs_mappings: dict[uuid.UUID, uuid.UUID],
10
+ output_handle_id_to_node_name: dict[str, str],
11
+ ) -> list[tuple[str, str]]:
12
+ mapping = []
13
+
14
+ for to, from_ in inputs_mappings.items():
15
+ for input in inputs:
16
+ if input.id == to:
17
+ mapping.append((input.name, from_))
18
+ break
19
+ else:
20
+ raise ValueError(f"Input handle {to} not found in inputs")
21
+
22
+ return [
23
+ (input_name, output_handle_id_to_node_name[str(output_id)])
24
+ for input_name, output_id in mapping
25
+ ]
@@ -0,0 +1 @@
1
+ from .engine import Engine
@@ -0,0 +1,14 @@
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+
4
+ from lmnr_engine.types import NodeInput
5
+
6
+
7
+ @dataclass
8
+ class RunOutput:
9
+ status: str # "Success" | "Termination" TODO: Turn into Enum
10
+ output: Union[NodeInput, None]
11
+
12
+
13
+ class NodeRunError(Exception):
14
+ pass
@@ -0,0 +1,261 @@
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import datetime
3
+ import logging
4
+ from typing import Optional
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ import queue
8
+
9
+ from .task import Task
10
+ from .action import NodeRunError, RunOutput
11
+ from .state import State
12
+ from lmnr_engine.types import Message, NodeInput
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class ScheduledTask:
20
+ status: str # "Task" | "Err" TODO: Use an enum
21
+ task_name: Optional[str]
22
+
23
+
24
+ class RunError(Exception):
25
+ outputs: dict[str, Message]
26
+
27
+
28
+ @dataclass
29
+ class Engine:
30
+ tasks: dict[str, Task]
31
+ active_tasks: set[str]
32
+ depths: dict[str, int]
33
+ outputs: dict[str, Message]
34
+ env: dict[str, str]
35
+ thread_pool_executor: ThreadPoolExecutor
36
+ # TODO: Store thread pool executor's Futures here to have control
37
+ # over them (e.g. cancel them)
38
+
39
+ @classmethod
40
+ def new(
41
+ cls, thread_pool_executor: ThreadPoolExecutor, env: dict[str, str] = {}
42
+ ) -> "Engine":
43
+ return cls(
44
+ tasks={},
45
+ active_tasks=set(),
46
+ depths={},
47
+ outputs={},
48
+ env=env,
49
+ thread_pool_executor=thread_pool_executor,
50
+ )
51
+
52
+ @classmethod
53
+ def with_tasks(
54
+ cls,
55
+ tasks: list[Task],
56
+ thread_pool_executor: ThreadPoolExecutor,
57
+ env: dict[str, str] = {},
58
+ ) -> "Engine":
59
+ dag = cls.new(thread_pool_executor, env=env)
60
+
61
+ for task in tasks:
62
+ dag.tasks[task.name] = task
63
+ dag.depths[task.name] = 0
64
+
65
+ return dag
66
+
67
+ def override_inputs(self, inputs: dict[str, NodeInput]) -> None:
68
+ for task in self.tasks.values():
69
+ # TODO: Check that it's the Input type task
70
+ if not task.prev:
71
+ task.value = inputs[task.name]
72
+
73
+ def run(self, inputs: dict[str, NodeInput]) -> dict[str, Message]:
74
+ self.override_inputs(inputs)
75
+
76
+ q = queue.Queue()
77
+
78
+ input_tasks = []
79
+ for task in self.tasks.values():
80
+ if len(task.prev) == 0:
81
+ input_tasks.append(task.name)
82
+
83
+ for task_id in input_tasks:
84
+ q.put(ScheduledTask(status="Task", task_name=task_id))
85
+
86
+ while True:
87
+ logger.info("Waiting for task from queue")
88
+ scheduled_task: ScheduledTask = q.get()
89
+ logger.info(f"Got task from queue: {scheduled_task}")
90
+ if scheduled_task.status == "Err":
91
+ # TODO: Abort all other threads
92
+ raise RunError(self.outputs)
93
+
94
+ task: Task = self.tasks[scheduled_task.task_name] # type: ignore
95
+ logger.info(f"Task next: {task.next}")
96
+
97
+ if not task.next:
98
+ try:
99
+ fut = self.execute_task(task, q)
100
+ fut.result()
101
+ if not self.active_tasks:
102
+ return self.outputs
103
+ except Exception:
104
+ raise RunError(self.outputs)
105
+ else:
106
+ self.execute_task(task, q)
107
+
108
+ def execute_task_inner(
109
+ self,
110
+ task: Task,
111
+ queue: queue.Queue,
112
+ ) -> None:
113
+ task_id = task.name
114
+ next = task.next
115
+ input_states = task.input_states
116
+ active_tasks = self.active_tasks
117
+ tasks = self.tasks
118
+ depths = self.depths
119
+ depth = depths[task.name]
120
+ outputs = self.outputs
121
+
122
+ inputs: dict[str, NodeInput] = {}
123
+ input_messages = []
124
+
125
+ # Wait for inputs for this task to be set
126
+ for handle_name, input_state in input_states.items():
127
+ logger.info(f"Task {task_id} waiting for semaphore for {handle_name}")
128
+ input_state.semaphore.acquire()
129
+ logger.info(f"Task {task_id} acquired semaphore for {handle_name}")
130
+
131
+ # Set the outputs of predecessors as inputs of the current
132
+ output = input_state.get_state()
133
+ # If at least one of the inputs is termination,
134
+ # also terminate this task early and set its state to termination
135
+ if output.status == "Termination":
136
+ return
137
+ message = output.get_out()
138
+
139
+ inputs[handle_name] = message.value
140
+ input_messages.append(message)
141
+
142
+ start_time = datetime.datetime.now()
143
+
144
+ try:
145
+ if callable(task.value):
146
+ res = task.value(**inputs, _env=self.env)
147
+ else:
148
+ res = RunOutput(status="Success", output=task.value)
149
+
150
+ if res.status == "Success":
151
+ id = uuid.uuid4()
152
+ state = State.new(
153
+ Message(
154
+ id=id,
155
+ value=res.output, # type: ignore
156
+ start_time=start_time,
157
+ end_time=datetime.datetime.now(),
158
+ )
159
+ )
160
+ else:
161
+ assert res.status == "Termination"
162
+ state = State.termination()
163
+
164
+ is_termination = state.is_termination()
165
+ logger.info(f"Task {task_id} executed")
166
+
167
+ # remove the task from active tasks once it's done
168
+ if task_id in active_tasks:
169
+ active_tasks.remove(task_id)
170
+
171
+ if depth > 0:
172
+ # propagate reset once we enter the loop
173
+ # TODO: Implement this for cycles
174
+ raise NotImplementedError()
175
+
176
+ if depth == 10:
177
+ # TODO: Implement this for cycles
178
+ raise NotImplementedError()
179
+
180
+ if not next:
181
+ # if there are no next tasks, we can terminate the graph
182
+ outputs[task.name] = state.get_out()
183
+
184
+ # push next tasks to the channel only if
185
+ # the current task is not a termination
186
+ for next_task_name in next:
187
+ # we set the inputs of the next tasks
188
+ # to the outputs of the current task
189
+ next_task = tasks[next_task_name]
190
+
191
+ # in majority of cases there will be only one handle name
192
+ # however we need to handle the case when single output
193
+ # is mapped to multiple inputs on the next node
194
+ handle_names = []
195
+ for k, v in next_task.handles_mapping:
196
+ if v == task.name:
197
+ handle_names.append(k)
198
+
199
+ for handle_name in handle_names:
200
+ next_state = next_task.input_states[handle_name]
201
+ next_state.set_state_and_permits(state, 1)
202
+
203
+ # push next tasks to the channel only if the task is not active
204
+ # and current task is not a termination
205
+ if not (next_task_name in active_tasks) and not is_termination:
206
+ active_tasks.add(next_task_name)
207
+ queue.put(
208
+ ScheduledTask(
209
+ status="Task",
210
+ task_name=next_task_name,
211
+ )
212
+ )
213
+
214
+ # increment depth of the finished task
215
+ depths[task_id] = depth + 1
216
+ except NodeRunError as e:
217
+ logger.exception(f"Execution failed [id: {task_id}]")
218
+
219
+ error = Message(
220
+ id=uuid.uuid4(),
221
+ value=str(e),
222
+ start_time=start_time,
223
+ end_time=datetime.datetime.now(),
224
+ )
225
+
226
+ outputs[task.name] = error
227
+
228
+ # terminate entire graph by sending err task
229
+ queue.put(
230
+ ScheduledTask(
231
+ status="Err",
232
+ task_name=None,
233
+ )
234
+ )
235
+
236
+ except Exception:
237
+ logger.exception(f"Execution failed [id: {task_id}]")
238
+ error = Message(
239
+ id=uuid.uuid4(),
240
+ value="Unexpected server error",
241
+ start_time=start_time,
242
+ end_time=datetime.datetime.now(),
243
+ )
244
+ outputs[task.name] = error
245
+ queue.put(
246
+ ScheduledTask(
247
+ status="Err",
248
+ task_name=None,
249
+ )
250
+ )
251
+
252
+ def execute_task(
253
+ self,
254
+ task: Task,
255
+ queue: queue.Queue,
256
+ ):
257
+ return self.thread_pool_executor.submit(
258
+ self.execute_task_inner,
259
+ task,
260
+ queue,
261
+ )