chainlit 0.2.110__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +48 -32
- chainlit/action.py +12 -12
- chainlit/cache.py +20 -0
- chainlit/cli/__init__.py +45 -29
- chainlit/cli/mock.py +48 -0
- chainlit/client.py +111 -89
- chainlit/config.py +22 -3
- chainlit/element.py +95 -124
- chainlit/{sdk.py → emitter.py} +55 -64
- chainlit/frontend/dist/assets/index-0b7e367e.js +717 -0
- chainlit/frontend/dist/assets/index-0cc9e355.css +1 -0
- chainlit/frontend/dist/index.html +3 -3
- chainlit/hello.py +3 -3
- chainlit/lc/__init__.py +11 -0
- chainlit/lc/agent.py +32 -0
- chainlit/lc/callbacks.py +411 -0
- chainlit/message.py +72 -96
- chainlit/server.py +280 -195
- chainlit/session.py +4 -2
- chainlit/sync.py +37 -0
- chainlit/types.py +18 -1
- chainlit/user_session.py +16 -16
- {chainlit-0.2.110.dist-info → chainlit-0.3.0.dist-info}/METADATA +15 -14
- chainlit-0.3.0.dist-info/RECORD +37 -0
- chainlit/frontend/dist/assets/index-36bf9cab.js +0 -713
- chainlit/frontend/dist/assets/index-bdffdaa0.css +0 -1
- chainlit/lc/chainlit_handler.py +0 -271
- chainlit/lc/monkey.py +0 -28
- chainlit/lc/new_monkey.py +0 -167
- chainlit/lc/old_monkey.py +0 -119
- chainlit/lc/utils.py +0 -38
- chainlit/watch.py +0 -54
- chainlit-0.2.110.dist-info/RECORD +0 -38
- {chainlit-0.2.110.dist-info → chainlit-0.3.0.dist-info}/WHEEL +0 -0
- {chainlit-0.2.110.dist-info → chainlit-0.3.0.dist-info}/entry_points.txt +0 -0
chainlit/server.py
CHANGED
|
@@ -5,35 +5,115 @@ mimetypes.add_type("text/css", ".css")
|
|
|
5
5
|
|
|
6
6
|
import os
|
|
7
7
|
import json
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
-
from
|
|
12
|
-
|
|
8
|
+
import webbrowser
|
|
9
|
+
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
from watchfiles import awatch
|
|
12
|
+
|
|
13
|
+
from fastapi import FastAPI
|
|
14
|
+
from fastapi.responses import (
|
|
15
|
+
HTMLResponse,
|
|
16
|
+
JSONResponse,
|
|
17
|
+
FileResponse,
|
|
18
|
+
PlainTextResponse,
|
|
19
|
+
)
|
|
20
|
+
from fastapi.staticfiles import StaticFiles
|
|
21
|
+
from fastapi_socketio import SocketManager
|
|
22
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
23
|
+
import asyncio
|
|
24
|
+
|
|
25
|
+
from chainlit.config import config, load_module, DEFAULT_HOST
|
|
13
26
|
from chainlit.session import Session, sessions
|
|
14
27
|
from chainlit.user_session import user_sessions
|
|
15
28
|
from chainlit.client import CloudClient
|
|
16
|
-
from chainlit.
|
|
29
|
+
from chainlit.emitter import ChainlitEmitter
|
|
17
30
|
from chainlit.markdown import get_markdown_str
|
|
18
31
|
from chainlit.action import Action
|
|
19
32
|
from chainlit.message import Message, ErrorMessage
|
|
20
|
-
from chainlit.telemetry import
|
|
33
|
+
from chainlit.telemetry import trace_event
|
|
21
34
|
from chainlit.logger import logger
|
|
35
|
+
from chainlit.types import CompletionRequest
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@asynccontextmanager
|
|
39
|
+
async def lifespan(app: FastAPI):
|
|
40
|
+
host = config.run_settings.host
|
|
41
|
+
port = config.run_settings.port
|
|
42
|
+
|
|
43
|
+
if not config.run_settings.headless:
|
|
44
|
+
if host == DEFAULT_HOST:
|
|
45
|
+
url = f"http://localhost:{port}"
|
|
46
|
+
else:
|
|
47
|
+
url = f"http://{host}:{port}"
|
|
48
|
+
|
|
49
|
+
logger.info(f"Your app is available at {url}")
|
|
50
|
+
webbrowser.open(url)
|
|
51
|
+
|
|
52
|
+
watch_task = None
|
|
53
|
+
stop_event = asyncio.Event()
|
|
54
|
+
|
|
55
|
+
if config.run_settings.watch:
|
|
56
|
+
|
|
57
|
+
async def watch_files_for_changes():
|
|
58
|
+
async for changes in awatch(config.root, stop_event=stop_event):
|
|
59
|
+
for change_type, file_path in changes:
|
|
60
|
+
file_name = os.path.basename(file_path)
|
|
61
|
+
file_ext = os.path.splitext(file_name)[1]
|
|
62
|
+
|
|
63
|
+
if file_ext.lower() == ".py" or file_name.lower() == "chainlit.md":
|
|
64
|
+
logger.info(f"File {change_type.name}: {file_name}")
|
|
65
|
+
|
|
66
|
+
# Reload the module if the module name is specified in the config
|
|
67
|
+
if config.module_name:
|
|
68
|
+
load_module(config.module_name)
|
|
69
|
+
|
|
70
|
+
await socket.emit("reload", {})
|
|
71
|
+
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
watch_task = asyncio.create_task(watch_files_for_changes())
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
yield
|
|
78
|
+
except KeyboardInterrupt:
|
|
79
|
+
logger.error("KeyboardInterrupt received, stopping the watch task...")
|
|
80
|
+
finally:
|
|
81
|
+
if watch_task:
|
|
82
|
+
stop_event.set()
|
|
83
|
+
await watch_task
|
|
84
|
+
|
|
22
85
|
|
|
23
86
|
root_dir = os.path.dirname(os.path.abspath(__file__))
|
|
24
87
|
build_dir = os.path.join(root_dir, "frontend/dist")
|
|
25
88
|
|
|
26
|
-
app =
|
|
27
|
-
|
|
28
|
-
|
|
89
|
+
app = FastAPI(lifespan=lifespan)
|
|
90
|
+
app.mount("/static", StaticFiles(directory=build_dir), name="static")
|
|
91
|
+
app.add_middleware(
|
|
92
|
+
CORSMiddleware,
|
|
93
|
+
allow_origins=["*"],
|
|
94
|
+
allow_credentials=True,
|
|
95
|
+
allow_methods=["*"],
|
|
96
|
+
allow_headers=["*"],
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Define max HTTP data size to 100 MB
|
|
100
|
+
max_http_data_size = 100 * 1024 * 1024
|
|
101
|
+
|
|
102
|
+
socket = SocketManager(
|
|
29
103
|
app,
|
|
30
|
-
cors_allowed_origins=
|
|
31
|
-
async_mode="
|
|
32
|
-
max_http_buffer_size=
|
|
104
|
+
cors_allowed_origins=[],
|
|
105
|
+
async_mode="asgi",
|
|
106
|
+
max_http_buffer_size=max_http_data_size,
|
|
33
107
|
)
|
|
34
108
|
|
|
109
|
+
"""
|
|
110
|
+
-------------------------------------------------------------------------------
|
|
111
|
+
HTTP HANDLERS
|
|
112
|
+
-------------------------------------------------------------------------------
|
|
113
|
+
"""
|
|
114
|
+
|
|
35
115
|
|
|
36
|
-
def
|
|
116
|
+
def get_html_template():
|
|
37
117
|
PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
|
|
38
118
|
|
|
39
119
|
default_url = "https://github.com/Chainlit/chainlit"
|
|
@@ -47,226 +127,226 @@ def inject_html_tags():
|
|
|
47
127
|
<meta property="og:image" content="https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png">
|
|
48
128
|
<meta property="og:url" content="{url}">"""
|
|
49
129
|
|
|
50
|
-
|
|
51
|
-
injected_index_html_file_path = os.path.join(app.static_folder, "_index.html")
|
|
130
|
+
index_html_file_path = os.path.join(build_dir, "index.html")
|
|
52
131
|
|
|
53
|
-
with open(
|
|
132
|
+
with open(index_html_file_path, "r", encoding="utf-8") as f:
|
|
54
133
|
content = f.read()
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
with open(injected_index_html_file_path, "w", encoding="utf-8") as f:
|
|
58
|
-
f.write(content)
|
|
134
|
+
content = content.replace(PLACEHOLDER, tags)
|
|
135
|
+
return content
|
|
59
136
|
|
|
60
137
|
|
|
61
|
-
|
|
138
|
+
html_template = get_html_template()
|
|
62
139
|
|
|
63
140
|
|
|
64
|
-
@app.
|
|
65
|
-
|
|
66
|
-
def serve(path):
|
|
67
|
-
"""Serve the UI."""
|
|
68
|
-
if path != "" and os.path.exists(app.static_folder + "/" + path):
|
|
69
|
-
return send_from_directory(app.static_folder, path)
|
|
70
|
-
else:
|
|
71
|
-
return send_from_directory(app.static_folder, "_index.html")
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
@app.route("/completion", methods=["POST"])
|
|
75
|
-
@trace
|
|
76
|
-
def completion():
|
|
141
|
+
@app.post("/completion")
|
|
142
|
+
async def completion(completion: CompletionRequest):
|
|
77
143
|
"""Handle a completion request from the prompt playground."""
|
|
78
144
|
|
|
79
145
|
import openai
|
|
80
146
|
|
|
81
|
-
|
|
82
|
-
llm_settings = data["settings"]
|
|
83
|
-
user_env = data.get("userEnv", {})
|
|
147
|
+
trace_event("completion")
|
|
84
148
|
|
|
85
|
-
api_key =
|
|
149
|
+
api_key = completion.userEnv.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
|
86
150
|
|
|
87
|
-
model_name =
|
|
88
|
-
stop =
|
|
151
|
+
model_name = completion.settings.model_name
|
|
152
|
+
stop = completion.settings.stop
|
|
89
153
|
# OpenAI doesn't support an empty stop array, clear it
|
|
90
154
|
if isinstance(stop, list) and len(stop) == 0:
|
|
91
155
|
stop = None
|
|
92
156
|
|
|
93
157
|
if model_name in ["gpt-3.5-turbo", "gpt-4"]:
|
|
94
|
-
response = openai.ChatCompletion.
|
|
158
|
+
response = await openai.ChatCompletion.acreate(
|
|
95
159
|
api_key=api_key,
|
|
96
160
|
model=model_name,
|
|
97
|
-
messages=[{"role": "user", "content":
|
|
161
|
+
messages=[{"role": "user", "content": completion.prompt}],
|
|
98
162
|
stop=stop,
|
|
99
|
-
**
|
|
163
|
+
**completion.settings.to_settings_dict(),
|
|
100
164
|
)
|
|
101
|
-
return response["choices"][0]["message"]["content"]
|
|
165
|
+
return PlainTextResponse(content=response["choices"][0]["message"]["content"])
|
|
102
166
|
else:
|
|
103
|
-
response = openai.Completion.
|
|
167
|
+
response = await openai.Completion.acreate(
|
|
104
168
|
api_key=api_key,
|
|
105
169
|
model=model_name,
|
|
106
|
-
prompt=
|
|
170
|
+
prompt=completion.prompt,
|
|
107
171
|
stop=stop,
|
|
108
|
-
**
|
|
172
|
+
**completion.settings.to_settings_dict(),
|
|
109
173
|
)
|
|
110
|
-
return response["choices"][0]["text"]
|
|
174
|
+
return PlainTextResponse(content=response["choices"][0]["text"])
|
|
111
175
|
|
|
112
176
|
|
|
113
|
-
@app.
|
|
114
|
-
def project_settings():
|
|
177
|
+
@app.get("/project/settings")
|
|
178
|
+
async def project_settings():
|
|
115
179
|
"""Return project settings. This is called by the UI before the establishing the websocket connection."""
|
|
116
|
-
return
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
180
|
+
return JSONResponse(
|
|
181
|
+
content={
|
|
182
|
+
"public": config.public,
|
|
183
|
+
"projectId": config.project_id,
|
|
184
|
+
"chainlitServer": config.chainlit_server,
|
|
185
|
+
"userEnv": config.user_env,
|
|
186
|
+
"hideCot": config.hide_cot,
|
|
187
|
+
"chainlitMd": get_markdown_str(config.root),
|
|
188
|
+
"prod": bool(config.chainlit_prod_url),
|
|
189
|
+
"appTitle": config.chatbot_name,
|
|
190
|
+
"github": config.github,
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@app.get("/{path:path}")
|
|
196
|
+
async def serve(path: str):
|
|
197
|
+
"""Serve the UI."""
|
|
198
|
+
path_to_file = os.path.join(build_dir, path)
|
|
199
|
+
if path != "" and os.path.exists(path_to_file):
|
|
200
|
+
return FileResponse(path_to_file)
|
|
201
|
+
else:
|
|
202
|
+
return HTMLResponse(content=html_template, status_code=200)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
-------------------------------------------------------------------------------
|
|
207
|
+
WEBSOCKET HANDLERS
|
|
208
|
+
-------------------------------------------------------------------------------
|
|
209
|
+
"""
|
|
135
210
|
|
|
136
|
-
if config.user_env:
|
|
137
|
-
# Check if requested user environment variables are provided
|
|
138
|
-
if request.headers.get("user-env"):
|
|
139
|
-
user_env = json.loads(request.headers.get("user-env"))
|
|
140
|
-
for key in config.user_env:
|
|
141
|
-
if key not in user_env:
|
|
142
|
-
trace_event("missing_user_env")
|
|
143
|
-
raise ConnectionRefusedError(
|
|
144
|
-
"Missing user environment variable: " + key
|
|
145
|
-
)
|
|
146
211
|
|
|
147
|
-
|
|
148
|
-
|
|
212
|
+
def need_session(id: str):
|
|
213
|
+
"""Return the session with the given id."""
|
|
214
|
+
|
|
215
|
+
session = sessions.get(id)
|
|
216
|
+
if not session:
|
|
217
|
+
raise ValueError("Session not found")
|
|
218
|
+
return session
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@socket.on("connect")
|
|
222
|
+
async def connect(sid, environ):
|
|
223
|
+
user_env = environ.get("HTTP_USER_ENV")
|
|
224
|
+
authorization = environ.get("HTTP_AUTHORIZATION")
|
|
225
|
+
cloud_client = None
|
|
226
|
+
|
|
227
|
+
# Check decorated functions
|
|
228
|
+
if not config.lc_factory and not config.on_message and not config.on_chat_start:
|
|
229
|
+
logger.error(
|
|
230
|
+
"Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
|
|
231
|
+
)
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
# Check authorization
|
|
235
|
+
if not config.public and not authorization:
|
|
149
236
|
# Refuse connection if the app is private and no access token is provided
|
|
150
237
|
trace_event("no_access_token")
|
|
151
|
-
|
|
152
|
-
|
|
238
|
+
logger.error("No access token provided")
|
|
239
|
+
return False
|
|
240
|
+
elif authorization and config.project_id:
|
|
153
241
|
# Create the cloud client
|
|
154
|
-
|
|
242
|
+
cloud_client = CloudClient(
|
|
155
243
|
project_id=config.project_id,
|
|
156
|
-
session_id=
|
|
157
|
-
access_token=
|
|
158
|
-
url=config.chainlit_server,
|
|
244
|
+
session_id=sid,
|
|
245
|
+
access_token=authorization,
|
|
159
246
|
)
|
|
160
|
-
is_project_member =
|
|
247
|
+
is_project_member = await cloud_client.is_project_member()
|
|
161
248
|
if not is_project_member:
|
|
162
|
-
|
|
249
|
+
logger.error("You are not a member of this project")
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
# Check user env
|
|
253
|
+
if config.user_env:
|
|
254
|
+
# Check if requested user environment variables are provided
|
|
255
|
+
if user_env:
|
|
256
|
+
user_env = json.loads(user_env)
|
|
257
|
+
for key in config.user_env:
|
|
258
|
+
if key not in user_env:
|
|
259
|
+
trace_event("missing_user_env")
|
|
260
|
+
logger.error("Missing user environment variable: " + key)
|
|
261
|
+
return False
|
|
262
|
+
else:
|
|
263
|
+
logger.error("Missing user environment variables")
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
# Create the session
|
|
163
267
|
|
|
164
268
|
# Function to send a message to this particular session
|
|
165
|
-
def
|
|
166
|
-
|
|
269
|
+
def emit_fn(event, data):
|
|
270
|
+
if sid in sessions:
|
|
271
|
+
if sessions[sid]["should_stop"]:
|
|
272
|
+
sessions[sid]["should_stop"] = False
|
|
273
|
+
raise InterruptedError("Task stopped by user")
|
|
274
|
+
return socket.emit(event, data, to=sid)
|
|
167
275
|
|
|
168
276
|
# Function to ask the user a question
|
|
169
|
-
def
|
|
170
|
-
|
|
277
|
+
def ask_user_fn(data, timeout):
|
|
278
|
+
if sessions[sid]["should_stop"]:
|
|
279
|
+
sessions[sid]["should_stop"] = False
|
|
280
|
+
raise InterruptedError("Task stopped by user")
|
|
281
|
+
return socket.call("ask", data, timeout=timeout, to=sid)
|
|
171
282
|
|
|
172
283
|
session = {
|
|
173
|
-
"id":
|
|
174
|
-
"emit":
|
|
175
|
-
"ask_user":
|
|
176
|
-
"client":
|
|
284
|
+
"id": sid,
|
|
285
|
+
"emit": emit_fn,
|
|
286
|
+
"ask_user": ask_user_fn,
|
|
287
|
+
"client": cloud_client,
|
|
177
288
|
"user_env": user_env,
|
|
289
|
+
"running_sync": False,
|
|
290
|
+
"should_stop": False,
|
|
178
291
|
} # type: Session
|
|
179
|
-
sessions[session_id] = session
|
|
180
292
|
|
|
181
|
-
|
|
182
|
-
raise ValueError(
|
|
183
|
-
"Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
|
|
184
|
-
)
|
|
293
|
+
sessions[sid] = session
|
|
185
294
|
|
|
186
|
-
|
|
295
|
+
trace_event("connection_successful")
|
|
296
|
+
return True
|
|
187
297
|
|
|
188
|
-
def instantiate_agent(session):
|
|
189
|
-
"""Instantiate the langchain agent and store it in the session."""
|
|
190
|
-
__chainlit_sdk__ = Chainlit(session)
|
|
191
|
-
agent = config.lc_factory()
|
|
192
|
-
session["agent"] = agent
|
|
193
298
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
299
|
+
@socket.on("connection_successful")
|
|
300
|
+
async def connection_successful(sid):
|
|
301
|
+
session = need_session(sid)
|
|
302
|
+
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
303
|
+
if config.lc_factory:
|
|
304
|
+
"""Instantiate the langchain agent and store it in the session."""
|
|
305
|
+
agent = await config.lc_factory(__chainlit_emitter__=__chainlit_emitter__)
|
|
306
|
+
session["agent"] = agent
|
|
197
307
|
|
|
198
308
|
if config.on_chat_start:
|
|
309
|
+
"""Call the on_chat_start function provided by the developer."""
|
|
310
|
+
await config.on_chat_start(__chainlit_emitter__=__chainlit_emitter__)
|
|
199
311
|
|
|
200
|
-
def _on_chat_start(session):
|
|
201
|
-
"""Call the on_chat_start function provided by the developer."""
|
|
202
|
-
__chainlit_sdk__ = Chainlit(session)
|
|
203
|
-
config.on_chat_start()
|
|
204
312
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
trace_event("connection_successful")
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
@socketio.on("disconnect")
|
|
213
|
-
def disconnect():
|
|
214
|
-
"""Handle socket disconnection."""
|
|
215
|
-
|
|
216
|
-
if request.sid in sessions:
|
|
313
|
+
@socket.on("disconnect")
|
|
314
|
+
async def disconnect(sid):
|
|
315
|
+
if sid in sessions:
|
|
217
316
|
# Clean up the session
|
|
218
|
-
|
|
219
|
-
task = session.get("task")
|
|
220
|
-
if task:
|
|
221
|
-
# If a background task is running, kill it
|
|
222
|
-
task.kill()
|
|
317
|
+
sessions.pop(sid)
|
|
223
318
|
|
|
224
|
-
if
|
|
319
|
+
if sid in user_sessions:
|
|
225
320
|
# Clean up the user session
|
|
226
|
-
user_sessions.pop(
|
|
321
|
+
user_sessions.pop(sid)
|
|
227
322
|
|
|
228
323
|
|
|
229
|
-
@
|
|
230
|
-
def stop():
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
if not session:
|
|
235
|
-
return
|
|
324
|
+
@socket.on("stop")
|
|
325
|
+
async def stop(sid):
|
|
326
|
+
if sid in sessions:
|
|
327
|
+
trace_event("stop_task")
|
|
328
|
+
session = sessions[sid]
|
|
236
329
|
|
|
237
|
-
|
|
330
|
+
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
238
331
|
|
|
239
|
-
|
|
240
|
-
task.kill()
|
|
241
|
-
session["task"] = None
|
|
332
|
+
await Message(author="System", content="Task stopped by the user.").send()
|
|
242
333
|
|
|
243
|
-
|
|
334
|
+
session["should_stop"] = True
|
|
244
335
|
|
|
245
336
|
if config.on_stop:
|
|
246
|
-
config.on_stop()
|
|
337
|
+
await config.on_stop()
|
|
247
338
|
|
|
248
|
-
Message(author="System", content="Conversation stopped by the user.").send()
|
|
249
339
|
|
|
250
|
-
|
|
251
|
-
def need_session(id: str):
|
|
252
|
-
"""Return the session with the given id."""
|
|
253
|
-
|
|
254
|
-
session = sessions.get(id)
|
|
255
|
-
if not session:
|
|
256
|
-
raise ValueError("Session not found")
|
|
257
|
-
return session
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
def process_message(session: Session, author: str, input_str: str):
|
|
340
|
+
async def process_message(session: Session, author: str, input_str: str):
|
|
261
341
|
"""Process a message from the user."""
|
|
262
342
|
|
|
263
|
-
__chainlit_sdk__ = Chainlit(session)
|
|
264
343
|
try:
|
|
265
|
-
|
|
344
|
+
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
345
|
+
await __chainlit_emitter__.task_start()
|
|
266
346
|
|
|
267
347
|
if session["client"]:
|
|
268
348
|
# If cloud is enabled, persist the message
|
|
269
|
-
session["client"].create_message(
|
|
349
|
+
await session["client"].create_message(
|
|
270
350
|
{
|
|
271
351
|
"author": author,
|
|
272
352
|
"content": input_str,
|
|
@@ -276,18 +356,28 @@ def process_message(session: Session, author: str, input_str: str):
|
|
|
276
356
|
|
|
277
357
|
langchain_agent = session.get("agent")
|
|
278
358
|
if langchain_agent:
|
|
359
|
+
from chainlit.lc.agent import run_langchain_agent
|
|
360
|
+
|
|
279
361
|
# If a langchain agent is available, run it
|
|
280
362
|
if config.lc_run:
|
|
281
363
|
# If the developer provided a custom run function, use it
|
|
282
|
-
config.lc_run(
|
|
364
|
+
await config.lc_run(
|
|
365
|
+
langchain_agent,
|
|
366
|
+
input_str,
|
|
367
|
+
__chainlit_emitter__=__chainlit_emitter__,
|
|
368
|
+
)
|
|
283
369
|
return
|
|
284
370
|
else:
|
|
285
371
|
# Otherwise, use the default run function
|
|
286
|
-
raw_res, output_key = run_langchain_agent(
|
|
372
|
+
raw_res, output_key = await run_langchain_agent(
|
|
373
|
+
langchain_agent, input_str, use_async=config.lc_agent_is_async
|
|
374
|
+
)
|
|
287
375
|
|
|
288
376
|
if config.lc_postprocess:
|
|
289
377
|
# If the developer provided a custom postprocess function, use it
|
|
290
|
-
config.lc_postprocess(
|
|
378
|
+
await config.lc_postprocess(
|
|
379
|
+
raw_res, __chainlit_emitter__=__chainlit_emitter__
|
|
380
|
+
)
|
|
291
381
|
return
|
|
292
382
|
elif output_key is not None:
|
|
293
383
|
# Use the output key if provided
|
|
@@ -296,54 +386,49 @@ def process_message(session: Session, author: str, input_str: str):
|
|
|
296
386
|
# Otherwise, use the raw response
|
|
297
387
|
res = raw_res
|
|
298
388
|
# Finally, send the response to the user
|
|
299
|
-
Message(author=config.chatbot_name, content=res).send()
|
|
389
|
+
await Message(author=config.chatbot_name, content=res).send()
|
|
300
390
|
|
|
301
391
|
elif config.on_message:
|
|
302
392
|
# If no langchain agent is available, call the on_message function provided by the developer
|
|
303
|
-
config.on_message(
|
|
393
|
+
await config.on_message(
|
|
394
|
+
input_str, __chainlit_emitter__=__chainlit_emitter__
|
|
395
|
+
)
|
|
396
|
+
except InterruptedError:
|
|
397
|
+
pass
|
|
304
398
|
except Exception as e:
|
|
305
399
|
logger.exception(e)
|
|
306
|
-
ErrorMessage(author="Error", content=str(e)).send()
|
|
400
|
+
await ErrorMessage(author="Error", content=str(e)).send()
|
|
307
401
|
finally:
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
@socketio.on("message")
|
|
312
|
-
def on_message(body):
|
|
313
|
-
"""Handle a message from the UI."""
|
|
402
|
+
await __chainlit_emitter__.task_end()
|
|
314
403
|
|
|
315
|
-
session_id = request.sid
|
|
316
|
-
session = need_session(session_id)
|
|
317
404
|
|
|
318
|
-
|
|
319
|
-
|
|
405
|
+
@socket.on("ui_message")
|
|
406
|
+
async def message(sid, data):
|
|
407
|
+
"""Handle a message sent by the User."""
|
|
408
|
+
session = need_session(sid)
|
|
409
|
+
session["should_stop"] = False
|
|
320
410
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
task.join()
|
|
324
|
-
session["task"] = None
|
|
411
|
+
input_str = data["content"].strip()
|
|
412
|
+
author = data["author"]
|
|
325
413
|
|
|
326
|
-
|
|
414
|
+
await process_message(session, author, input_str)
|
|
327
415
|
|
|
328
416
|
|
|
329
|
-
def process_action(session: Session, action: Action):
|
|
330
|
-
|
|
417
|
+
async def process_action(session: Session, action: Action):
|
|
418
|
+
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
331
419
|
callback = config.action_callbacks.get(action.name)
|
|
332
420
|
if callback:
|
|
333
|
-
callback(action)
|
|
421
|
+
await callback(action, __chainlit_emitter__=__chainlit_emitter__)
|
|
334
422
|
else:
|
|
335
423
|
logger.warning("No callback found for action %s", action.name)
|
|
336
424
|
|
|
337
425
|
|
|
338
|
-
@
|
|
339
|
-
def call_action(action):
|
|
426
|
+
@socket.on("action_call")
|
|
427
|
+
async def call_action(sid, action):
|
|
340
428
|
"""Handle an action call from the UI."""
|
|
341
|
-
|
|
342
|
-
session = need_session(session_id)
|
|
429
|
+
session = need_session(sid)
|
|
343
430
|
|
|
431
|
+
__chainlit_emitter__ = ChainlitEmitter(session)
|
|
344
432
|
action = Action(**action)
|
|
345
433
|
|
|
346
|
-
|
|
347
|
-
session["task"] = task
|
|
348
|
-
task.join()
|
|
349
|
-
session["task"] = None
|
|
434
|
+
await process_action(session, action)
|
chainlit/session.py
CHANGED
|
@@ -16,8 +16,10 @@ class Session(TypedDict):
|
|
|
16
16
|
user_env: Dict[str, str]
|
|
17
17
|
# Optional langchain agent
|
|
18
18
|
agent: Any
|
|
19
|
-
#
|
|
20
|
-
|
|
19
|
+
# If the session is currently running a sync task
|
|
20
|
+
running_sync: bool
|
|
21
|
+
# Whether the current task should be stopped
|
|
22
|
+
should_stop: bool
|
|
21
23
|
# Optional client to persist messages and files
|
|
22
24
|
client: Optional[BaseClient]
|
|
23
25
|
|
chainlit/sync.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Any, Callable
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from syncer import sync
|
|
5
|
+
from asyncer import asyncify
|
|
6
|
+
|
|
7
|
+
from chainlit.emitter import get_emitter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def make_async(function: Callable):
|
|
11
|
+
emitter = get_emitter()
|
|
12
|
+
if not emitter:
|
|
13
|
+
raise RuntimeError(
|
|
14
|
+
"Emitter not found, please call make_async in a Chainlit context."
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def wrapper(*args, **kwargs):
|
|
18
|
+
emitter.session["running_sync"] = True
|
|
19
|
+
__chainlit_emitter__ = emitter
|
|
20
|
+
res = function(*args, **kwargs)
|
|
21
|
+
emitter.session["running_sync"] = False
|
|
22
|
+
return res
|
|
23
|
+
|
|
24
|
+
return asyncify(wrapper, cancellable=True)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def run_sync(co: Any):
|
|
28
|
+
try:
|
|
29
|
+
loop = asyncio.get_event_loop()
|
|
30
|
+
except RuntimeError as e:
|
|
31
|
+
if "There is no current event loop" in str(e):
|
|
32
|
+
loop = None
|
|
33
|
+
|
|
34
|
+
if loop is None or not loop.is_running():
|
|
35
|
+
loop = asyncio.new_event_loop()
|
|
36
|
+
asyncio.set_event_loop(loop)
|
|
37
|
+
return sync(co)
|