lmnr 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lmnr might be problematic. Click here for more details.
- lmnr/__init__.py +5 -1
- lmnr/cli/__main__.py +4 -0
- lmnr/cli/cli.py +97 -0
- lmnr/cli/cookiecutter.json +9 -0
- lmnr/cli/parser/__init__.py +0 -0
- lmnr/cli/parser/nodes/__init__.py +50 -0
- lmnr/cli/parser/nodes/types.py +156 -0
- lmnr/cli/parser/parser.py +58 -0
- lmnr/cli/parser/utils.py +25 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/__init__.py +0 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/__init__.py +1 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/action.py +14 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/engine.py +261 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/state.py +69 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/engine/task.py +38 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/__init__.py +1 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/nodes/functions.py +149 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/pipelines/{{cookiecutter.pipeline_dir_name}}/{{cookiecutter.pipeline_dir_name}}.py +87 -0
- lmnr/cli/{{cookiecutter.lmnr_pipelines_dir_name}}/types.py +50 -0
- lmnr/sdk/endpoint.py +166 -0
- lmnr/types.py +93 -0
- lmnr-0.1.3.dist-info/LICENSE +72 -0
- lmnr-0.1.3.dist-info/METADATA +78 -0
- lmnr-0.1.3.dist-info/RECORD +26 -0
- lmnr-0.1.3.dist-info/entry_points.txt +3 -0
- lmnr/endpoint.py +0 -43
- lmnr/model.py +0 -39
- lmnr-0.1.1.dist-info/LICENSE +0 -7
- lmnr-0.1.1.dist-info/METADATA +0 -37
- lmnr-0.1.1.dist-info/RECORD +0 -7
- {lmnr-0.1.1.dist-info → lmnr-0.1.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from lmnr_engine.types import Message
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class State:
|
|
10
|
+
status: str # "Success", "Empty", "Termination" # TODO: Turn into Enum
|
|
11
|
+
message: Union[Message, None]
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def new(cls, val: Message) -> "State":
|
|
15
|
+
return cls(
|
|
16
|
+
status="Success",
|
|
17
|
+
message=val,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def empty(cls) -> "State":
|
|
22
|
+
return cls(
|
|
23
|
+
status="Empty",
|
|
24
|
+
message=Message.empty(),
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def termination(cls) -> "State":
|
|
29
|
+
return cls(
|
|
30
|
+
status="Termination",
|
|
31
|
+
message=None,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def is_success(self) -> bool:
|
|
35
|
+
return self.status == "Success"
|
|
36
|
+
|
|
37
|
+
def is_termination(self) -> bool:
|
|
38
|
+
return self.status == "Termination"
|
|
39
|
+
|
|
40
|
+
def get_out(self) -> Message:
|
|
41
|
+
if self.message is None:
|
|
42
|
+
raise ValueError("Cannot get message from a termination state")
|
|
43
|
+
|
|
44
|
+
return self.message
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ExecState:
|
|
49
|
+
output: State
|
|
50
|
+
semaphore: threading.Semaphore
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def new(cls) -> "ExecState":
|
|
54
|
+
return cls(
|
|
55
|
+
output=State.empty(),
|
|
56
|
+
semaphore=threading.Semaphore(0),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Assume this is called by the caller who doesn't need to acquire semaphore
|
|
60
|
+
def set_state(self, output: State):
|
|
61
|
+
self.output = output
|
|
62
|
+
|
|
63
|
+
# Assume the caller is smart to call this after acquiring the semaphore
|
|
64
|
+
def get_state(self) -> State:
|
|
65
|
+
return self.output
|
|
66
|
+
|
|
67
|
+
def set_state_and_permits(self, output: State, permits: int):
|
|
68
|
+
self.output = output
|
|
69
|
+
self.semaphore.release(permits)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
from .action import RunOutput
|
|
4
|
+
from .state import ExecState
|
|
5
|
+
from lmnr_engine.types import NodeInput
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Task:
|
|
9
|
+
# unique identifier
|
|
10
|
+
name: str
|
|
11
|
+
# mapping from current node's handle name to previous node's unique name
|
|
12
|
+
# assumes nodes have only one output
|
|
13
|
+
handles_mapping: list[tuple[str, str]]
|
|
14
|
+
# Value or a function that returns a value
|
|
15
|
+
# Usually a function which waits for inputs from previous nodes
|
|
16
|
+
value: Union[NodeInput, Callable[..., RunOutput]] # TODO: Type this fully
|
|
17
|
+
# unique node names of previous nodes
|
|
18
|
+
prev: list[str]
|
|
19
|
+
# unique node names of next nodes
|
|
20
|
+
next: list[str]
|
|
21
|
+
input_states: dict[str, ExecState]
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
handles_mapping: list[tuple[str, str]],
|
|
27
|
+
value: Union[NodeInput, Callable[..., RunOutput]],
|
|
28
|
+
prev: list[str],
|
|
29
|
+
next: list[str],
|
|
30
|
+
) -> None:
|
|
31
|
+
self.name = name
|
|
32
|
+
self.handles_mapping = handles_mapping
|
|
33
|
+
self.value = value
|
|
34
|
+
self.prev = prev
|
|
35
|
+
self.next = next
|
|
36
|
+
self.input_states = {
|
|
37
|
+
handle_name: ExecState.new() for (handle_name, _) in self.handles_mapping
|
|
38
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .{{ cookiecutter.pipeline_dir_name }} import {{ cookiecutter.class_name }}
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from lmnr_engine.engine.action import NodeRunError, RunOutput
|
|
5
|
+
from lmnr_engine.types import ChatMessage, NodeInput
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
{% for task in cookiecutter._tasks.values() %}
|
|
9
|
+
{% if task.node_type == "LLM" %}
|
|
10
|
+
def {{task.function_name}}({{ task.handle_args }}, _env: dict[str, str]) -> RunOutput:
|
|
11
|
+
{% set chat_messages_found = false %}
|
|
12
|
+
{% for input_handle_name in task.input_handle_names %}
|
|
13
|
+
{% if input_handle_name == 'chat_messages' %}
|
|
14
|
+
{% set chat_messages_found = true %}
|
|
15
|
+
{% endif %}
|
|
16
|
+
{% endfor %}
|
|
17
|
+
|
|
18
|
+
{% if chat_messages_found %}
|
|
19
|
+
input_chat_messages = chat_messages
|
|
20
|
+
{% else %}
|
|
21
|
+
input_chat_messages = []
|
|
22
|
+
{% endif %}
|
|
23
|
+
|
|
24
|
+
rendered_prompt = """{{task.config.prompt}}"""
|
|
25
|
+
{% set prompt_variables = task.input_handle_names|reject("equalto", "chat_messages") %}
|
|
26
|
+
{% for prompt_variable in prompt_variables %}
|
|
27
|
+
# TODO: Fix this. Using double curly braces in quotes because normal double curly braces
|
|
28
|
+
# get replaced during rendering by Cookiecutter. This is a hacky solution.
|
|
29
|
+
rendered_prompt = rendered_prompt.replace("{{'{{'}}{{prompt_variable}}{{'}}'}}", {{prompt_variable}}) # type: ignore
|
|
30
|
+
{% endfor %}
|
|
31
|
+
|
|
32
|
+
{% if task.config.model_params == none %}
|
|
33
|
+
params = {}
|
|
34
|
+
{% else %}
|
|
35
|
+
params = json.loads(
|
|
36
|
+
"""{{task.config.model_params}}"""
|
|
37
|
+
)
|
|
38
|
+
{% endif %}
|
|
39
|
+
|
|
40
|
+
messages = [ChatMessage(role="system", content=rendered_prompt)]
|
|
41
|
+
messages.extend(input_chat_messages)
|
|
42
|
+
|
|
43
|
+
{% if task.config.provider == "openai" %}
|
|
44
|
+
message_jsons = [
|
|
45
|
+
{"role": message.role, "content": message.content} for message in messages
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
data = {
|
|
49
|
+
"model": "{{task.config.model}}",
|
|
50
|
+
"messages": message_jsons,
|
|
51
|
+
}
|
|
52
|
+
data.update(params)
|
|
53
|
+
|
|
54
|
+
headers = {
|
|
55
|
+
"Content-Type": "application/json",
|
|
56
|
+
"Authorization": f"Bearer {_env['OPENAI_API_KEY']}",
|
|
57
|
+
}
|
|
58
|
+
res = requests.post(
|
|
59
|
+
"https://api.openai.com/v1/chat/completions", json=data, headers=headers
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if res.status_code != 200:
|
|
63
|
+
res_json = res.json()
|
|
64
|
+
raise NodeRunError(f'OpenAI completions request failed: {res_json["error"]["message"]}')
|
|
65
|
+
|
|
66
|
+
chat_completion = res.json()
|
|
67
|
+
|
|
68
|
+
completion_message = chat_completion["choices"][0]["message"]["content"]
|
|
69
|
+
|
|
70
|
+
meta_log = {}
|
|
71
|
+
meta_log["node_chunk_id"] = None # TODO: Add node chunk id
|
|
72
|
+
meta_log["model"] = "{{task.config.model}}"
|
|
73
|
+
meta_log["prompt"] = rendered_prompt
|
|
74
|
+
meta_log["input_message_count"] = len(messages)
|
|
75
|
+
meta_log["input_token_count"] = chat_completion["usage"]["prompt_tokens"]
|
|
76
|
+
meta_log["output_token_count"] = chat_completion["usage"]["completion_tokens"]
|
|
77
|
+
meta_log["total_token_count"] = (
|
|
78
|
+
chat_completion["usage"]["prompt_tokens"] + chat_completion["usage"]["completion_tokens"]
|
|
79
|
+
)
|
|
80
|
+
meta_log["approximate_cost"] = None # TODO: Add approximate cost
|
|
81
|
+
{% elif task.config.provider == "anthropic" %}
|
|
82
|
+
data = {
|
|
83
|
+
"model": "{{task.config.model}}",
|
|
84
|
+
"max_tokens": 4096,
|
|
85
|
+
}
|
|
86
|
+
data.update(params)
|
|
87
|
+
|
|
88
|
+
# TODO: Generate appropriate code based on this if-else block
|
|
89
|
+
if len(messages) == 1 and messages[0].role == "system":
|
|
90
|
+
messages[0].role = "user"
|
|
91
|
+
message_jsons = [
|
|
92
|
+
{"role": message.role, "content": message.content} for message in messages
|
|
93
|
+
]
|
|
94
|
+
data["messages"] = message_jsons
|
|
95
|
+
else:
|
|
96
|
+
data["system"] = messages[0].content
|
|
97
|
+
message_jsons = [
|
|
98
|
+
{"role": message.role, "content": message.content} for message in messages[1:]
|
|
99
|
+
]
|
|
100
|
+
data["messages"] = message_jsons
|
|
101
|
+
|
|
102
|
+
headers = {
|
|
103
|
+
"Content-Type": "application/json",
|
|
104
|
+
"X-Api-Key": _env['ANTHROPIC_API_KEY'],
|
|
105
|
+
"Anthropic-Version": "2023-06-01",
|
|
106
|
+
}
|
|
107
|
+
res = requests.post(
|
|
108
|
+
"https://api.anthropic.com/v1/messages", json=data, headers=headers
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if res.status_code != 200:
|
|
112
|
+
raise NodeRunError(f"Anthropic message request failed: {res.text}")
|
|
113
|
+
|
|
114
|
+
chat_completion = res.json()
|
|
115
|
+
|
|
116
|
+
completion_message = chat_completion["content"][0]["text"]
|
|
117
|
+
|
|
118
|
+
meta_log = {}
|
|
119
|
+
meta_log["node_chunk_id"] = None # TODO: Add node chunk id
|
|
120
|
+
meta_log["model"] = "{{task.config.model}}"
|
|
121
|
+
meta_log["prompt"] = rendered_prompt
|
|
122
|
+
meta_log["input_message_count"] = len(messages)
|
|
123
|
+
meta_log["input_token_count"] = chat_completion["usage"]["input_tokens"]
|
|
124
|
+
meta_log["output_token_count"] = chat_completion["usage"]["output_tokens"]
|
|
125
|
+
meta_log["total_token_count"] = (
|
|
126
|
+
chat_completion["usage"]["input_tokens"] + chat_completion["usage"]["output_tokens"]
|
|
127
|
+
)
|
|
128
|
+
meta_log["approximate_cost"] = None # TODO: Add approximate cost
|
|
129
|
+
{% else %}
|
|
130
|
+
{% endif %}
|
|
131
|
+
|
|
132
|
+
return RunOutput(status="Success", output=completion_message)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
{% elif task.node_type == "Output" %}
|
|
136
|
+
def {{task.function_name}}(output: NodeInput, _env: dict[str, str]) -> RunOutput:
|
|
137
|
+
return RunOutput(status="Success", output=output)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
{% elif task.node_type == "Input" %}
|
|
141
|
+
{# Do nothing for Input tasks #}
|
|
142
|
+
{% else %}
|
|
143
|
+
def {{task.function_name}}(output: NodeInput, _env: dict[str, str]) -> RunOutput:
|
|
144
|
+
return RunOutput(status="Success", output=output)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
{% endif %}
|
|
148
|
+
{% endfor %}
|
|
149
|
+
# Other functions can be added here
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Union
|
|
5
|
+
|
|
6
|
+
from lmnr_engine.types import ChatMessageList, ChatMessage
|
|
7
|
+
from lmnr_engine.engine import Engine
|
|
8
|
+
{% set function_names = cookiecutter._tasks.values() | selectattr('node_type', '!=', 'Input') | map(attribute='function_name') | join(', ') %}
|
|
9
|
+
from .nodes.functions import {{ function_names }}
|
|
10
|
+
from lmnr_engine.engine.task import Task
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PipelineRunnerError(Exception):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class PipelineRunOutput:
|
|
22
|
+
value: Union[str, ChatMessageList]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# This class is not imported in other files and can be renamed to desired name
|
|
26
|
+
class {{ cookiecutter.class_name }}:
|
|
27
|
+
thread_pool_executor: ThreadPoolExecutor
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self, thread_pool_executor: Optional[ThreadPoolExecutor] = None
|
|
31
|
+
) -> None:
|
|
32
|
+
# Set max workers to hard-coded value for now
|
|
33
|
+
self.thread_pool_executor = (
|
|
34
|
+
ThreadPoolExecutor(max_workers=10)
|
|
35
|
+
if thread_pool_executor is None
|
|
36
|
+
else thread_pool_executor
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def run(
|
|
40
|
+
self,
|
|
41
|
+
inputs: dict[str, Union[str, list]],
|
|
42
|
+
env: dict[str, str] = {},
|
|
43
|
+
) -> dict[str, PipelineRunOutput]:
|
|
44
|
+
"""
|
|
45
|
+
Run the pipeline with the given graph
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
PipelineRunnerError: if there is an error running the pipeline
|
|
49
|
+
"""
|
|
50
|
+
logger.info("Running pipeline {{ cookiecutter.pipeline_name }}, pipeline_version: {{ cookiecutter.pipeline_version_name }}")
|
|
51
|
+
|
|
52
|
+
run_inputs = {}
|
|
53
|
+
for inp_name, inp in inputs.items():
|
|
54
|
+
if isinstance(inp, str):
|
|
55
|
+
run_inputs[inp_name] = inp
|
|
56
|
+
else:
|
|
57
|
+
assert isinstance(inp, list), f"Invalid input type: {type(inp)}"
|
|
58
|
+
run_inputs[inp_name] = [ChatMessage.from_dict(msg) for msg in inp]
|
|
59
|
+
|
|
60
|
+
tasks = []
|
|
61
|
+
{% for task in cookiecutter._tasks.values() %}
|
|
62
|
+
tasks.append(
|
|
63
|
+
Task(
|
|
64
|
+
name="{{ task.name }}",
|
|
65
|
+
value={{ "''" if task.node_type == "Input" else task.function_name }},
|
|
66
|
+
handles_mapping={{ task.handles_mapping }},
|
|
67
|
+
prev=[
|
|
68
|
+
{% for prev in task.prev %}
|
|
69
|
+
"{{ prev }}",
|
|
70
|
+
{% endfor %}
|
|
71
|
+
],
|
|
72
|
+
next=[
|
|
73
|
+
{% for next in task.next %}
|
|
74
|
+
"{{ next }}",
|
|
75
|
+
{% endfor %}
|
|
76
|
+
],
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
{% endfor %}
|
|
80
|
+
engine = Engine.with_tasks(tasks, self.thread_pool_executor, env=env)
|
|
81
|
+
|
|
82
|
+
# TODO: Raise PipelineRunnerError with node_errors
|
|
83
|
+
run_res = engine.run(run_inputs)
|
|
84
|
+
return {
|
|
85
|
+
name: PipelineRunOutput(value=output.value)
|
|
86
|
+
for name, output in run_res.items()
|
|
87
|
+
}
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Union
|
|
3
|
+
import uuid
|
|
4
|
+
import datetime
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ChatMessage:
|
|
9
|
+
role: str
|
|
10
|
+
content: str
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def from_dict(cls, data: dict) -> "ChatMessage":
|
|
14
|
+
return cls(
|
|
15
|
+
role=data["role"],
|
|
16
|
+
content=data["content"],
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
NodeInput = Union[str, list[ChatMessage]] # TODO: Add conditioned value
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Message:
|
|
25
|
+
id: uuid.UUID
|
|
26
|
+
# output value of producing node in form of NodeInput
|
|
27
|
+
# for the following consumer
|
|
28
|
+
value: NodeInput
|
|
29
|
+
# all input messages to this node; accumulates previous messages too
|
|
30
|
+
# input_messages: list["Message"]
|
|
31
|
+
start_time: datetime.datetime
|
|
32
|
+
end_time: datetime.datetime
|
|
33
|
+
# node_id: uuid.UUID
|
|
34
|
+
# node_name: str
|
|
35
|
+
# node_type: str
|
|
36
|
+
# all node per-run metadata that needs to be logged at the end of execution
|
|
37
|
+
# meta_log: MetaLog | None
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def empty(cls) -> "Message":
|
|
41
|
+
return cls(
|
|
42
|
+
id=uuid.uuid4(),
|
|
43
|
+
value="",
|
|
44
|
+
# input_messages=[],
|
|
45
|
+
start_time=datetime.datetime.now(),
|
|
46
|
+
end_time=datetime.datetime.now(),
|
|
47
|
+
# node_id=uuid.uuid4(),
|
|
48
|
+
# node_name="",
|
|
49
|
+
# node_type="",
|
|
50
|
+
)
|
lmnr/sdk/endpoint.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pydantic.alias_generators import to_snake
|
|
3
|
+
import pydantic
|
|
4
|
+
import requests
|
|
5
|
+
from ..types import (
|
|
6
|
+
EndpointRunError, EndpointRunResponse, NodeInput, EndpointRunRequest,
|
|
7
|
+
ToolCall, SDKError
|
|
8
|
+
)
|
|
9
|
+
from typing import Callable, Optional
|
|
10
|
+
from websockets.sync.client import connect
|
|
11
|
+
|
|
12
|
+
class Laminar:
|
|
13
|
+
project_api_key: Optional[str] = None
|
|
14
|
+
def __init__(self, project_api_key: str):
|
|
15
|
+
"""Initialize the Laminar object with your project API key
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
project_api_key (str):
|
|
19
|
+
Project api key. Generate or view your keys
|
|
20
|
+
in the project settings in the Laminar dashboard.
|
|
21
|
+
"""
|
|
22
|
+
self.project_api_key = project_api_key
|
|
23
|
+
self.url = 'https://api.lmnr.ai/v2/endpoint/run'
|
|
24
|
+
self.ws_url = 'ws://api.lmnr.ai/v2/endpoint/ws'
|
|
25
|
+
|
|
26
|
+
def run (
|
|
27
|
+
self,
|
|
28
|
+
endpoint: str,
|
|
29
|
+
inputs: dict[str, NodeInput],
|
|
30
|
+
env: dict[str, str] = {},
|
|
31
|
+
metadata: dict[str, str] = {},
|
|
32
|
+
tools: list[Callable[..., NodeInput]] = [],
|
|
33
|
+
) -> EndpointRunResponse:
|
|
34
|
+
"""Runs the endpoint with the given inputs
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
endpoint (str): name of the Laminar endpoint
|
|
38
|
+
inputs (dict[str, NodeInput]):
|
|
39
|
+
inputs to the endpoint's target pipeline.
|
|
40
|
+
Keys in the dictionary must match input node names
|
|
41
|
+
env (dict[str, str], optional):
|
|
42
|
+
Environment variables for the pipeline execution.
|
|
43
|
+
Defaults to {}.
|
|
44
|
+
metadata (dict[str, str], optional):
|
|
45
|
+
any custom metadata to be stored
|
|
46
|
+
with execution trace. Defaults to {}.
|
|
47
|
+
tools (list[Callable[..., NodeInput]], optional):
|
|
48
|
+
List of callable functions the execution can call as tools.
|
|
49
|
+
If specified and non-empty, a bidirectional communication
|
|
50
|
+
with Laminar API through websocket will be established.
|
|
51
|
+
Defaults to [].
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
EndpointRunResponse: response object containing the outputs
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: if project API key is not set
|
|
58
|
+
EndpointRunError: if the endpoint run fails
|
|
59
|
+
SDKError: if an error occurs on client side during the execution
|
|
60
|
+
"""
|
|
61
|
+
if self.project_api_key is None:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
'Please initialize the Laminar object with'
|
|
64
|
+
' your project API key'
|
|
65
|
+
)
|
|
66
|
+
if tools:
|
|
67
|
+
return self._run_websocket(endpoint, inputs, env, metadata, tools)
|
|
68
|
+
return self._run(endpoint, inputs, env, metadata)
|
|
69
|
+
|
|
70
|
+
def _run(
|
|
71
|
+
self,
|
|
72
|
+
endpoint: str,
|
|
73
|
+
inputs: dict[str, NodeInput],
|
|
74
|
+
env: dict[str, str] = {},
|
|
75
|
+
metadata: dict[str, str] = {}
|
|
76
|
+
) -> EndpointRunResponse:
|
|
77
|
+
try:
|
|
78
|
+
request = EndpointRunRequest(
|
|
79
|
+
inputs = inputs,
|
|
80
|
+
endpoint = endpoint,
|
|
81
|
+
env = env,
|
|
82
|
+
metadata = metadata
|
|
83
|
+
)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
raise ValueError(f'Invalid request: {e}')
|
|
86
|
+
response = requests.post(
|
|
87
|
+
self.url,
|
|
88
|
+
json=json.loads(request.model_dump_json()),
|
|
89
|
+
headers={'Authorization': f'Bearer {self.project_api_key}'}
|
|
90
|
+
)
|
|
91
|
+
if response.status_code != 200:
|
|
92
|
+
raise EndpointRunError(response)
|
|
93
|
+
try:
|
|
94
|
+
resp_json = response.json()
|
|
95
|
+
keys = list(resp_json.keys())
|
|
96
|
+
for key in keys:
|
|
97
|
+
value = resp_json[key]
|
|
98
|
+
del resp_json[key]
|
|
99
|
+
resp_json[to_snake(key)] = value
|
|
100
|
+
return EndpointRunResponse(**resp_json)
|
|
101
|
+
except:
|
|
102
|
+
raise EndpointRunError(response)
|
|
103
|
+
|
|
104
|
+
def _run_websocket(
|
|
105
|
+
self,
|
|
106
|
+
endpoint: str,
|
|
107
|
+
inputs: dict[str, NodeInput],
|
|
108
|
+
env: dict[str, str] = {},
|
|
109
|
+
metadata: dict[str, str] = {},
|
|
110
|
+
tools: list[Callable[..., NodeInput]] = [],
|
|
111
|
+
) -> EndpointRunResponse:
|
|
112
|
+
try:
|
|
113
|
+
request = EndpointRunRequest(
|
|
114
|
+
inputs = inputs,
|
|
115
|
+
endpoint = endpoint,
|
|
116
|
+
env = env,
|
|
117
|
+
metadata = metadata
|
|
118
|
+
)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise ValueError(f'Invalid request: {e}')
|
|
121
|
+
|
|
122
|
+
with connect(
|
|
123
|
+
"ws://localhost:8000/v2/endpoint/ws",
|
|
124
|
+
additional_headers={
|
|
125
|
+
'Authorization': f'Bearer {self.project_api_key}'
|
|
126
|
+
}
|
|
127
|
+
) as websocket:
|
|
128
|
+
websocket.send(request.model_dump_json())
|
|
129
|
+
|
|
130
|
+
while True:
|
|
131
|
+
message = websocket.recv()
|
|
132
|
+
try:
|
|
133
|
+
tool_call = ToolCall.model_validate_json(message)
|
|
134
|
+
matching_tools = [
|
|
135
|
+
tool for tool in tools
|
|
136
|
+
if tool.__name__ == tool_call.function.name
|
|
137
|
+
]
|
|
138
|
+
if not matching_tools:
|
|
139
|
+
raise SDKError(
|
|
140
|
+
f'Tool {tool_call.function.name} not found.'
|
|
141
|
+
' Registered tools: '
|
|
142
|
+
f'{", ".join([tool.__name__ for tool in tools])}'
|
|
143
|
+
)
|
|
144
|
+
tool = matching_tools[0]
|
|
145
|
+
if tool.__name__ == tool_call.function.name:
|
|
146
|
+
# default the arguments to an empty dictionary
|
|
147
|
+
arguments = {}
|
|
148
|
+
try:
|
|
149
|
+
arguments = json.loads(tool_call.function.arguments)
|
|
150
|
+
except:
|
|
151
|
+
pass
|
|
152
|
+
response = tool(**arguments)
|
|
153
|
+
websocket.send(json.dumps(response))
|
|
154
|
+
except pydantic.ValidationError as e:
|
|
155
|
+
message_json = json.loads(message)
|
|
156
|
+
keys = list(message_json.keys())
|
|
157
|
+
for key in keys:
|
|
158
|
+
value = message_json[key]
|
|
159
|
+
del message_json[key]
|
|
160
|
+
message_json[to_snake(key)] = value
|
|
161
|
+
result = EndpointRunResponse.model_validate(message_json)
|
|
162
|
+
websocket.close()
|
|
163
|
+
return result
|
|
164
|
+
except Exception:
|
|
165
|
+
websocket.close()
|
|
166
|
+
raise SDKError('Error communicating to backend through websocket')
|
lmnr/types.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import pydantic
|
|
3
|
+
from typing import Any, Union, Optional
|
|
4
|
+
|
|
5
|
+
class ChatMessage(pydantic.BaseModel):
|
|
6
|
+
"""Chat message object
|
|
7
|
+
|
|
8
|
+
Attributes:
|
|
9
|
+
role (str):
|
|
10
|
+
Role of the message sender.
|
|
11
|
+
Can be 'user', 'assistant', or 'system'.
|
|
12
|
+
content (str): Message content
|
|
13
|
+
"""
|
|
14
|
+
role: str
|
|
15
|
+
content: str
|
|
16
|
+
|
|
17
|
+
"""NodeInput is a common type for values shared
|
|
18
|
+
between nodes in Laminar pipelines."""
|
|
19
|
+
NodeInput = Union[str, list[ChatMessage]] # TypeAlias
|
|
20
|
+
|
|
21
|
+
class EndpointRunRequest(pydantic.BaseModel):
|
|
22
|
+
inputs: dict[str, NodeInput]
|
|
23
|
+
endpoint: str
|
|
24
|
+
env: dict[str, str] = pydantic.Field(default_factory=dict)
|
|
25
|
+
metadata: dict[str, str] = pydantic.Field(default_factory=dict)
|
|
26
|
+
|
|
27
|
+
class EndpointRunResponse(pydantic.BaseModel):
|
|
28
|
+
"""Response object from endpoint run
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
outputs (dict[str, dict[str, NodeInput]]):
|
|
32
|
+
Dictionary of output names and their values.
|
|
33
|
+
Each value is a dictionary with the following keys:
|
|
34
|
+
- 'value': Output value
|
|
35
|
+
run_id (str): Stringified UUID of the run. Useful to find traces.
|
|
36
|
+
"""
|
|
37
|
+
outputs: dict[str, dict[str, NodeInput]]
|
|
38
|
+
run_id: str
|
|
39
|
+
|
|
40
|
+
def get_output(self, output_name: str) -> Optional[NodeInput]:
|
|
41
|
+
"""utility to extract the output value by node name
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
output_name (str): must match the output node name
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Optional[NodeInput]: value of the output node if found
|
|
48
|
+
"""
|
|
49
|
+
output = self.outputs.get(output_name)
|
|
50
|
+
return output.get('value') if output else None
|
|
51
|
+
|
|
52
|
+
class EndpointRunError(Exception):
|
|
53
|
+
error_code: str
|
|
54
|
+
error_message: str
|
|
55
|
+
|
|
56
|
+
def __init__(self, response: requests.Response):
|
|
57
|
+
try:
|
|
58
|
+
resp_json = response.json()
|
|
59
|
+
self.error_code = resp_json['error_code']
|
|
60
|
+
self.error_message = resp_json['error_message']
|
|
61
|
+
super().__init__(self.error_message)
|
|
62
|
+
except:
|
|
63
|
+
super().__init__(response.text)
|
|
64
|
+
|
|
65
|
+
def __str__(self) -> str:
|
|
66
|
+
try:
|
|
67
|
+
return str({
|
|
68
|
+
'error_code': self.error_code,
|
|
69
|
+
'error_message': self.error_message
|
|
70
|
+
})
|
|
71
|
+
except:
|
|
72
|
+
return super().__str__()
|
|
73
|
+
|
|
74
|
+
class SDKError(Exception):
|
|
75
|
+
error_message: str
|
|
76
|
+
|
|
77
|
+
def __init__(self, error_message: str):
|
|
78
|
+
self.error_message = error_message
|
|
79
|
+
super().__init__(self.error_message)
|
|
80
|
+
|
|
81
|
+
def __str__(self) -> str:
|
|
82
|
+
return super().__str__()
|
|
83
|
+
|
|
84
|
+
class ToolFunctionCall(pydantic.BaseModel):
|
|
85
|
+
name: str
|
|
86
|
+
arguments: str
|
|
87
|
+
|
|
88
|
+
class ToolCall(pydantic.BaseModel):
|
|
89
|
+
id: str
|
|
90
|
+
type: str
|
|
91
|
+
function: ToolFunctionCall
|
|
92
|
+
|
|
93
|
+
ToolResponse = NodeInput # TypeAlias
|