lmnr 0.2.14__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lmnr/__init__.py +4 -4
  2. lmnr/sdk/client.py +161 -0
  3. lmnr/sdk/collector.py +177 -0
  4. lmnr/sdk/constants.py +1 -0
  5. lmnr/sdk/context.py +456 -0
  6. lmnr/sdk/decorators.py +277 -0
  7. lmnr/sdk/interface.py +339 -0
  8. lmnr/sdk/providers/__init__.py +2 -0
  9. lmnr/sdk/providers/base.py +28 -0
  10. lmnr/sdk/providers/fallback.py +131 -0
  11. lmnr/sdk/providers/openai.py +140 -0
  12. lmnr/sdk/providers/utils.py +33 -0
  13. lmnr/sdk/tracing_types.py +197 -0
  14. lmnr/sdk/types.py +69 -0
  15. lmnr/sdk/utils.py +102 -0
  16. lmnr-0.3.0.dist-info/METADATA +185 -0
  17. lmnr-0.3.0.dist-info/RECORD +21 -0
  18. lmnr/cli/__init__.py +0 -0
  19. lmnr/cli/__main__.py +0 -4
  20. lmnr/cli/cli.py +0 -232
  21. lmnr/cli/parser/__init__.py +0 -0
  22. lmnr/cli/parser/nodes/__init__.py +0 -45
  23. lmnr/cli/parser/nodes/code.py +0 -36
  24. lmnr/cli/parser/nodes/condition.py +0 -30
  25. lmnr/cli/parser/nodes/input.py +0 -25
  26. lmnr/cli/parser/nodes/json_extractor.py +0 -29
  27. lmnr/cli/parser/nodes/llm.py +0 -56
  28. lmnr/cli/parser/nodes/output.py +0 -27
  29. lmnr/cli/parser/nodes/router.py +0 -37
  30. lmnr/cli/parser/nodes/semantic_search.py +0 -53
  31. lmnr/cli/parser/nodes/types.py +0 -153
  32. lmnr/cli/parser/parser.py +0 -62
  33. lmnr/cli/parser/utils.py +0 -49
  34. lmnr/cli/zip.py +0 -16
  35. lmnr/sdk/endpoint.py +0 -186
  36. lmnr/sdk/registry.py +0 -29
  37. lmnr/sdk/remote_debugger.py +0 -148
  38. lmnr/types.py +0 -101
  39. lmnr-0.2.14.dist-info/METADATA +0 -187
  40. lmnr-0.2.14.dist-info/RECORD +0 -28
  41. {lmnr-0.2.14.dist-info → lmnr-0.3.0.dist-info}/LICENSE +0 -0
  42. {lmnr-0.2.14.dist-info → lmnr-0.3.0.dist-info}/WHEEL +0 -0
  43. {lmnr-0.2.14.dist-info → lmnr-0.3.0.dist-info}/entry_points.txt +0 -0
