chainlit 0.2.110__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

chainlit/server.py CHANGED
@@ -5,35 +5,115 @@ mimetypes.add_type("text/css", ".css")
5
5
 
6
6
  import os
7
7
  import json
8
- from flask_cors import CORS
9
- from flask import Flask, request, send_from_directory
10
- from flask_socketio import SocketIO, ConnectionRefusedError
11
- from chainlit.config import config
12
- from chainlit.lc.utils import run_langchain_agent
8
+ import webbrowser
9
+
10
+ from contextlib import asynccontextmanager
11
+ from watchfiles import awatch
12
+
13
+ from fastapi import FastAPI
14
+ from fastapi.responses import (
15
+ HTMLResponse,
16
+ JSONResponse,
17
+ FileResponse,
18
+ PlainTextResponse,
19
+ )
20
+ from fastapi.staticfiles import StaticFiles
21
+ from fastapi_socketio import SocketManager
22
+ from starlette.middleware.cors import CORSMiddleware
23
+ import asyncio
24
+
25
+ from chainlit.config import config, load_module, DEFAULT_HOST
13
26
  from chainlit.session import Session, sessions
14
27
  from chainlit.user_session import user_sessions
15
28
  from chainlit.client import CloudClient
16
- from chainlit.sdk import Chainlit
29
+ from chainlit.emitter import ChainlitEmitter
17
30
  from chainlit.markdown import get_markdown_str
18
31
  from chainlit.action import Action
19
32
  from chainlit.message import Message, ErrorMessage
20
- from chainlit.telemetry import trace, trace_event
33
+ from chainlit.telemetry import trace_event
21
34
  from chainlit.logger import logger
35
+ from chainlit.types import CompletionRequest
36
+
37
+
38
+ @asynccontextmanager
39
+ async def lifespan(app: FastAPI):
40
+ host = config.run_settings.host
41
+ port = config.run_settings.port
42
+
43
+ if not config.run_settings.headless:
44
+ if host == DEFAULT_HOST:
45
+ url = f"http://localhost:{port}"
46
+ else:
47
+ url = f"http://{host}:{port}"
48
+
49
+ logger.info(f"Your app is available at {url}")
50
+ webbrowser.open(url)
51
+
52
+ watch_task = None
53
+ stop_event = asyncio.Event()
54
+
55
+ if config.run_settings.watch:
56
+
57
+ async def watch_files_for_changes():
58
+ async for changes in awatch(config.root, stop_event=stop_event):
59
+ for change_type, file_path in changes:
60
+ file_name = os.path.basename(file_path)
61
+ file_ext = os.path.splitext(file_name)[1]
62
+
63
+ if file_ext.lower() == ".py" or file_name.lower() == "chainlit.md":
64
+ logger.info(f"File {change_type.name}: {file_name}")
65
+
66
+ # Reload the module if the module name is specified in the config
67
+ if config.module_name:
68
+ load_module(config.module_name)
69
+
70
+ await socket.emit("reload", {})
71
+
72
+ break
73
+
74
+ watch_task = asyncio.create_task(watch_files_for_changes())
75
+
76
+ try:
77
+ yield
78
+ except KeyboardInterrupt:
79
+ logger.error("KeyboardInterrupt received, stopping the watch task...")
80
+ finally:
81
+ if watch_task:
82
+ stop_event.set()
83
+ await watch_task
84
+
22
85
 
23
86
  root_dir = os.path.dirname(os.path.abspath(__file__))
24
87
  build_dir = os.path.join(root_dir, "frontend/dist")
25
88
 
