chainlit 0.7.604rc2__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +32 -23
- chainlit/auth.py +9 -10
- chainlit/cache.py +3 -3
- chainlit/cli/__init__.py +12 -2
- chainlit/config.py +22 -13
- chainlit/context.py +7 -3
- chainlit/data/__init__.py +375 -9
- chainlit/data/acl.py +6 -5
- chainlit/element.py +86 -123
- chainlit/emitter.py +117 -50
- chainlit/frontend/dist/assets/index-6aee009a.js +697 -0
- chainlit/frontend/dist/assets/{react-plotly-16f7de12.js → react-plotly-2f07c02a.js} +1 -1
- chainlit/frontend/dist/index.html +1 -1
- chainlit/haystack/callbacks.py +45 -43
- chainlit/hello.py +1 -1
- chainlit/langchain/callbacks.py +135 -120
- chainlit/llama_index/callbacks.py +68 -48
- chainlit/message.py +179 -207
- chainlit/oauth_providers.py +39 -34
- chainlit/playground/provider.py +44 -30
- chainlit/playground/providers/anthropic.py +4 -4
- chainlit/playground/providers/huggingface.py +2 -2
- chainlit/playground/providers/langchain.py +8 -10
- chainlit/playground/providers/openai.py +19 -13
- chainlit/server.py +155 -99
- chainlit/session.py +109 -40
- chainlit/socket.py +54 -38
- chainlit/step.py +393 -0
- chainlit/types.py +78 -21
- chainlit/user.py +32 -0
- chainlit/user_session.py +1 -5
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/METADATA +12 -31
- chainlit-1.0.0rc0.dist-info/RECORD +60 -0
- chainlit/client/base.py +0 -169
- chainlit/client/cloud.py +0 -500
- chainlit/frontend/dist/assets/index-c58dbd4b.js +0 -871
- chainlit/prompt.py +0 -40
- chainlit-0.7.604rc2.dist-info/RECORD +0 -61
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/WHEEL +0 -0
- {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/entry_points.txt +0 -0
chainlit/socket.py
CHANGED
|
@@ -6,7 +6,7 @@ from chainlit.action import Action
|
|
|
6
6
|
from chainlit.auth import get_current_user, require_login
|
|
7
7
|
from chainlit.config import config
|
|
8
8
|
from chainlit.context import init_ws_context
|
|
9
|
-
from chainlit.data import
|
|
9
|
+
from chainlit.data import get_data_layer
|
|
10
10
|
from chainlit.logger import logger
|
|
11
11
|
from chainlit.message import ErrorMessage, Message
|
|
12
12
|
from chainlit.server import socket
|
|
@@ -27,39 +27,33 @@ def restore_existing_session(sid, session_id, emit_fn, ask_user_fn):
|
|
|
27
27
|
return False
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
async def persist_user_session(
|
|
31
|
-
if
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
await chainlit_client.update_conversation_metadata(
|
|
35
|
-
conversation_id=conversation_id, metadata=metadata
|
|
36
|
-
)
|
|
30
|
+
async def persist_user_session(thread_id: str, metadata: Dict):
|
|
31
|
+
if data_layer := get_data_layer():
|
|
32
|
+
await data_layer.update_thread(thread_id=thread_id, metadata=metadata)
|
|
37
33
|
|
|
38
34
|
|
|
39
|
-
async def
|
|
40
|
-
|
|
35
|
+
async def resume_thread(session: WebsocketSession):
|
|
36
|
+
data_layer = get_data_layer()
|
|
37
|
+
if not data_layer or not session.user or not session.thread_id_to_resume:
|
|
38
|
+
return
|
|
39
|
+
thread = await data_layer.get_thread(thread_id=session.thread_id_to_resume)
|
|
40
|
+
if not thread:
|
|
41
41
|
return
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
)
|
|
43
|
+
author = thread.get("user").get("identifier") if thread["user"] else None
|
|
44
|
+
user_is_author = author == session.user.identifier
|
|
46
45
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
user_is_author = author == session.user.username
|
|
51
|
-
|
|
52
|
-
if conversation and user_is_author:
|
|
53
|
-
metadata = conversation["metadata"] or {}
|
|
54
|
-
user_sessions[session.id] = metadata
|
|
46
|
+
if user_is_author:
|
|
47
|
+
metadata = thread["metadata"] or {}
|
|
48
|
+
user_sessions[session.id] = metadata.copy()
|
|
55
49
|
if chat_profile := metadata.get("chat_profile"):
|
|
56
50
|
session.chat_profile = chat_profile
|
|
57
51
|
if chat_settings := metadata.get("chat_settings"):
|
|
58
52
|
session.chat_settings = chat_settings
|
|
59
53
|
|
|
60
|
-
trace_event("
|
|
54
|
+
trace_event("thread_resumed")
|
|
61
55
|
|
|
62
|
-
return
|
|
56
|
+
return thread
|
|
63
57
|
|
|
64
58
|
|
|
65
59
|
def load_user_env(user_env):
|
|
@@ -128,9 +122,8 @@ async def connect(sid, environ, auth):
|
|
|
128
122
|
user=user,
|
|
129
123
|
token=token,
|
|
130
124
|
chat_profile=environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE"),
|
|
131
|
-
|
|
125
|
+
thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"),
|
|
132
126
|
)
|
|
133
|
-
|
|
134
127
|
trace_event("connection_successful")
|
|
135
128
|
return True
|
|
136
129
|
|
|
@@ -142,22 +135,27 @@ async def connection_successful(sid):
|
|
|
142
135
|
if context.session.restored:
|
|
143
136
|
return
|
|
144
137
|
|
|
145
|
-
if context.session.
|
|
146
|
-
|
|
147
|
-
if
|
|
138
|
+
if context.session.thread_id_to_resume and config.code.on_chat_resume:
|
|
139
|
+
thread = await resume_thread(context.session)
|
|
140
|
+
if thread:
|
|
148
141
|
context.session.has_user_message = True
|
|
149
|
-
await
|
|
150
|
-
await context.emitter.
|
|
142
|
+
await context.emitter.clear_ask()
|
|
143
|
+
await context.emitter.resume_thread(thread)
|
|
144
|
+
await config.code.on_chat_resume(thread)
|
|
151
145
|
return
|
|
152
146
|
|
|
153
147
|
if config.code.on_chat_start:
|
|
154
148
|
"""Call the on_chat_start function provided by the developer."""
|
|
149
|
+
await context.emitter.clear_ask()
|
|
155
150
|
await config.code.on_chat_start()
|
|
156
151
|
|
|
157
152
|
|
|
158
153
|
@socket.on("clear_session")
|
|
159
154
|
async def clean_session(sid):
|
|
160
155
|
if session := WebsocketSession.get(sid):
|
|
156
|
+
if config.code.on_chat_end:
|
|
157
|
+
init_ws_context(session)
|
|
158
|
+
await config.code.on_chat_end()
|
|
161
159
|
# Clean up the user session
|
|
162
160
|
if session.id in user_sessions:
|
|
163
161
|
user_sessions.pop(session.id)
|
|
@@ -169,13 +167,14 @@ async def clean_session(sid):
|
|
|
169
167
|
@socket.on("disconnect")
|
|
170
168
|
async def disconnect(sid):
|
|
171
169
|
session = WebsocketSession.get(sid)
|
|
172
|
-
if
|
|
170
|
+
if session:
|
|
173
171
|
init_ws_context(session)
|
|
174
|
-
|
|
172
|
+
|
|
173
|
+
if config.code.on_chat_end and session:
|
|
175
174
|
await config.code.on_chat_end()
|
|
176
175
|
|
|
177
|
-
if session and session.
|
|
178
|
-
await persist_user_session(session.
|
|
176
|
+
if session and session.thread_id and session.has_user_message:
|
|
177
|
+
await persist_user_session(session.thread_id, session.to_persistable())
|
|
179
178
|
|
|
180
179
|
async def disconnect_on_timeout(sid):
|
|
181
180
|
await asyncio.sleep(config.project.session_timeout)
|
|
@@ -195,7 +194,9 @@ async def stop(sid):
|
|
|
195
194
|
trace_event("stop_task")
|
|
196
195
|
|
|
197
196
|
init_ws_context(session)
|
|
198
|
-
await Message(
|
|
197
|
+
await Message(
|
|
198
|
+
author="System", content="Task stopped by the user.", disable_feedback=True
|
|
199
|
+
).send()
|
|
199
200
|
|
|
200
201
|
session.should_stop = True
|
|
201
202
|
|
|
@@ -235,7 +236,8 @@ async def message(sid, payload: UIMessagePayload):
|
|
|
235
236
|
async def process_action(action: Action):
|
|
236
237
|
callback = config.code.action_callbacks.get(action.name)
|
|
237
238
|
if callback:
|
|
238
|
-
await callback(action)
|
|
239
|
+
res = await callback(action)
|
|
240
|
+
return res
|
|
239
241
|
else:
|
|
240
242
|
logger.warning("No callback found for action %s", action.name)
|
|
241
243
|
|
|
@@ -243,11 +245,25 @@ async def process_action(action: Action):
|
|
|
243
245
|
@socket.on("action_call")
|
|
244
246
|
async def call_action(sid, action):
|
|
245
247
|
"""Handle an action call from the UI."""
|
|
246
|
-
init_ws_context(sid)
|
|
248
|
+
context = init_ws_context(sid)
|
|
247
249
|
|
|
248
250
|
action = Action(**action)
|
|
249
251
|
|
|
250
|
-
|
|
252
|
+
try:
|
|
253
|
+
res = await process_action(action)
|
|
254
|
+
await context.emitter.send_action_response(
|
|
255
|
+
id=action.id, status=True, response=res if isinstance(res, str) else None
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
except InterruptedError:
|
|
259
|
+
await context.emitter.send_action_response(
|
|
260
|
+
id=action.id, status=False, response="Action interrupted by the user"
|
|
261
|
+
)
|
|
262
|
+
except Exception as e:
|
|
263
|
+
logger.exception(e)
|
|
264
|
+
await context.emitter.send_action_response(
|
|
265
|
+
id=action.id, status=False, response="An error occured"
|
|
266
|
+
)
|
|
251
267
|
|
|
252
268
|
|
|
253
269
|
@socket.on("chat_settings_change")
|
chainlit/step.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
import uuid
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from typing import Callable, Dict, List, Optional, TypedDict, Union
|
|
8
|
+
|
|
9
|
+
from chainlit.config import config
|
|
10
|
+
from chainlit.context import context
|
|
11
|
+
from chainlit.data import get_data_layer
|
|
12
|
+
from chainlit.element import Element
|
|
13
|
+
from chainlit.logger import logger
|
|
14
|
+
from chainlit.telemetry import trace_event
|
|
15
|
+
from chainlit.types import FeedbackDict
|
|
16
|
+
from chainlit_client import BaseGeneration
|
|
17
|
+
from chainlit_client.step import StepType, TrueStepType
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StepDict(TypedDict, total=False):
|
|
21
|
+
name: str
|
|
22
|
+
type: StepType
|
|
23
|
+
id: str
|
|
24
|
+
threadId: str
|
|
25
|
+
parentId: Optional[str]
|
|
26
|
+
disableFeedback: bool
|
|
27
|
+
streaming: bool
|
|
28
|
+
waitForAnswer: Optional[bool]
|
|
29
|
+
isError: Optional[bool]
|
|
30
|
+
metadata: Dict
|
|
31
|
+
input: str
|
|
32
|
+
output: str
|
|
33
|
+
createdAt: Optional[str]
|
|
34
|
+
start: Optional[str]
|
|
35
|
+
end: Optional[str]
|
|
36
|
+
generation: Optional[Dict]
|
|
37
|
+
showInput: Optional[Union[bool, str]]
|
|
38
|
+
language: Optional[str]
|
|
39
|
+
indent: Optional[int]
|
|
40
|
+
feedback: Optional[FeedbackDict]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def step(
|
|
44
|
+
original_function: Optional[Callable] = None,
|
|
45
|
+
*,
|
|
46
|
+
name: Optional[str] = "",
|
|
47
|
+
type: TrueStepType = "undefined",
|
|
48
|
+
id: Optional[str] = None,
|
|
49
|
+
disable_feedback: bool = True,
|
|
50
|
+
root: bool = False,
|
|
51
|
+
language: Optional[str] = None,
|
|
52
|
+
show_input: Union[bool, str] = False,
|
|
53
|
+
):
|
|
54
|
+
"""Step decorator for async and sync functions."""
|
|
55
|
+
|
|
56
|
+
def wrapper(func: Callable):
|
|
57
|
+
nonlocal name
|
|
58
|
+
if not name:
|
|
59
|
+
name = func.__name__
|
|
60
|
+
|
|
61
|
+
# Handle async decorator
|
|
62
|
+
|
|
63
|
+
if inspect.iscoroutinefunction(func):
|
|
64
|
+
|
|
65
|
+
@wraps(func)
|
|
66
|
+
async def async_wrapper(*args, **kwargs):
|
|
67
|
+
async with Step(
|
|
68
|
+
type=type,
|
|
69
|
+
name=name,
|
|
70
|
+
id=id,
|
|
71
|
+
disable_feedback=disable_feedback,
|
|
72
|
+
root=root,
|
|
73
|
+
language=language,
|
|
74
|
+
show_input=show_input,
|
|
75
|
+
) as step:
|
|
76
|
+
try:
|
|
77
|
+
step.input = {"args": args, "kwargs": kwargs}
|
|
78
|
+
except:
|
|
79
|
+
pass
|
|
80
|
+
result = await func(*args, **kwargs)
|
|
81
|
+
try:
|
|
82
|
+
if result and not step.output:
|
|
83
|
+
step.output = result
|
|
84
|
+
except:
|
|
85
|
+
pass
|
|
86
|
+
return result
|
|
87
|
+
|
|
88
|
+
return async_wrapper
|
|
89
|
+
else:
|
|
90
|
+
# Handle sync decorator
|
|
91
|
+
@wraps(func)
|
|
92
|
+
def sync_wrapper(*args, **kwargs):
|
|
93
|
+
with Step(
|
|
94
|
+
type=type,
|
|
95
|
+
name=name,
|
|
96
|
+
id=id,
|
|
97
|
+
disable_feedback=disable_feedback,
|
|
98
|
+
root=root,
|
|
99
|
+
language=language,
|
|
100
|
+
show_input=show_input,
|
|
101
|
+
) as step:
|
|
102
|
+
try:
|
|
103
|
+
step.input = {"args": args, "kwargs": kwargs}
|
|
104
|
+
except:
|
|
105
|
+
pass
|
|
106
|
+
result = func(*args, **kwargs)
|
|
107
|
+
try:
|
|
108
|
+
if result and not step.output:
|
|
109
|
+
step.output = result
|
|
110
|
+
except:
|
|
111
|
+
pass
|
|
112
|
+
return result
|
|
113
|
+
|
|
114
|
+
return sync_wrapper
|
|
115
|
+
|
|
116
|
+
func = original_function
|
|
117
|
+
if not func:
|
|
118
|
+
return wrapper
|
|
119
|
+
else:
|
|
120
|
+
return wrapper(func)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class Step:
|
|
124
|
+
# Constructor
|
|
125
|
+
name: str
|
|
126
|
+
type: TrueStepType
|
|
127
|
+
id: str
|
|
128
|
+
parent_id: Optional[str]
|
|
129
|
+
disable_feedback: bool
|
|
130
|
+
|
|
131
|
+
streaming: bool
|
|
132
|
+
persisted: bool
|
|
133
|
+
|
|
134
|
+
root: bool
|
|
135
|
+
show_input: Union[bool, str]
|
|
136
|
+
|
|
137
|
+
is_error: Optional[bool]
|
|
138
|
+
metadata: Dict
|
|
139
|
+
thread_id: str
|
|
140
|
+
created_at: Union[str, None]
|
|
141
|
+
start: Union[str, None]
|
|
142
|
+
end: Union[str, None]
|
|
143
|
+
generation: Optional[BaseGeneration]
|
|
144
|
+
language: Optional[str]
|
|
145
|
+
elements: Optional[List[Element]]
|
|
146
|
+
fail_on_persist_error: bool
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
name: Optional[str] = config.ui.name,
|
|
151
|
+
type: TrueStepType = "undefined",
|
|
152
|
+
id: Optional[str] = None,
|
|
153
|
+
parent_id: Optional[str] = None,
|
|
154
|
+
elements: Optional[List[Element]] = None,
|
|
155
|
+
disable_feedback: bool = True,
|
|
156
|
+
root: bool = False,
|
|
157
|
+
language: Optional[str] = None,
|
|
158
|
+
show_input: Union[bool, str] = False,
|
|
159
|
+
):
|
|
160
|
+
trace_event(f"init {self.__class__.__name__} {type}")
|
|
161
|
+
self._input = ""
|
|
162
|
+
self._output = ""
|
|
163
|
+
self.thread_id = context.session.thread_id
|
|
164
|
+
self.name = name or ""
|
|
165
|
+
self.type = type
|
|
166
|
+
self.id = id or str(uuid.uuid4())
|
|
167
|
+
self.disable_feedback = disable_feedback
|
|
168
|
+
self.metadata = {}
|
|
169
|
+
self.is_error = False
|
|
170
|
+
self.show_input = show_input
|
|
171
|
+
self.parent_id = parent_id
|
|
172
|
+
self.root = root
|
|
173
|
+
|
|
174
|
+
self.language = language
|
|
175
|
+
self.generation = None
|
|
176
|
+
self.elements = elements or []
|
|
177
|
+
|
|
178
|
+
self.created_at = datetime.utcnow().isoformat()
|
|
179
|
+
self.start = None
|
|
180
|
+
self.end = None
|
|
181
|
+
|
|
182
|
+
self.streaming = False
|
|
183
|
+
self.persisted = False
|
|
184
|
+
self.fail_on_persist_error = False
|
|
185
|
+
|
|
186
|
+
def _process_content(self, content, set_language=False):
|
|
187
|
+
if content is None:
|
|
188
|
+
return ""
|
|
189
|
+
if isinstance(content, dict):
|
|
190
|
+
try:
|
|
191
|
+
processed_content = json.dumps(content, indent=4, ensure_ascii=False)
|
|
192
|
+
if set_language:
|
|
193
|
+
self.language = "json"
|
|
194
|
+
except TypeError:
|
|
195
|
+
processed_content = str(content)
|
|
196
|
+
if set_language:
|
|
197
|
+
self.language = "text"
|
|
198
|
+
elif isinstance(content, str):
|
|
199
|
+
processed_content = content
|
|
200
|
+
else:
|
|
201
|
+
processed_content = str(content)
|
|
202
|
+
if set_language:
|
|
203
|
+
self.language = "text"
|
|
204
|
+
return processed_content
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def input(self):
|
|
208
|
+
return self._input
|
|
209
|
+
|
|
210
|
+
@input.setter
|
|
211
|
+
def input(self, content: Union[Dict, str]):
|
|
212
|
+
self._input = self._process_content(content, set_language=False)
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def output(self):
|
|
216
|
+
return self._output
|
|
217
|
+
|
|
218
|
+
@output.setter
|
|
219
|
+
def output(self, content: Union[Dict, str]):
|
|
220
|
+
self._output = self._process_content(content, set_language=True)
|
|
221
|
+
|
|
222
|
+
def to_dict(self) -> StepDict:
|
|
223
|
+
_dict: StepDict = {
|
|
224
|
+
"name": self.name,
|
|
225
|
+
"type": self.type,
|
|
226
|
+
"id": self.id,
|
|
227
|
+
"threadId": self.thread_id,
|
|
228
|
+
"parentId": self.parent_id,
|
|
229
|
+
"disableFeedback": self.disable_feedback,
|
|
230
|
+
"streaming": self.streaming,
|
|
231
|
+
"metadata": self.metadata,
|
|
232
|
+
"input": self.input,
|
|
233
|
+
"isError": self.is_error,
|
|
234
|
+
"output": self.output,
|
|
235
|
+
"createdAt": self.created_at,
|
|
236
|
+
"start": self.start,
|
|
237
|
+
"end": self.end,
|
|
238
|
+
"language": self.language,
|
|
239
|
+
"showInput": self.show_input,
|
|
240
|
+
"generation": self.generation.to_dict() if self.generation else None,
|
|
241
|
+
}
|
|
242
|
+
return _dict
|
|
243
|
+
|
|
244
|
+
async def update(self):
|
|
245
|
+
"""
|
|
246
|
+
Update a step already sent to the UI.
|
|
247
|
+
"""
|
|
248
|
+
trace_event("update_step")
|
|
249
|
+
|
|
250
|
+
if self.streaming:
|
|
251
|
+
self.streaming = False
|
|
252
|
+
|
|
253
|
+
step_dict = self.to_dict()
|
|
254
|
+
data_layer = get_data_layer()
|
|
255
|
+
|
|
256
|
+
if data_layer:
|
|
257
|
+
try:
|
|
258
|
+
asyncio.create_task(data_layer.update_step(step_dict))
|
|
259
|
+
except Exception as e:
|
|
260
|
+
if self.fail_on_persist_error:
|
|
261
|
+
raise e
|
|
262
|
+
logger.error(f"Failed to persist step update: {str(e)}")
|
|
263
|
+
|
|
264
|
+
tasks = [el.send(for_id=self.id) for el in self.elements]
|
|
265
|
+
await asyncio.gather(*tasks)
|
|
266
|
+
|
|
267
|
+
if not config.features.prompt_playground and "generation" in step_dict:
|
|
268
|
+
step_dict.pop("generation", None)
|
|
269
|
+
|
|
270
|
+
await context.emitter.update_step(step_dict)
|
|
271
|
+
|
|
272
|
+
return True
|
|
273
|
+
|
|
274
|
+
async def remove(self):
|
|
275
|
+
"""
|
|
276
|
+
Remove a step already sent to the UI.
|
|
277
|
+
"""
|
|
278
|
+
trace_event("remove_step")
|
|
279
|
+
|
|
280
|
+
step_dict = self.to_dict()
|
|
281
|
+
data_layer = get_data_layer()
|
|
282
|
+
|
|
283
|
+
if data_layer:
|
|
284
|
+
try:
|
|
285
|
+
asyncio.create_task(data_layer.delete_step(self.id))
|
|
286
|
+
except Exception as e:
|
|
287
|
+
if self.fail_on_persist_error:
|
|
288
|
+
raise e
|
|
289
|
+
logger.error(f"Failed to persist step deletion: {str(e)}")
|
|
290
|
+
|
|
291
|
+
await context.emitter.delete_step(step_dict)
|
|
292
|
+
|
|
293
|
+
return True
|
|
294
|
+
|
|
295
|
+
async def send(self):
|
|
296
|
+
if self.persisted:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
if config.code.author_rename:
|
|
300
|
+
self.name = await config.code.author_rename(self.name)
|
|
301
|
+
|
|
302
|
+
if self.streaming:
|
|
303
|
+
self.streaming = False
|
|
304
|
+
|
|
305
|
+
step_dict = self.to_dict()
|
|
306
|
+
|
|
307
|
+
data_layer = get_data_layer()
|
|
308
|
+
|
|
309
|
+
if data_layer:
|
|
310
|
+
try:
|
|
311
|
+
asyncio.create_task(data_layer.create_step(step_dict))
|
|
312
|
+
self.persisted = True
|
|
313
|
+
except Exception as e:
|
|
314
|
+
if self.fail_on_persist_error:
|
|
315
|
+
raise e
|
|
316
|
+
logger.error(f"Failed to persist step creation: {str(e)}")
|
|
317
|
+
|
|
318
|
+
tasks = [el.send(for_id=self.id) for el in self.elements]
|
|
319
|
+
await asyncio.gather(*tasks)
|
|
320
|
+
|
|
321
|
+
if not config.features.prompt_playground and "generation" in step_dict:
|
|
322
|
+
step_dict.pop("generation", None)
|
|
323
|
+
|
|
324
|
+
await context.emitter.send_step(step_dict)
|
|
325
|
+
|
|
326
|
+
return self.id
|
|
327
|
+
|
|
328
|
+
async def stream_token(self, token: str, is_sequence=False):
|
|
329
|
+
"""
|
|
330
|
+
Sends a token to the UI.
|
|
331
|
+
Once all tokens have been streamed, call .send() to end the stream and persist the step if persistence is enabled.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
if not self.streaming:
|
|
335
|
+
self.streaming = True
|
|
336
|
+
step_dict = self.to_dict()
|
|
337
|
+
await context.emitter.stream_start(step_dict)
|
|
338
|
+
|
|
339
|
+
if is_sequence:
|
|
340
|
+
self.output = token
|
|
341
|
+
else:
|
|
342
|
+
self.output += token
|
|
343
|
+
|
|
344
|
+
assert self.id
|
|
345
|
+
await context.emitter.send_token(
|
|
346
|
+
id=self.id, token=token, is_sequence=is_sequence
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Handle parameter less decorator
|
|
350
|
+
def __call__(self, func):
|
|
351
|
+
return step(
|
|
352
|
+
original_function=func,
|
|
353
|
+
type=self.type,
|
|
354
|
+
name=self.name,
|
|
355
|
+
id=self.id,
|
|
356
|
+
parent_id=self.parent_id,
|
|
357
|
+
thread_id=self.thread_id,
|
|
358
|
+
disable_feedback=self.disable_feedback,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Handle Context Manager Protocol
|
|
362
|
+
async def __aenter__(self):
|
|
363
|
+
self.start = datetime.utcnow().isoformat()
|
|
364
|
+
if not self.parent_id and not self.root:
|
|
365
|
+
if current_step := context.current_step:
|
|
366
|
+
self.parent_id = current_step.id
|
|
367
|
+
elif context.session.root_message:
|
|
368
|
+
self.parent_id = context.session.root_message.id
|
|
369
|
+
context.session.active_steps.append(self)
|
|
370
|
+
await self.send()
|
|
371
|
+
return self
|
|
372
|
+
|
|
373
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
374
|
+
self.end = datetime.utcnow().isoformat()
|
|
375
|
+
context.session.active_steps.pop()
|
|
376
|
+
await self.update()
|
|
377
|
+
|
|
378
|
+
def __enter__(self):
|
|
379
|
+
self.start = datetime.utcnow().isoformat()
|
|
380
|
+
if not self.parent_id and not self.root:
|
|
381
|
+
if current_step := context.current_step:
|
|
382
|
+
self.parent_id = current_step.id
|
|
383
|
+
elif context.session.root_message:
|
|
384
|
+
self.parent_id = context.session.root_message.id
|
|
385
|
+
context.session.active_steps.append(self)
|
|
386
|
+
|
|
387
|
+
asyncio.create_task(self.send())
|
|
388
|
+
return self
|
|
389
|
+
|
|
390
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
391
|
+
self.end = datetime.utcnow().isoformat()
|
|
392
|
+
context.session.active_steps.pop()
|
|
393
|
+
asyncio.create_task(self.update())
|