@@ -1,27 +0,0 @@
1
- from dataclasses import dataclass
2
- import uuid
3
-
4
- from lmnr.cli.parser.nodes import Handle, NodeFunctions
5
- from lmnr.cli.parser.utils import map_handles
6
-
7
-
8
- @dataclass
9
- class OutputNode(NodeFunctions):
10
- id: uuid.UUID
11
- name: str
12
- inputs: list[Handle]
13
- outputs: list[Handle]
14
- inputs_mappings: dict[uuid.UUID, uuid.UUID]
15
-
16
- def handles_mapping(
17
- self, output_handle_id_to_node_name: dict[str, str]
18
- ) -> list[tuple[str, str]]:
19
- return map_handles(
20
- self.inputs, self.inputs_mappings, output_handle_id_to_node_name
21
- )
22
-
23
- def node_type(self) -> str:
24
- return "Output"
25
-
26
- def config(self) -> dict:
27
- return {}
@@ -1,37 +0,0 @@
1
- from dataclasses import dataclass
2
- import uuid
3
-
4
- from lmnr.cli.parser.nodes import Handle, NodeFunctions
5
- from lmnr.cli.parser.utils import map_handles
6
-
7
-
8
- @dataclass
9
- class Route:
10
- name: str
11
-
12
-
13
- @dataclass
14
- class RouterNode(NodeFunctions):
15
- id: uuid.UUID
16
- name: str
17
- inputs: list[Handle]
18
- outputs: list[Handle]
19
- inputs_mappings: dict[uuid.UUID, uuid.UUID]
20
- routes: list[Route]
21
- has_default_route: bool
22
-
23
- def handles_mapping(
24
- self, output_handle_id_to_node_name: dict[str, str]
25
- ) -> list[tuple[str, str]]:
26
- return map_handles(
27
- self.inputs, self.inputs_mappings, output_handle_id_to_node_name
28
- )
29
-
30
- def node_type(self) -> str:
31
- return "Router"
32
-
33
- def config(self) -> dict:
34
- return {
35
- "routes": str([route.name for route in self.routes]),
36
- "has_default_route": str(self.has_default_route),
37
- }
@@ -1,53 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import uuid
4
-
5
- from lmnr.cli.parser.nodes import Handle, NodeFunctions
6
- from lmnr.cli.parser.utils import map_handles
7
-
8
-
9
- @dataclass
10
- class Dataset:
11
- id: uuid.UUID
12
- # created_at: datetime
13
- # project_id: uuid.UUID
14
- # name: str
15
- # indexed_on: Optional[str]
16
-
17
- @classmethod
18
- def from_dict(cls, dataset_dict: dict) -> "Dataset":
19
- return cls(
20
- id=uuid.UUID(dataset_dict["id"]),
21
- )
22
-
23
-
24
- @dataclass
25
- class SemanticSearchNode(NodeFunctions):
26
- id: uuid.UUID
27
- name: str
28
- inputs: list[Handle]
29
- outputs: list[Handle]
30
- inputs_mappings: dict[uuid.UUID, uuid.UUID]
31
- limit: int
32
- threshold: float
33
- template: str
34
- datasets: list[Dataset]
35
-
36
- def handles_mapping(
37
- self, output_handle_id_to_node_name: dict[str, str]
38
- ) -> list[tuple[str, str]]:
39
- return map_handles(
40
- self.inputs, self.inputs_mappings, output_handle_id_to_node_name
41
- )
42
-
43
- def node_type(self) -> str:
44
- return "SemanticSearch"
45
-
46
- def config(self) -> dict:
47
- return {
48
- "limit": self.limit,
49
- "threshold": self.threshold,
50
- "template": self.template,
51
- "datasource_ids": [str(dataset.id) for dataset in self.datasets],
52
- "datasource_ids_list": str([str(dataset.id) for dataset in self.datasets]),
53
- }
@@ -1,153 +0,0 @@
1
- from typing import Any, Union
2
- import uuid
3
-
4
- from lmnr.cli.parser.nodes import Handle
5
- from lmnr.cli.parser.nodes.code import CodeNode
6
- from lmnr.cli.parser.nodes.condition import ConditionNode
7
- from lmnr.cli.parser.nodes.input import InputNode
8
- from lmnr.cli.parser.nodes.json_extractor import JsonExtractorNode
9
- from lmnr.cli.parser.nodes.llm import LLMNode
10
- from lmnr.cli.parser.nodes.output import OutputNode
11
- from lmnr.cli.parser.nodes.router import Route, RouterNode
12
- from lmnr.cli.parser.nodes.semantic_search import (
13
- Dataset,
14
- SemanticSearchNode,
15
- )
16
- from lmnr.types import NodeInput, ChatMessage
17
-
18
-
19
- def node_input_from_json(json_val: Any) -> NodeInput:
20
- if isinstance(json_val, str):
21
- return json_val
22
- elif isinstance(json_val, list):
23
- return [ChatMessage.model_validate(msg) for msg in json_val]
24
- else:
25
- raise ValueError(f"Invalid NodeInput value: {json_val}")
26
-
27
-
28
- Node = Union[
29
- InputNode,
30
- OutputNode,
31
- ConditionNode,
32
- LLMNode,
33
- RouterNode,
34
- SemanticSearchNode,
35
- CodeNode,
36
- JsonExtractorNode,
37
- ]
38
-
39
-
40
- def node_from_dict(node_dict: dict) -> Node:
41
- if node_dict["type"] == "Input":
42
- return InputNode(
43
- id=uuid.UUID(node_dict["id"]),
44
- name=node_dict["name"],
45
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
46
- input=node_input_from_json(node_dict["input"]),
47
- )
48
- elif node_dict["type"] == "Output":
49
- return OutputNode(
50
- id=uuid.UUID(node_dict["id"]),
51
- name=node_dict["name"],
52
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
53
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
54
- inputs_mappings={
55
- uuid.UUID(k): uuid.UUID(v)
56
- for k, v in node_dict["inputsMappings"].items()
57
- },
58
- )
59
- elif node_dict["type"] == "Condition":
60
- return ConditionNode(
61
- id=uuid.UUID(node_dict["id"]),
62
- name=node_dict["name"],
63
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
64
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
65
- inputs_mappings={
66
- uuid.UUID(k): uuid.UUID(v)
67
- for k, v in node_dict["inputsMappings"].items()
68
- },
69
- condition=node_dict["condition"],
70
- )
71
- elif node_dict["type"] == "LLM":
72
- return LLMNode(
73
- id=uuid.UUID(node_dict["id"]),
74
- name=node_dict["name"],
75
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
76
- dynamic_inputs=[
77
- Handle.from_dict(handle) for handle in node_dict["dynamicInputs"]
78
- ],
79
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
80
- inputs_mappings={
81
- uuid.UUID(k): uuid.UUID(v)
82
- for k, v in node_dict["inputsMappings"].items()
83
- },
84
- prompt=node_dict["prompt"],
85
- model=node_dict["model"],
86
- model_params=(
87
- node_dict["modelParams"] if "modelParams" in node_dict else None
88
- ),
89
- stream=False,
90
- structured_output_enabled=node_dict.get("structuredOutputEnabled", False),
91
- structured_output_max_retries=node_dict.get(
92
- "structuredOutputMaxRetries", 0
93
- ),
94
- structured_output_schema=node_dict.get("structuredOutputSchema", None),
95
- structured_output_schema_target=node_dict.get(
96
- "structuredOutputSchemaTarget", None
97
- ),
98
- )
99
- elif node_dict["type"] == "Router":
100
- return RouterNode(
101
- id=uuid.UUID(node_dict["id"]),
102
- name=node_dict["name"],
103
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
104
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
105
- inputs_mappings={
106
- uuid.UUID(k): uuid.UUID(v)
107
- for k, v in node_dict["inputsMappings"].items()
108
- },
109
- routes=[Route(name=route["name"]) for route in node_dict["routes"]],
110
- has_default_route=node_dict["hasDefaultRoute"],
111
- )
112
- elif node_dict["type"] == "SemanticSearch":
113
- return SemanticSearchNode(
114
- id=uuid.UUID(node_dict["id"]),
115
- name=node_dict["name"],
116
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
117
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
118
- inputs_mappings={
119
- uuid.UUID(k): uuid.UUID(v)
120
- for k, v in node_dict["inputsMappings"].items()
121
- },
122
- limit=node_dict["limit"],
123
- threshold=node_dict["threshold"],
124
- template=node_dict["template"],
125
- datasets=[Dataset.from_dict(ds) for ds in node_dict["datasets"]],
126
- )
127
- elif node_dict["type"] == "Code":
128
- return CodeNode(
129
- id=uuid.UUID(node_dict["id"]),
130
- name=node_dict["name"],
131
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
132
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
133
- inputs_mappings={
134
- uuid.UUID(k): uuid.UUID(v)
135
- for k, v in node_dict["inputsMappings"].items()
136
- },
137
- code=node_dict["code"],
138
- fn_name=node_dict["fnName"],
139
- )
140
- elif node_dict["type"] == "JsonExtractor":
141
- return JsonExtractorNode(
142
- id=uuid.UUID(node_dict["id"]),
143
- name=node_dict["name"],
144
- inputs=[Handle.from_dict(handle) for handle in node_dict["inputs"]],
145
- outputs=[Handle.from_dict(handle) for handle in node_dict["outputs"]],
146
- inputs_mappings={
147
- uuid.UUID(k): uuid.UUID(v)
148
- for k, v in node_dict["inputsMappings"].items()
149
- },
150
- template=node_dict["template"],
151
- )
152
- else:
153
- raise ValueError(f"Node type {node_dict['type']} not supported")
lmnr/cli/parser/parser.py DELETED
@@ -1,62 +0,0 @@
1
- from lmnr.cli.parser.utils import replace_spaces_with_underscores
2
- from .nodes.types import node_from_dict
3
-
4
-
5
- def runnable_graph_to_template_vars(graph: dict) -> dict:
6
- """
7
- Convert a runnable graph to template vars to be rendered in a cookiecutter context.
8
- """
9
- node_id_to_node_name = {}
10
- output_handle_id_to_node_name: dict[str, str] = {}
11
- for node in graph["nodes"].values():
12
- # override node names in the graph itself to be safe
13
- node["name"] = replace_spaces_with_underscores(node["name"])
14
-
15
- node_id_to_node_name[node["id"]] = node["name"]
16
- for handle in node["outputs"]:
17
- output_handle_id_to_node_name[handle["id"]] = node["name"]
18
-
19
- tasks = []
20
- for node_obj in graph["nodes"].values():
21
- node = node_from_dict(node_obj)
22
- handles_mapping = node.handles_mapping(output_handle_id_to_node_name)
23
- node_type = node.node_type()
24
-
25
- unique_handles = set([handle_name for (handle_name, _) in handles_mapping])
26
-
27
- tasks.append(
28
- {
29
- "name": node.name,
30
- "function_name": f"run_{node.name}",
31
- "node_type": node_type,
32
- "handles_mapping": handles_mapping,
33
- # since we map from to to from, all 'to's won't repeat
34
- "input_handle_names": [
35
- handle_name for (handle_name, _) in handles_mapping
36
- ],
37
- "handle_args": ", ".join(
38
- [f"{handle_name}: NodeInput" for handle_name in unique_handles]
39
- ),
40
- "prev": [],
41
- "next": [],
42
- "config": node.config(),
43
- }
44
- )
45
-
46
- for to, from_ in graph["pred"].items():
47
- # TODO: Make "tasks" a hashmap from node id (as str!) to task
48
- to_task = [task for task in tasks if task["name"] == node_id_to_node_name[to]][
49
- 0
50
- ]
51
- from_tasks = []
52
- for f in from_:
53
- from_tasks.append(
54
- [task for task in tasks if task["name"] == node_id_to_node_name[f]][0]
55
- )
56
-
57
- for from_task in from_tasks:
58
- to_task["prev"].append(from_task["name"])
59
- from_task["next"].append(node_id_to_node_name[to])
60
-
61
- # Return as a hashmap due to cookiecutter limitations, investigate later.
62
- return {task["name"]: task for task in tasks}
lmnr/cli/parser/utils.py DELETED
@@ -1,49 +0,0 @@
1
- # Convert a list of handles to a map of input handle names
2
- # to their respective values
3
- import uuid
4
- from .nodes import Handle
5
-
6
-
7
- def map_handles(
8
- inputs: list[Handle],
9
- inputs_mappings: dict[uuid.UUID, uuid.UUID],
10
- output_handle_id_to_node_name: dict[str, str],
11
- ) -> list[tuple[str, str]]:
12
- mapping = []
13
-
14
- for to, from_ in inputs_mappings.items():
15
- for input in inputs:
16
- if input.id == to:
17
- mapping.append((input.name, from_))
18
- break
19
- else:
20
- raise ValueError(f"Input handle {to} not found in inputs")
21
-
22
- return [
23
- (input_name, output_handle_id_to_node_name[str(output_id)])
24
- for input_name, output_id in mapping
25
- ]
26
-
27
-
28
- def replace_spaces_with_underscores(s: str):
29
- spaces = [
30
- "\u0020",
31
- "\u00A0",
32
- "\u1680",
33
- "\u2000",
34
- "\u2001",
35
- "\u2002",
36
- "\u2003",
37
- "\u2004",
38
- "\u2005",
39
- "\u2006",
40
- "\u2007",
41
- "\u2008",
42
- "\u2009",
43
- "\u200A",
44
- "\u200B",
45
- "\u202F",
46
- "\u205F",
47
- "\u3000",
48
- ]
49
- return s.translate({ord(space): "_" for space in spaces})
lmnr/cli/zip.py DELETED
@@ -1,16 +0,0 @@
1
- import os
2
- from pathlib import Path
3
- import zipfile
4
-
5
-
6
- def zip_directory(directory_path: Path, zip_file_path: Path):
7
- with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
8
- for root, _, files in os.walk(directory_path):
9
- for file in files:
10
- # Don't include the zip file itself, otherwise goes to infinite loop
11
- if file == zip_file_path.name:
12
- continue
13
-
14
- file_path = os.path.join(root, file)
15
- arcname = os.path.relpath(file_path, directory_path)
16
- zipf.write(file_path, arcname)
lmnr/sdk/endpoint.py DELETED
@@ -1,186 +0,0 @@
1
- import json
2
- from pydantic.alias_generators import to_snake
3
- import pydantic
4
- import requests
5
- from lmnr.types import (
6
- EndpointRunError, EndpointRunResponse, NodeInput, EndpointRunRequest,
7
- ToolCallError, ToolCallRequest, ToolCallResponse, SDKError
8
- )
9
- from typing import Callable, Optional
10
- from websockets.sync.client import connect
11
-
12
- class Laminar:
13
- project_api_key: Optional[str] = None
14
- def __init__(self, project_api_key: str):
15
- """Initialize the Laminar object with your project API key
16
-
17
- Args:
18
- project_api_key (str):
19
- Project api key. Generate or view your keys
20
- in the project settings in the Laminar dashboard.
21
- """
22
- self.project_api_key = project_api_key
23
- self.url = 'https://api.lmnr.ai/v2/endpoint/run'
24
- self.ws_url = 'wss://api.lmnr.ai/v2/endpoint/ws'
25
-
26
- def run (
27
- self,
28
- endpoint: str,
29
- inputs: dict[str, NodeInput],
30
- env: dict[str, str] = {},
31
- metadata: dict[str, str] = {},
32
- tools: list[Callable[..., NodeInput]] = [],
33
- ) -> EndpointRunResponse:
34
- """Runs the endpoint with the given inputs
35
-
36
- Args:
37
- endpoint (str): name of the Laminar endpoint
38
- inputs (dict[str, NodeInput]):
39
- inputs to the endpoint's target pipeline.
40
- Keys in the dictionary must match input node names
41
- env (dict[str, str], optional):
42
- Environment variables for the pipeline execution.
43
- Defaults to {}.
44
- metadata (dict[str, str], optional):
45
- any custom metadata to be stored
46
- with execution trace. Defaults to {}.
47
- tools (list[Callable[..., NodeInput]], optional):
48
- List of callable functions the execution can call as tools.
49
- If specified and non-empty, a bidirectional communication
50
- with Laminar API through websocket will be established.
51
- Defaults to [].
52
-
53
- Returns:
54
- EndpointRunResponse: response object containing the outputs
55
-
56
- Raises:
57
- ValueError: if project API key is not set
58
- EndpointRunError: if the endpoint run fails
59
- SDKError: if an error occurs on client side during the execution
60
- """
61
- if self.project_api_key is None:
62
- raise ValueError(
63
- 'Please initialize the Laminar object with'
64
- ' your project API key'
65
- )
66
- if tools:
67
- return self._run_websocket(endpoint, inputs, env, metadata, tools)
68
- return self._run(endpoint, inputs, env, metadata)
69
-
70
- def _run(
71
- self,
72
- endpoint: str,
73
- inputs: dict[str, NodeInput],
74
- env: dict[str, str] = {},
75
- metadata: dict[str, str] = {}
76
- ) -> EndpointRunResponse:
77
- try:
78
- request = EndpointRunRequest(
79
- inputs = inputs,
80
- endpoint = endpoint,
81
- env = env,
82
- metadata = metadata
83
- )
84
- except Exception as e:
85
- raise ValueError(f'Invalid request: {e}')
86
- response = requests.post(
87
- self.url,
88
- json=json.loads(request.model_dump_json()),
89
- headers={'Authorization': f'Bearer {self.project_api_key}'}
90
- )
91
- if response.status_code != 200:
92
- raise EndpointRunError(response)
93
- try:
94
- resp_json = response.json()
95
- keys = list(resp_json.keys())
96
- for key in keys:
97
- value = resp_json[key]
98
- del resp_json[key]
99
- resp_json[to_snake(key)] = value
100
- return EndpointRunResponse(**resp_json)
101
- except:
102
- raise EndpointRunError(response)
103
-
104
- def _run_websocket(
105
- self,
106
- endpoint: str,
107
- inputs: dict[str, NodeInput],
108
- env: dict[str, str] = {},
109
- metadata: dict[str, str] = {},
110
- tools: list[Callable[..., NodeInput]] = [],
111
- ) -> EndpointRunResponse:
112
- try:
113
- request = EndpointRunRequest(
114
- inputs = inputs,
115
- endpoint = endpoint,
116
- env = env,
117
- metadata = metadata
118
- )
119
- except Exception as e:
120
- raise ValueError(f'Invalid request: {e}')
121
-
122
- with connect(
123
- self.ws_url,
124
- additional_headers={
125
- 'Authorization': f'Bearer {self.project_api_key}'
126
- }
127
- ) as websocket:
128
- websocket.send(request.model_dump_json())
129
- req_id = None
130
-
131
- while True:
132
- message = websocket.recv()
133
- try:
134
- tool_call = ToolCallRequest.model_validate_json(message)
135
- req_id = tool_call.req_id
136
- matching_tools = [
137
- tool for tool in tools
138
- if tool.__name__ == tool_call.toolCall.function.name
139
- ]
140
- if not matching_tools:
141
- raise SDKError(
142
- f'Tool {tool_call.toolCall.function.name} not found.'
143
- ' Registered tools: '
144
- f'{", ".join([tool.__name__ for tool in tools])}'
145
- )
146
- tool = matching_tools[0]
147
- # default the arguments to an empty dictionary
148
- if tool.__name__ == tool_call.toolCall.function.name:
149
- arguments = {}
150
- try:
151
- arguments = json.loads(tool_call.toolCall.function.arguments)
152
- except:
153
- pass
154
- try:
155
- response = tool(**arguments)
156
- except Exception as e:
157
- error_message = 'Error occurred while running tool' +\
158
- f'{tool.__name__}: {e}'
159
- e = ToolCallError(error=error_message, reqId=req_id)
160
- websocket.send(e.model_dump_json())
161
- formatted_response = None
162
- try:
163
- formatted_response = ToolCallResponse(
164
- reqId=tool_call.reqId,
165
- response=response
166
- )
167
- except pydantic.ValidationError as e:
168
- formatted_response = ToolCallResponse(
169
- reqId=tool_call.reqId,
170
- response=str(response)
171
- )
172
- websocket.send(
173
- formatted_response.model_dump_json()
174
- )
175
- except pydantic.ValidationError as e:
176
- message_json = json.loads(message)
177
- keys = list(message_json.keys())
178
- for key in keys:
179
- value = message_json[key]
180
- del message_json[key]
181
- message_json[to_snake(key)] = value
182
- result = EndpointRunResponse.model_validate(message_json)
183
- websocket.close()
184
- return result
185
- except Exception as e:
186
- raise SDKError(f'Error communicating to backend through websocket {e}')
lmnr/sdk/registry.py DELETED
@@ -1,29 +0,0 @@
1
- from typing import Callable
2
-
3
- from lmnr.types import NodeFunction, NodeInput
4
-
5
-
6
- class Registry:
7
- """
8
- Class to register and resolve node functions based on their node names.
9
-
10
- Node names cannot have space in their name.
11
- """
12
-
13
- functions: dict[str, NodeFunction]
14
-
15
- def __init__(self):
16
- self.functions = {}
17
-
18
- def add(self, node_name: str, function: Callable[..., NodeInput]):
19
- self.functions[node_name] = NodeFunction(node_name, function)
20
-
21
- def func(self, node_name: str):
22
- def decorator(f: Callable[..., NodeInput]):
23
- self.add(node_name, f)
24
- return f
25
-
26
- return decorator
27
-
28
- def get(self, node_name: str) -> Callable[..., NodeInput]:
29
- return self.functions[node_name].function