26
- app = Flask(__name__, static_folder=build_dir)
27
- CORS(app)
28
- socketio = SocketIO(
89
+ app = FastAPI(lifespan=lifespan)
90
+ app.mount("/static", StaticFiles(directory=build_dir), name="static")
91
+ app.add_middleware(
92
+ CORSMiddleware,
93
+ allow_origins=["*"],
94
+ allow_credentials=True,
95
+ allow_methods=["*"],
96
+ allow_headers=["*"],
97
+ )
98
+
99
+ # Define max HTTP data size to 100 MB
100
+ max_http_data_size = 100 * 1024 * 1024
101
+
102
+ socket = SocketManager(
29
103
  app,
30
- cors_allowed_origins="*",
31
- async_mode="gevent",
32
- max_http_buffer_size=1000000 * 100,
104
+ cors_allowed_origins=[],
105
+ async_mode="asgi",
106
+ max_http_buffer_size=max_http_data_size,
33
107
  )
34
108
 
109
+ """
110
+ -------------------------------------------------------------------------------
111
+ HTTP HANDLERS
112
+ -------------------------------------------------------------------------------
113
+ """
114
+
35
115
 
36
- def inject_html_tags():
116
+ def get_html_template():
37
117
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
38
118
 
39
119
  default_url = "https://github.com/Chainlit/chainlit"
@@ -47,226 +127,226 @@ def inject_html_tags():
47
127
  <meta property="og:image" content="https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png">
48
128
  <meta property="og:url" content="{url}">"""
49
129
 
50
- orig_index_html_file_path = os.path.join(app.static_folder, "index.html")
51
- injected_index_html_file_path = os.path.join(app.static_folder, "_index.html")
130
+ index_html_file_path = os.path.join(build_dir, "index.html")
52
131
 
53
- with open(orig_index_html_file_path, "r", encoding="utf-8") as f:
132
+ with open(index_html_file_path, "r", encoding="utf-8") as f:
54
133
  content = f.read()
55
- content = content.replace(PLACEHOLDER, tags)
56
-
57
- with open(injected_index_html_file_path, "w", encoding="utf-8") as f:
58
- f.write(content)
134
+ content = content.replace(PLACEHOLDER, tags)
135
+ return content
59
136
 
60
137
 
61
- inject_html_tags()
138
+ html_template = get_html_template()
62
139
 
63
140
 
64
- @app.route("/", defaults={"path": ""})
65
- @app.route("/<path:path>")
66
- def serve(path):
67
- """Serve the UI."""
68
- if path != "" and os.path.exists(app.static_folder + "/" + path):
69
- return send_from_directory(app.static_folder, path)
70
- else:
71
- return send_from_directory(app.static_folder, "_index.html")
72
-
73
-
74
- @app.route("/completion", methods=["POST"])
75
- @trace
76
- def completion():
141
+ @app.post("/completion")
142
+ async def completion(completion: CompletionRequest):
77
143
  """Handle a completion request from the prompt playground."""
78
144
 
79
145
  import openai
80
146
 
81
- data = request.json
82
- llm_settings = data["settings"]
83
- user_env = data.get("userEnv", {})
147
+ trace_event("completion")
84
148
 
85
- api_key = user_env.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
149
+ api_key = completion.userEnv.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
86
150
 
87
- model_name = llm_settings.pop("model_name", None)
88
- stop = llm_settings.pop("stop", None)
151
+ model_name = completion.settings.model_name
152
+ stop = completion.settings.stop
89
153
  # OpenAI doesn't support an empty stop array, clear it
90
154
  if isinstance(stop, list) and len(stop) == 0:
91
155
  stop = None
92
156
 
93
157
  if model_name in ["gpt-3.5-turbo", "gpt-4"]:
94
- response = openai.ChatCompletion.create(
158
+ response = await openai.ChatCompletion.acreate(
95
159
  api_key=api_key,
96
160
  model=model_name,
97
- messages=[{"role": "user", "content": data["prompt"]}],
161
+ messages=[{"role": "user", "content": completion.prompt}],
98
162
  stop=stop,
99
- **llm_settings,
163
+ **completion.settings.to_settings_dict(),
100
164
  )
101
- return response["choices"][0]["message"]["content"]
165
+ return PlainTextResponse(content=response["choices"][0]["message"]["content"])
102
166
  else:
103
- response = openai.Completion.create(
167
+ response = await openai.Completion.acreate(
104
168
  api_key=api_key,
105
169
  model=model_name,
106
- prompt=data["prompt"],
170
+ prompt=completion.prompt,
107
171
  stop=stop,
108
- **llm_settings,
172
+ **completion.settings.to_settings_dict(),
109
173
  )
110
- return response["choices"][0]["text"]
174
+ return PlainTextResponse(content=response["choices"][0]["text"])
111
175
 
112
176
 
113
- @app.route("/project/settings", methods=["GET"])
114
- def project_settings():
177
+ @app.get("/project/settings")
178
+ async def project_settings():
115
179
  """Return project settings. This is called by the UI before the establishing the websocket connection."""
116
- return {
117
- "public": config.public,
118
- "projectId": config.project_id,
119
- "chainlitServer": config.chainlit_server,
120
- "userEnv": config.user_env,
121
- "hideCot": config.hide_cot,
122
- "chainlitMd": get_markdown_str(config.root),
123
- "prod": bool(config.chainlit_prod_url),
124
- "appTitle": config.chatbot_name,
125
- "github": config.github,
126
- }
127
-
128
-
129
- @socketio.on("connect")
130
- def connect():
131
- """Handle socket connection."""
132
- session_id = request.sid
133
- client = None
134
- user_env = {}
180
+ return JSONResponse(
181
+ content={
182
+ "public": config.public,
183
+ "projectId": config.project_id,
184
+ "chainlitServer": config.chainlit_server,
185
+ "userEnv": config.user_env,
186
+ "hideCot": config.hide_cot,
187
+ "chainlitMd": get_markdown_str(config.root),
188
+ "prod": bool(config.chainlit_prod_url),
189
+ "appTitle": config.chatbot_name,
190
+ "github": config.github,
191
+ }
192
+ )
193
+
194
+
195
+ @app.get("/{path:path}")
196
+ async def serve(path: str):
197
+ """Serve the UI."""
198
+ path_to_file = os.path.join(build_dir, path)
199
+ if path != "" and os.path.exists(path_to_file):
200
+ return FileResponse(path_to_file)
201
+ else:
202
+ return HTMLResponse(content=html_template, status_code=200)
203
+
204
+
205
+ """
206
+ -------------------------------------------------------------------------------
207
+ WEBSOCKET HANDLERS
208
+ -------------------------------------------------------------------------------
209
+ """
135
210
 
136
- if config.user_env:
137
- # Check if requested user environment variables are provided
138
- if request.headers.get("user-env"):
139
- user_env = json.loads(request.headers.get("user-env"))
140
- for key in config.user_env:
141
- if key not in user_env:
142
- trace_event("missing_user_env")
143
- raise ConnectionRefusedError(
144
- "Missing user environment variable: " + key
145
- )
146
211
 
147
- access_token = request.headers.get("Authorization")
148
- if not config.public and not access_token:
212
+ def need_session(id: str):
213
+ """Return the session with the given id."""
214
+
215
+ session = sessions.get(id)
216
+ if not session:
217
+ raise ValueError("Session not found")
218
+ return session
219
+
220
+
221
+ @socket.on("connect")
222
+ async def connect(sid, environ):
223
+ user_env = environ.get("HTTP_USER_ENV")
224
+ authorization = environ.get("HTTP_AUTHORIZATION")
225
+ cloud_client = None
226
+
227
+ # Check decorated functions
228
+ if not config.lc_factory and not config.on_message and not config.on_chat_start:
229
+ logger.error(
230
+ "Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
231
+ )
232
+ return False
233
+
234
+ # Check authorization
235
+ if not config.public and not authorization:
149
236
  # Refuse connection if the app is private and no access token is provided
150
237
  trace_event("no_access_token")
151
- raise ConnectionRefusedError("No access token provided")
152
- elif access_token and config.project_id:
238
+ logger.error("No access token provided")
239
+ return False
240
+ elif authorization and config.project_id:
153
241
  # Create the cloud client
154
- client = CloudClient(
242
+ cloud_client = CloudClient(
155
243
  project_id=config.project_id,
156
- session_id=session_id,
157
- access_token=access_token,
158
- url=config.chainlit_server,
244
+ session_id=sid,
245
+ access_token=authorization,
159
246
  )
160
- is_project_member = client.is_project_member()
247
+ is_project_member = await cloud_client.is_project_member()
161
248
  if not is_project_member:
162
- raise ConnectionRefusedError("You are not a member of this project")
249
+ logger.error("You are not a member of this project")
250
+ return False
251
+
252
+ # Check user env
253
+ if config.user_env:
254
+ # Check if requested user environment variables are provided
255
+ if user_env:
256
+ user_env = json.loads(user_env)
257
+ for key in config.user_env:
258
+ if key not in user_env:
259
+ trace_event("missing_user_env")
260
+ logger.error("Missing user environment variable: " + key)
261
+ return False
262
+ else:
263
+ logger.error("Missing user environment variables")
264
+ return False
265
+
266
+ # Create the session
163
267
 
164
268
  # Function to send a message to this particular session
165
- def _emit(event, data):
166
- socketio.emit(event, data, to=session_id)
269
+ def emit_fn(event, data):
270
+ if sid in sessions:
271
+ if sessions[sid]["should_stop"]:
272
+ sessions[sid]["should_stop"] = False
273
+ raise InterruptedError("Task stopped by user")
274
+ return socket.emit(event, data, to=sid)
167
275
 
168
276
  # Function to ask the user a question
169
- def _ask_user(data, timeout):
170
- return socketio.call("ask", data, timeout=timeout, to=session_id)
277
+ def ask_user_fn(data, timeout):
278
+ if sessions[sid]["should_stop"]:
279
+ sessions[sid]["should_stop"] = False
280
+ raise InterruptedError("Task stopped by user")
281
+ return socket.call("ask", data, timeout=timeout, to=sid)
171
282
 
172
283
  session = {
173
- "id": session_id,
174
- "emit": _emit,
175
- "ask_user": _ask_user,
176
- "client": client,
284
+ "id": sid,
285
+ "emit": emit_fn,
286
+ "ask_user": ask_user_fn,
287
+ "client": cloud_client,
177
288
  "user_env": user_env,
289
+ "running_sync": False,
290
+ "should_stop": False,
178
291
  } # type: Session
179
- sessions[session_id] = session
180
292
 
181
- if not config.lc_factory and not config.on_message and not config.on_chat_start:
182
- raise ValueError(
183
- "Module should at least expose one of @langchain_factory, @on_message or @on_chat_start function"
184
- )
293
+ sessions[sid] = session
185
294
 
186
- if config.lc_factory:
295
+ trace_event("connection_successful")
296
+ return True
187
297
 
188
- def instantiate_agent(session):
189
- """Instantiate the langchain agent and store it in the session."""
190
- __chainlit_sdk__ = Chainlit(session)
191
- agent = config.lc_factory()
192
- session["agent"] = agent
193
298
 
194
- # Instantiate the agent in a background task since the connection is not yet accepted
195
- task = socketio.start_background_task(instantiate_agent, session)
196
- session["task"] = task
299
+ @socket.on("connection_successful")
300
+ async def connection_successful(sid):
301
+ session = need_session(sid)
302
+ __chainlit_emitter__ = ChainlitEmitter(session)
303
+ if config.lc_factory:
304
+ """Instantiate the langchain agent and store it in the session."""
305
+ agent = await config.lc_factory(__chainlit_emitter__=__chainlit_emitter__)
306
+ session["agent"] = agent
197
307
 
198
308
  if config.on_chat_start:
309
+ """Call the on_chat_start function provided by the developer."""
310
+ await config.on_chat_start(__chainlit_emitter__=__chainlit_emitter__)
199
311
 
200
- def _on_chat_start(session):
201
- """Call the on_chat_start function provided by the developer."""
202
- __chainlit_sdk__ = Chainlit(session)
203
- config.on_chat_start()
204
312
 
205
- # Send the ask in a backgroudn task since the connection is not yet accepted
206
- task = socketio.start_background_task(_on_chat_start, session)
207
- session["task"] = task
208
-
209
- trace_event("connection_successful")
210
-
211
-
212
- @socketio.on("disconnect")
213
- def disconnect():
214
- """Handle socket disconnection."""
215
-
216
- if request.sid in sessions:
313
+ @socket.on("disconnect")
314
+ async def disconnect(sid):
315
+ if sid in sessions:
217
316
  # Clean up the session
218
- session = sessions.pop(request.sid)
219
- task = session.get("task")
220
- if task:
221
- # If a background task is running, kill it
222
- task.kill()
317
+ sessions.pop(sid)
223
318
 
224
- if request.sid in user_sessions:
319
+ if sid in user_sessions:
225
320
  # Clean up the user session
226
- user_sessions.pop(request.sid)
321
+ user_sessions.pop(sid)
227
322
 
228
323
 
229
- @socketio.on("stop")
230
- def stop():
231
- """Handle a stop request from the client."""
232
- trace_event("stop_task")
233
- session = sessions.get(request.sid)
234
- if not session:
235
- return
324
+ @socket.on("stop")
325
+ async def stop(sid):
326
+ if sid in sessions:
327
+ trace_event("stop_task")
328
+ session = sessions[sid]
236
329
 
237
- task = session.get("task")
330
+ __chainlit_emitter__ = ChainlitEmitter(session)
238
331
 
239
- if task:
240
- task.kill()
241
- session["task"] = None
332
+ await Message(author="System", content="Task stopped by the user.").send()
242
333
 
243
- __chainlit_sdk__ = Chainlit(session)
334
+ session["should_stop"] = True
244
335
 
245
336
  if config.on_stop:
246
- config.on_stop()
337
+ await config.on_stop()
247
338
 
248
- Message(author="System", content="Conversation stopped by the user.").send()
249
339
 
250
-
251
- def need_session(id: str):
252
- """Return the session with the given id."""
253
-
254
- session = sessions.get(id)
255
- if not session:
256
- raise ValueError("Session not found")
257
- return session
258
-
259
-
260
- def process_message(session: Session, author: str, input_str: str):
340
+ async def process_message(session: Session, author: str, input_str: str):
261
341
  """Process a message from the user."""
262
342
 
263
- __chainlit_sdk__ = Chainlit(session)
264
343
  try:
265
- __chainlit_sdk__.task_start()
344
+ __chainlit_emitter__ = ChainlitEmitter(session)
345
+ await __chainlit_emitter__.task_start()
266
346
 
267
347
  if session["client"]:
268
348
  # If cloud is enabled, persist the message
269
- session["client"].create_message(
349
+ await session["client"].create_message(
270
350
  {
271
351
  "author": author,
272
352
  "content": input_str,
@@ -276,18 +356,28 @@ def process_message(session: Session, author: str, input_str: str):
276
356
 
277
357
  langchain_agent = session.get("agent")
278
358
  if langchain_agent:
359
+ from chainlit.lc.agent import run_langchain_agent
360
+
279
361
  # If a langchain agent is available, run it
280
362
  if config.lc_run:
281
363
  # If the developer provided a custom run function, use it
282
- config.lc_run(langchain_agent, input_str)
364
+ await config.lc_run(
365
+ langchain_agent,
366
+ input_str,
367
+ __chainlit_emitter__=__chainlit_emitter__,
368
+ )
283
369
  return
284
370
  else:
285
371
  # Otherwise, use the default run function
286
- raw_res, output_key = run_langchain_agent(langchain_agent, input_str)
372
+ raw_res, output_key = await run_langchain_agent(
373
+ langchain_agent, input_str, use_async=config.lc_agent_is_async
374
+ )
287
375
 
288
376
  if config.lc_postprocess:
289
377
  # If the developer provided a custom postprocess function, use it
290
- config.lc_postprocess(raw_res)
378
+ await config.lc_postprocess(
379
+ raw_res, __chainlit_emitter__=__chainlit_emitter__
380
+ )
291
381
  return
292
382
  elif output_key is not None:
293
383
  # Use the output key if provided
@@ -296,54 +386,49 @@ def process_message(session: Session, author: str, input_str: str):
296
386
  # Otherwise, use the raw response
297
387
  res = raw_res
298
388
  # Finally, send the response to the user
299
- Message(author=config.chatbot_name, content=res).send()
389
+ await Message(author=config.chatbot_name, content=res).send()
300
390
 
301
391
  elif config.on_message:
302
392
  # If no langchain agent is available, call the on_message function provided by the developer
303
- config.on_message(input_str)
393
+ await config.on_message(
394
+ input_str, __chainlit_emitter__=__chainlit_emitter__
395
+ )
396
+ except InterruptedError:
397
+ pass
304
398
  except Exception as e:
305
399
  logger.exception(e)
306
- ErrorMessage(author="Error", content=str(e)).send()
400
+ await ErrorMessage(author="Error", content=str(e)).send()
307
401
  finally:
308
- __chainlit_sdk__.task_end()
309
-
310
-
311
- @socketio.on("message")
312
- def on_message(body):
313
- """Handle a message from the UI."""
402
+ await __chainlit_emitter__.task_end()
314
403
 
315
- session_id = request.sid
316
- session = need_session(session_id)
317
404
 
318
- input_str = body["content"].strip()
319
- author = body["author"]
405
+ @socket.on("ui_message")
406
+ async def message(sid, data):
407
+ """Handle a message sent by the User."""
408
+ session = need_session(sid)
409
+ session["should_stop"] = False
320
410
 
321
- task = socketio.start_background_task(process_message, session, author, input_str)
322
- session["task"] = task
323
- task.join()
324
- session["task"] = None
411
+ input_str = data["content"].strip()
412
+ author = data["author"]
325
413
 
326
- return {"success": True}
414
+ await process_message(session, author, input_str)
327
415
 
328
416
 
329
- def process_action(session: Session, action: Action):
330
- __chainlit_sdk__ = Chainlit(session)
417
+ async def process_action(session: Session, action: Action):
418
+ __chainlit_emitter__ = ChainlitEmitter(session)
331
419
  callback = config.action_callbacks.get(action.name)
332
420
  if callback:
333
- callback(action)
421
+ await callback(action, __chainlit_emitter__=__chainlit_emitter__)
334
422
  else:
335
423
  logger.warning("No callback found for action %s", action.name)
336
424
 
337
425
 
338
- @socketio.on("call_action")
339
- def call_action(action):
426
+ @socket.on("action_call")
427
+ async def call_action(sid, action):
340
428
  """Handle an action call from the UI."""
341
- session_id = request.sid
342
- session = need_session(session_id)
429
+ session = need_session(sid)
343
430
 
431
+ __chainlit_emitter__ = ChainlitEmitter(session)
344
432
  action = Action(**action)
345
433
 
346
- task = socketio.start_background_task(process_action, session, action)
347
- session["task"] = task
348
- task.join()
349
- session["task"] = None
434
+ await process_action(session, action)
chainlit/session.py CHANGED
@@ -16,8 +16,10 @@ class Session(TypedDict):
16
16
  user_env: Dict[str, str]
17
17
  # Optional langchain agent
18
18
  agent: Any
19
- # Potential background task running
20
- task: Optional[Any]
19
+ # If the session is currently running a sync task
20
+ running_sync: bool
21
+ # Whether the current task should be stopped
22
+ should_stop: bool
21
23
  # Optional client to persist messages and files
22
24
  client: Optional[BaseClient]
23
25
 
chainlit/sync.py ADDED
@@ -0,0 +1,37 @@
1
+ from typing import Any, Callable
2
+
3
+ import asyncio
4
+ from syncer import sync
5
+ from asyncer import asyncify
6
+
7
+ from chainlit.emitter import get_emitter
8
+
9
+
10
+ def make_async(function: Callable):
11
+ emitter = get_emitter()
12
+ if not emitter:
13
+ raise RuntimeError(
14
+ "Emitter not found, please call make_async in a Chainlit context."
15
+ )
16
+
17
+ def wrapper(*args, **kwargs):
18
+ emitter.session["running_sync"] = True
19
+ __chainlit_emitter__ = emitter
20
+ res = function(*args, **kwargs)
21
+ emitter.session["running_sync"] = False
22
+ return res
23
+
24
+ return asyncify(wrapper, cancellable=True)
25
+
26
+
27
+ def run_sync(co: Any):
28
+ try:
29
+ loop = asyncio.get_event_loop()
30
+ except RuntimeError as e:
31
+ if "There is no current event loop" in str(e):
32
+ loop = None
33
+
34
+ if loop is None or not loop.is_running():
35
+ loop = asyncio.new_event_loop()
36
+ asyncio.set_event_loop(loop)
37
+ return sync(co)