lmnr 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lmnr might be problematic. Click here for more details.
- lmnr/__init__.py +5 -1
- lmnr/cli/__main__.py +4 -0
- lmnr/cli/cli.py +97 -0
- lmnr/cli/cookiecutter.json +9 -0
- lmnr/cli/parser/__init__.py +0 -0
- lmnr/cli/parser/nodes/__init__.py +50 -0
- lmnr/cli/parser/nodes/types.py +156 -0
- lmnr/cli/parser/parser.py +58 -0
- lmnr/cli/parser/utils.py +25 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/__init__.py +0 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/__init__.py +1 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/action.py +14 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/engine.py +261 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/state.py +69 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/task.py +38 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/__init__.py +1 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/nodes/functions.py +149 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/{{cookiecutter.pipeline_dir_name}}.py +87 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/types.py +50 -0
- lmnr/sdk/endpoint.py +166 -0
- lmnr/types.py +93 -0
- lmnr-0.1.3.dist-info/LICENSE +72 -0
- lmnr-0.1.3.dist-info/METADATA +78 -0
- lmnr-0.1.3.dist-info/RECORD +26 -0
- lmnr-0.1.3.dist-info/entry_points.txt +3 -0
- lmnr/endpoint.py +0 -43
- lmnr/model.py +0 -39
- lmnr-0.1.1.dist-info/LICENSE +0 -7
- lmnr-0.1.1.dist-info/METADATA +0 -37
- lmnr-0.1.1.dist-info/RECORD +0 -7
- {lmnr-0.1.1.dist-info → lmnr-0.1.3.dist-info}/WHEEL +0 -0
lmnr/__init__.py
CHANGED
lmnr/cli/__main__.py
ADDED
lmnr/cli/cli.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
import os
|
|
4
|
+
import click
|
|
5
|
+
import logging
|
|
6
|
+
from cookiecutter.main import cookiecutter
|
|
7
|
+
from importlib import resources as importlib_resources
|
|
8
|
+
from pydantic.alias_generators import to_pascal
|
|
9
|
+
|
|
10
|
+
from .parser.parser import runnable_graph_to_template_vars
|
|
11
|
+
import lmnr
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@click.group()
|
|
17
|
+
@click.version_option()
|
|
18
|
+
def cli():
|
|
19
|
+
"CLI for Laminar AI Engine"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@cli.command(name="pull")
|
|
23
|
+
@click.argument("pipeline_name")
|
|
24
|
+
@click.argument("pipeline_version_name")
|
|
25
|
+
@click.option(
|
|
26
|
+
"-p",
|
|
27
|
+
"--project-api-key",
|
|
28
|
+
help="Project API key",
|
|
29
|
+
)
|
|
30
|
+
@click.option(
|
|
31
|
+
"-l",
|
|
32
|
+
"--loglevel",
|
|
33
|
+
help="Sets logging level",
|
|
34
|
+
)
|
|
35
|
+
def pull(pipeline_name, pipeline_version_name, project_api_key, loglevel):
|
|
36
|
+
loglevel_str_to_val = {
|
|
37
|
+
"DEBUG": logging.DEBUG,
|
|
38
|
+
"INFO": logging.INFO,
|
|
39
|
+
"WARNING": logging.WARNING,
|
|
40
|
+
"ERROR": logging.ERROR,
|
|
41
|
+
"CRITICAL": logging.CRITICAL,
|
|
42
|
+
}
|
|
43
|
+
logging.basicConfig()
|
|
44
|
+
logging.getLogger().setLevel(loglevel_str_to_val.get(loglevel, logging.WARNING))
|
|
45
|
+
|
|
46
|
+
project_api_key = project_api_key or os.environ.get("LMNR_PROJECT_API_KEY")
|
|
47
|
+
if not project_api_key:
|
|
48
|
+
load_dotenv()
|
|
49
|
+
project_api_key = os.environ.get("LMNR_PROJECT_API_KEY")
|
|
50
|
+
if not project_api_key:
|
|
51
|
+
raise ValueError("LMNR_PROJECT_API_KEY is not set")
|
|
52
|
+
|
|
53
|
+
headers = {"Authorization": f"Bearer {project_api_key}"}
|
|
54
|
+
params = {
|
|
55
|
+
"pipelineName": pipeline_name,
|
|
56
|
+
"pipelineVersionName": pipeline_version_name,
|
|
57
|
+
}
|
|
58
|
+
res = requests.get(
|
|
59
|
+
"https://api.lmnr.ai/v1/pipeline-version-by-name",
|
|
60
|
+
headers=headers,
|
|
61
|
+
params=params,
|
|
62
|
+
)
|
|
63
|
+
if res.status_code != 200:
|
|
64
|
+
try:
|
|
65
|
+
res_json = res.json()
|
|
66
|
+
except Exception:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Error in fetching pipeline version: {res.status_code}\n{res.text}"
|
|
69
|
+
)
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Error in fetching pipeline version: {res.status_code}\n{res_json}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
pipeline_version = res.json()
|
|
75
|
+
|
|
76
|
+
class_name = to_pascal(pipeline_name.replace(" ", "_"))
|
|
77
|
+
|
|
78
|
+
context = {
|
|
79
|
+
"pipeline_name": pipeline_name,
|
|
80
|
+
"pipeline_version_name": pipeline_version_name,
|
|
81
|
+
"class_name": class_name,
|
|
82
|
+
# _tasks starts from underscore because we don't want it to be templated
|
|
83
|
+
# some tasks contains LLM nodes which have prompts
|
|
84
|
+
# which we don't want to be rendered by cookiecutter
|
|
85
|
+
"_tasks": runnable_graph_to_template_vars(pipeline_version["runnableGraph"]),
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
logger.info(f"Context:\n{context}")
|
|
89
|
+
cookiecutter(
|
|
90
|
+
str(importlib_resources.files(lmnr)),
|
|
91
|
+
output_dir=".",
|
|
92
|
+
config_file=None,
|
|
93
|
+
extra_context=context,
|
|
94
|
+
directory="cli",
|
|
95
|
+
no_input=True,
|
|
96
|
+
overwrite_if_exists=True,
|
|
97
|
+
)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
{
|
|
2
|
+
"lmnr_pipelines_dir_name": "lmnr_engine",
|
|
3
|
+
"pipeline_name": "Laminar Pipeline",
|
|
4
|
+
"pipeline_dir_name": "{{ cookiecutter['pipeline_name'].lower().replace('-', '_').replace(' ', '_') }}",
|
|
5
|
+
"class_name": "LaminarPipeline",
|
|
6
|
+
"pipeline_version_name": "main",
|
|
7
|
+
"_tasks": {},
|
|
8
|
+
"_jinja2_env_vars": {"lstrip_blocks": true, "trim_blocks": true}
|
|
9
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
HandleType = str # "String" | "ChatMessageList" | "Any"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class Handle:
|
|
12
|
+
id: uuid.UUID
|
|
13
|
+
name: Optional[str]
|
|
14
|
+
type: HandleType
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def from_dict(cls, dict: dict) -> "Handle":
|
|
18
|
+
return cls(
|
|
19
|
+
id=uuid.UUID(dict["id"]),
|
|
20
|
+
name=(dict["name"] if "name" in dict else None),
|
|
21
|
+
type=dict["type"],
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
class NodeFunctions(metaclass=ABCMeta):
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def handles_mapping(
|
|
29
|
+
self, output_handle_id_to_node_name: dict[str, str]
|
|
30
|
+
) -> list[tuple[str, str]]:
|
|
31
|
+
"""
|
|
32
|
+
Returns a list of tuples mapping from this node's input
|
|
33
|
+
handle name to the unique name of the previous node.
|
|
34
|
+
|
|
35
|
+
Assumes previous node has only one output.
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def node_type(self) -> str:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def config(self) -> dict:
|
|
45
|
+
"""
|
|
46
|
+
Returns a dictionary of node-specific configuration.
|
|
47
|
+
|
|
48
|
+
E.g. prompt and model name for LLM node.
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
import uuid
|
|
4
|
+
from lmnr.cli.parser.nodes import Handle, HandleType, NodeFunctions
|
|
5
|
+
from lmnr.cli.parser.utils import map_handles
|
|
6
|
+
from lmnr.types import NodeInput, ChatMessage
|
|
7
|
+
|
|
8
|
+
def node_input_from_json(json_val: Any) -> NodeInput:
|
|
9
|
+
if isinstance(json_val, str):
|
|
10
|
+
return json_val
|
|
11
|
+
elif isinstance(json_val, list):
|
|
12
|
+
return [ChatMessage.model_validate(msg) for msg in json_val]
|
|
13
|
+
else:
|
|
14
|
+
raise ValueError(f"Invalid NodeInput value: {json_val}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# TODO: Convert to Pydantic
|
|
18
|
+
@dataclass
|
|
19
|
+
class InputNode(NodeFunctions):
|
|
20
|
+
id: uuid.UUID
|
|
21
|
+
name: str
|
|
22
|
+
outputs: list[Handle]
|
|
23
|
+
input: Optional[NodeInput]
|
|
24
|
+
input_type: HandleType
|
|
25
|
+
|
|
26
|
+
def handles_mapping(
|
|
27
|
+
self, output_handle_id_to_node_name: dict[str, str]
|
|
28
|
+
) -> list[tuple[str, str]]:
|
|
29
|
+
return []
|
|
30
|
+
|
|
31
|
+
def node_type(self) -> str:
|
|
32
|
+
return "Input"
|
|
33
|
+
|
|
34
|
+
def config(self) -> dict:
|
|
35
|
+
return {}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# TODO: Convert to Pydantic
|
|
39
|
+
@dataclass
|
|
40
|
+
class LLMNode(NodeFunctions):
|
|
41
|
+
id: uuid.UUID
|
|
42
|
+
name: str
|
|
43
|
+
inputs: list[Handle]
|
|
44
|
+
dynamic_inputs: list[Handle]
|
|
45
|
+
outputs: list[Handle]
|
|
46
|
+
inputs_mappings: dict[uuid.UUID, uuid.UUID]
|
|
47
|
+
prompt: str
|
|
48
|
+
model: str
|
|
49
|
+
model_params: Optional[str]
|
|
50
|
+
stream: bool
|
|
51
|
+
structured_output_enabled: bool
|
|
52
|
+
structured_output_max_retries: int
|
|
53
|
+
structured_output_schema: Optional[str]
|
|
54
|
+
structured_output_schema_target: Optional[str]
|
|
55
|
+
|
|
56
|
+
def handles_mapping(
|
|
57
|
+
self, output_handle_id_to_node_name: dict[str, str]
|
|
58
|
+
) -> list[tuple[str, str]]:
|
|
59
|
+
combined_inputs = self.inputs + self.dynamic_inputs
|
|
60
|
+
return map_handles(
|
|
61
|
+
combined_inputs, self.inputs_mappings, output_handle_id_to_node_name
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def node_type(self) -> str:
|
|
65
|
+
return "LLM"
|
|
66
|
+
|
|
67
|
+
def config(self) -> dict:
|
|
68
|
+
# For easier access in the template separate the provider and model here
|
|
69
|
+
provider, model = self.model.split(":", maxsplit=1)
|
|
70
|
+
|
|
71
|
+
return {
|
|
72
|
+
"prompt": self.prompt,
|
|
73
|
+
"provider": provider,
|
|
74
|
+
"model": model,
|
|
75
|
+
"model_params": self.model_params,
|
|
76
|
+
"stream": self.stream,
|
|
77
|
+
"structured_output_enabled": self.structured_output_enabled,
|
|
78
|
+
"structured_output_max_retries": self.structured_output_max_retries,
|
|
79
|
+
"structured_output_schema": self.structured_output_schema,
|
|
80
|
+
"structured_output_schema_target": self.structured_output_schema_target,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# TODO: Convert to Pydantic
|
|
85
|
+
@dataclass
|
|
86
|
+
class OutputNode(NodeFunctions):
|
|
87
|
+
id: uuid.UUID
|
|
88
|
+
name: str
|
|
89
|
+
inputs: list[Handle]
|
|
90
|
+
outputs: list[Handle]
|
|
91
|
+
inputs_mappings: dict[uuid.UUID, uuid.UUID]
|
|
92
|
+
|
|
93
|
+
def handles_mapping(
|
|
94
|
+
self, output_handle_id_to_node_name: dict[str, str]
|
|
95
|
+
) -> list[tuple[str, str]]:
|
|
96
|
+
return map_handles(
|
|
97
|
+
self.inputs, self.inputs_mappings, output_handle_id_to_node_name
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def node_type(self) -> str:
|
|
101
|
+
return "Output"
|
|
102
|
+
|
|
103
|
+
def config(self) -> dict:
|
|
104
|
+
return {}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
Node = Union[InputNode, OutputNode, LLMNode]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def node_from_dict(node_dict: dict) -> Node:
|
|
111
|
+
if node_dict["type"] == "Input":
|
|
112
|
+
return InputNode(
|
|
113
|
+
id=uuid.UUID(node_dict["id"]),
|
|
114
|
+
name=node_dict["name"],
|
|
115
|
+
outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
|
|
116
|
+
input=node_input_from_json(node_dict["input"]),
|
|
117
|
+
input_type=node_dict["inputType"],
|
|
118
|
+
)
|
|
119
|
+
elif node_dict["type"] == "Output":
|
|
120
|
+
return OutputNode(
|
|
121
|
+
id=uuid.UUID(node_dict["id"]),
|
|
122
|
+
name=node_dict["name"],
|
|
123
|
+
inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
|
|
124
|
+
outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
|
|
125
|
+
inputs_mappings={
|
|
126
|
+
uuid.UUID(k): uuid.UUID(v)
|
|
127
|
+
for k, v in node_dict["inputsMappings"].items()
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
elif node_dict["type"] == "LLM":
|
|
131
|
+
return LLMNode(
|
|
132
|
+
id=uuid.UUID(node_dict["id"]),
|
|
133
|
+
name=node_dict["name"],
|
|
134
|
+
inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
|
|
135
|
+
dynamic_inputs=[
|
|
136
|
+
Handle.from_dict(handle) for handle in node_dict["dynamicInputs"]
|
|
137
|
+
],
|
|
138
|
+
outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
|
|
139
|
+
inputs_mappings={
|
|
140
|
+
uuid.UUID(k): uuid.UUID(v)
|
|
141
|
+
for k, v in node_dict["inputsMappings"].items()
|
|
142
|
+
},
|
|
143
|
+
prompt=node_dict["prompt"],
|
|
144
|
+
model=node_dict["model"],
|
|
145
|
+
model_params=(
|
|
146
|
+
node_dict["modelParams"] if "modelParams" in node_dict else None
|
|
147
|
+
),
|
|
148
|
+
stream=False,
|
|
149
|
+
# TODO: Implement structured output
|
|
150
|
+
structured_output_enabled=False,
|
|
151
|
+
structured_output_max_retries=3,
|
|
152
|
+
structured_output_schema=None,
|
|
153
|
+
structured_output_schema_target=None,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(f"Node type {node_dict['type']} not supported")
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from .nodes.types import node_from_dict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def runnable_graph_to_template_vars(graph: dict) -> dict:
|
|
5
|
+
"""
|
|
6
|
+
Convert a runnable graph to template vars to be rendered in a cookiecutter context.
|
|
7
|
+
"""
|
|
8
|
+
node_id_to_node_name = {}
|
|
9
|
+
output_handle_id_to_node_name: dict[str, str] = {}
|
|
10
|
+
for node in graph["nodes"].values():
|
|
11
|
+
node_id_to_node_name[node["id"]] = node["name"]
|
|
12
|
+
for handle in node["outputs"]:
|
|
13
|
+
output_handle_id_to_node_name[handle["id"]] = node["name"]
|
|
14
|
+
|
|
15
|
+
tasks = []
|
|
16
|
+
for node_obj in graph["nodes"].values():
|
|
17
|
+
node = node_from_dict(node_obj)
|
|
18
|
+
handles_mapping = node.handles_mapping(output_handle_id_to_node_name)
|
|
19
|
+
node_type = node.node_type()
|
|
20
|
+
tasks.append(
|
|
21
|
+
{
|
|
22
|
+
"name": node.name,
|
|
23
|
+
"function_name": f"run_{node.name}",
|
|
24
|
+
"node_type": node_type,
|
|
25
|
+
"handles_mapping": handles_mapping,
|
|
26
|
+
# since we map from to to from, all 'to's won't repeat
|
|
27
|
+
"input_handle_names": [
|
|
28
|
+
handle_name for (handle_name, _) in handles_mapping
|
|
29
|
+
],
|
|
30
|
+
"handle_args": ", ".join(
|
|
31
|
+
[
|
|
32
|
+
f"{handle_name}: NodeInput"
|
|
33
|
+
for (handle_name, _) in handles_mapping
|
|
34
|
+
]
|
|
35
|
+
),
|
|
36
|
+
"prev": [],
|
|
37
|
+
"next": [],
|
|
38
|
+
"config": node.config(),
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
for to, from_ in graph["pred"].items():
|
|
43
|
+
# TODO: Make "tasks" a hashmap from node id (as str!) to task
|
|
44
|
+
to_task = [task for task in tasks if task["name"] == node_id_to_node_name[to]][
|
|
45
|
+
0
|
|
46
|
+
]
|
|
47
|
+
from_tasks = []
|
|
48
|
+
for f in from_:
|
|
49
|
+
from_tasks.append(
|
|
50
|
+
[task for task in tasks if task["name"] == node_id_to_node_name[f]][0]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
for from_task in from_tasks:
|
|
54
|
+
to_task["prev"].append(from_task["name"])
|
|
55
|
+
from_task["next"].append(node_id_to_node_name[to])
|
|
56
|
+
|
|
57
|
+
# Return as a hashmap due to cookiecutter limitations, investigate later.
|
|
58
|
+
return {task["name"]: task for task in tasks}
|
lmnr/cli/parser/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Convert a list of handles to a map of input handle names
|
|
2
|
+
# to their respective values
|
|
3
|
+
import uuid
|
|
4
|
+
from .nodes import Handle
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def map_handles(
|
|
8
|
+
inputs: list[Handle],
|
|
9
|
+
inputs_mappings: dict[uuid.UUID, uuid.UUID],
|
|
10
|
+
output_handle_id_to_node_name: dict[str, str],
|
|
11
|
+
) -> list[tuple[str, str]]:
|
|
12
|
+
mapping = []
|
|
13
|
+
|
|
14
|
+
for to, from_ in inputs_mappings.items():
|
|
15
|
+
for input in inputs:
|
|
16
|
+
if input.id == to:
|
|
17
|
+
mapping.append((input.name, from_))
|
|
18
|
+
break
|
|
19
|
+
else:
|
|
20
|
+
raise ValueError(f"Input handle {to} not found in inputs")
|
|
21
|
+
|
|
22
|
+
return [
|
|
23
|
+
(input_name, output_handle_id_to_node_name[str(output_id)])
|
|
24
|
+
for input_name, output_id in mapping
|
|
25
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .engine import Engine
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from lmnr_engine.types import NodeInput
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class RunOutput:
|
|
9
|
+
status: str # "Success" | "Termination" TODO: Turn into Enum
|
|
10
|
+
output: Union[NodeInput, None]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeRunError(Exception):
|
|
14
|
+
pass
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
2
|
+
import datetime
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
import uuid
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import queue
|
|
8
|
+
|
|
9
|
+
from .task import Task
|
|
10
|
+
from .action import NodeRunError, RunOutput
|
|
11
|
+
from .state import State
|
|
12
|
+
from lmnr_engine.types import Message, NodeInput
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ScheduledTask:
|
|
20
|
+
status: str # "Task" | "Err" TODO: Use an enum
|
|
21
|
+
task_name: Optional[str]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RunError(Exception):
|
|
25
|
+
outputs: dict[str, Message]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Engine:
|
|
30
|
+
tasks: dict[str, Task]
|
|
31
|
+
active_tasks: set[str]
|
|
32
|
+
depths: dict[str, int]
|
|
33
|
+
outputs: dict[str, Message]
|
|
34
|
+
env: dict[str, str]
|
|
35
|
+
thread_pool_executor: ThreadPoolExecutor
|
|
36
|
+
# TODO: Store thread pool executor's Futures here to have control
|
|
37
|
+
# over them (e.g. cancel them)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def new(
|
|
41
|
+
cls, thread_pool_executor: ThreadPoolExecutor, env: dict[str, str] = {}
|
|
42
|
+
) -> "Engine":
|
|
43
|
+
return cls(
|
|
44
|
+
tasks={},
|
|
45
|
+
active_tasks=set(),
|
|
46
|
+
depths={},
|
|
47
|
+
outputs={},
|
|
48
|
+
env=env,
|
|
49
|
+
thread_pool_executor=thread_pool_executor,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def with_tasks(
|
|
54
|
+
cls,
|
|
55
|
+
tasks: list[Task],
|
|
56
|
+
thread_pool_executor: ThreadPoolExecutor,
|
|
57
|
+
env: dict[str, str] = {},
|
|
58
|
+
) -> "Engine":
|
|
59
|
+
dag = cls.new(thread_pool_executor, env=env)
|
|
60
|
+
|
|
61
|
+
for task in tasks:
|
|
62
|
+
dag.tasks[task.name] = task
|
|
63
|
+
dag.depths[task.name] = 0
|
|
64
|
+
|
|
65
|
+
return dag
|
|
66
|
+
|
|
67
|
+
def override_inputs(self, inputs: dict[str, NodeInput]) -> None:
|
|
68
|
+
for task in self.tasks.values():
|
|
69
|
+
# TODO: Check that it's the Input type task
|
|
70
|
+
if not task.prev:
|
|
71
|
+
task.value = inputs[task.name]
|
|
72
|
+
|
|
73
|
+
def run(self, inputs: dict[str, NodeInput]) -> dict[str, Message]:
|
|
74
|
+
self.override_inputs(inputs)
|
|
75
|
+
|
|
76
|
+
q = queue.Queue()
|
|
77
|
+
|
|
78
|
+
input_tasks = []
|
|
79
|
+
for task in self.tasks.values():
|
|
80
|
+
if len(task.prev) == 0:
|
|
81
|
+
input_tasks.append(task.name)
|
|
82
|
+
|
|
83
|
+
for task_id in input_tasks:
|
|
84
|
+
q.put(ScheduledTask(status="Task", task_name=task_id))
|
|
85
|
+
|
|
86
|
+
while True:
|
|
87
|
+
logger.info("Waiting for task from queue")
|
|
88
|
+
scheduled_task: ScheduledTask = q.get()
|
|
89
|
+
logger.info(f"Got task from queue: {scheduled_task}")
|
|
90
|
+
if scheduled_task.status == "Err":
|
|
91
|
+
# TODO: Abort all other threads
|
|
92
|
+
raise RunError(self.outputs)
|
|
93
|
+
|
|
94
|
+
task: Task = self.tasks[scheduled_task.task_name] # type: ignore
|
|
95
|
+
logger.info(f"Task next: {task.next}")
|
|
96
|
+
|
|
97
|
+
if not task.next:
|
|
98
|
+
try:
|
|
99
|
+
fut = self.execute_task(task, q)
|
|
100
|
+
fut.result()
|
|
101
|
+
if not self.active_tasks:
|
|
102
|
+
return self.outputs
|
|
103
|
+
except Exception:
|
|
104
|
+
raise RunError(self.outputs)
|
|
105
|
+
else:
|
|
106
|
+
self.execute_task(task, q)
|
|
107
|
+
|
|
108
|
+
def execute_task_inner(
|
|
109
|
+
self,
|
|
110
|
+
task: Task,
|
|
111
|
+
queue: queue.Queue,
|
|
112
|
+
) -> None:
|
|
113
|
+
task_id = task.name
|
|
114
|
+
next = task.next
|
|
115
|
+
input_states = task.input_states
|
|
116
|
+
active_tasks = self.active_tasks
|
|
117
|
+
tasks = self.tasks
|
|
118
|
+
depths = self.depths
|
|
119
|
+
depth = depths[task.name]
|
|
120
|
+
outputs = self.outputs
|
|
121
|
+
|
|
122
|
+
inputs: dict[str, NodeInput] = {}
|
|
123
|
+
input_messages = []
|
|
124
|
+
|
|
125
|
+
# Wait for inputs for this task to be set
|
|
126
|
+
for handle_name, input_state in input_states.items():
|
|
127
|
+
logger.info(f"Task {task_id} waiting for semaphore for {handle_name}")
|
|
128
|
+
input_state.semaphore.acquire()
|
|
129
|
+
logger.info(f"Task {task_id} acquired semaphore for {handle_name}")
|
|
130
|
+
|
|
131
|
+
# Set the outputs of predecessors as inputs of the current
|
|
132
|
+
output = input_state.get_state()
|
|
133
|
+
# If at least one of the inputs is termination,
|
|
134
|
+
# also terminate this task early and set its state to termination
|
|
135
|
+
if output.status == "Termination":
|
|
136
|
+
return
|
|
137
|
+
message = output.get_out()
|
|
138
|
+
|
|
139
|
+
inputs[handle_name] = message.value
|
|
140
|
+
input_messages.append(message)
|
|
141
|
+
|
|
142
|
+
start_time = datetime.datetime.now()
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
if callable(task.value):
|
|
146
|
+
res = task.value(**inputs, _env=self.env)
|
|
147
|
+
else:
|
|
148
|
+
res = RunOutput(status="Success", output=task.value)
|
|
149
|
+
|
|
150
|
+
if res.status == "Success":
|
|
151
|
+
id = uuid.uuid4()
|
|
152
|
+
state = State.new(
|
|
153
|
+
Message(
|
|
154
|
+
id=id,
|
|
155
|
+
value=res.output, # type: ignore
|
|
156
|
+
start_time=start_time,
|
|
157
|
+
end_time=datetime.datetime.now(),
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
assert res.status == "Termination"
|
|
162
|
+
state = State.termination()
|
|
163
|
+
|
|
164
|
+
is_termination = state.is_termination()
|
|
165
|
+
logger.info(f"Task {task_id} executed")
|
|
166
|
+
|
|
167
|
+
# remove the task from active tasks once it's done
|
|
168
|
+
if task_id in active_tasks:
|
|
169
|
+
active_tasks.remove(task_id)
|
|
170
|
+
|
|
171
|
+
if depth > 0:
|
|
172
|
+
# propagate reset once we enter the loop
|
|
173
|
+
# TODO: Implement this for cycles
|
|
174
|
+
raise NotImplementedError()
|
|
175
|
+
|
|
176
|
+
if depth == 10:
|
|
177
|
+
# TODO: Implement this for cycles
|
|
178
|
+
raise NotImplementedError()
|
|
179
|
+
|
|
180
|
+
if not next:
|
|
181
|
+
# if there are no next tasks, we can terminate the graph
|
|
182
|
+
outputs[task.name] = state.get_out()
|
|
183
|
+
|
|
184
|
+
# push next tasks to the channel only if
|
|
185
|
+
# the current task is not a termination
|
|
186
|
+
for next_task_name in next:
|
|
187
|
+
# we set the inputs of the next tasks
|
|
188
|
+
# to the outputs of the current task
|
|
189
|
+
next_task = tasks[next_task_name]
|
|
190
|
+
|
|
191
|
+
# in majority of cases there will be only one handle name
|
|
192
|
+
# however we need to handle the case when single output
|
|
193
|
+
# is mapped to multiple inputs on the next node
|
|
194
|
+
handle_names = []
|
|
195
|
+
for k, v in next_task.handles_mapping:
|
|
196
|
+
if v == task.name:
|
|
197
|
+
handle_names.append(k)
|
|
198
|
+
|
|
199
|
+
for handle_name in handle_names:
|
|
200
|
+
next_state = next_task.input_states[handle_name]
|
|
201
|
+
next_state.set_state_and_permits(state, 1)
|
|
202
|
+
|
|
203
|
+
# push next tasks to the channel only if the task is not active
|
|
204
|
+
# and current task is not a termination
|
|
205
|
+
if not (next_task_name in active_tasks) and not is_termination:
|
|
206
|
+
active_tasks.add(next_task_name)
|
|
207
|
+
queue.put(
|
|
208
|
+
ScheduledTask(
|
|
209
|
+
status="Task",
|
|
210
|
+
task_name=next_task_name,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# increment depth of the finished task
|
|
215
|
+
depths[task_id] = depth + 1
|
|
216
|
+
except NodeRunError as e:
|
|
217
|
+
logger.exception(f"Execution failed [id: {task_id}]")
|
|
218
|
+
|
|
219
|
+
error = Message(
|
|
220
|
+
id=uuid.uuid4(),
|
|
221
|
+
value=str(e),
|
|
222
|
+
start_time=start_time,
|
|
223
|
+
end_time=datetime.datetime.now(),
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
outputs[task.name] = error
|
|
227
|
+
|
|
228
|
+
# terminate entire graph by sending err task
|
|
229
|
+
queue.put(
|
|
230
|
+
ScheduledTask(
|
|
231
|
+
status="Err",
|
|
232
|
+
task_name=None,
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
except Exception:
|
|
237
|
+
logger.exception(f"Execution failed [id: {task_id}]")
|
|
238
|
+
error = Message(
|
|
239
|
+
id=uuid.uuid4(),
|
|
240
|
+
value="Unexpected server error",
|
|
241
|
+
start_time=start_time,
|
|
242
|
+
end_time=datetime.datetime.now(),
|
|
243
|
+
)
|
|
244
|
+
outputs[task.name] = error
|
|
245
|
+
queue.put(
|
|
246
|
+
ScheduledTask(
|
|
247
|
+
status="Err",
|
|
248
|
+
task_name=None,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def execute_task(
|
|
253
|
+
self,
|
|
254
|
+
task: Task,
|
|
255
|
+
queue: queue.Queue,
|
|
256
|
+
):
|
|
257
|
+
return self.thread_pool_executor.submit(
|
|
258
|
+
self.execute_task_inner,
|
|
259
|
+
task,
|
|
260
|
+
queue,
|
|
261
|
+
)
|