chainlit 1.1.300rc4__py3-none-any.whl → 1.1.301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (45) hide show
  1. chainlit/__init__.py +3 -1
  2. chainlit/cli/__init__.py +53 -6
  3. chainlit/config.py +4 -0
  4. chainlit/context.py +9 -0
  5. chainlit/copilot/dist/index.js +180 -180
  6. chainlit/data/__init__.py +6 -3
  7. chainlit/data/sql_alchemy.py +3 -3
  8. chainlit/element.py +33 -9
  9. chainlit/emitter.py +4 -4
  10. chainlit/frontend/dist/assets/{DailyMotion-1a2b7d60.js → DailyMotion-578b63e6.js} +1 -1
  11. chainlit/frontend/dist/assets/{Facebook-8422f48c.js → Facebook-b825e5bb.js} +1 -1
  12. chainlit/frontend/dist/assets/{FilePlayer-a0a41349.js → FilePlayer-bcba3b4e.js} +1 -1
  13. chainlit/frontend/dist/assets/{Kaltura-aa7990f2.js → Kaltura-fc1c9497.js} +1 -1
  14. chainlit/frontend/dist/assets/{Mixcloud-1647b3e4.js → Mixcloud-4cfb2724.js} +1 -1
  15. chainlit/frontend/dist/assets/{Mux-7e57be81.js → Mux-aa92055c.js} +1 -1
  16. chainlit/frontend/dist/assets/{Preview-cb89b2c6.js → Preview-9f55905a.js} +1 -1
  17. chainlit/frontend/dist/assets/{SoundCloud-c0d86d55.js → SoundCloud-f991fe03.js} +1 -1
  18. chainlit/frontend/dist/assets/{Streamable-57c43c18.js → Streamable-53128f49.js} +1 -1
  19. chainlit/frontend/dist/assets/{Twitch-bed3f21d.js → Twitch-fce8b9f5.js} +1 -1
  20. chainlit/frontend/dist/assets/{Vidyard-4dd76e44.js → Vidyard-e35c6102.js} +1 -1
  21. chainlit/frontend/dist/assets/{Vimeo-93bb5ae2.js → Vimeo-fff35f8e.js} +1 -1
  22. chainlit/frontend/dist/assets/{Wistia-97199246.js → Wistia-ec07dc64.js} +1 -1
  23. chainlit/frontend/dist/assets/{YouTube-fe1a7afe.js → YouTube-ad068e2a.js} +1 -1
  24. chainlit/frontend/dist/assets/index-d40d41cc.js +727 -0
  25. chainlit/frontend/dist/assets/{react-plotly-b416b8f9.js → react-plotly-b2c6442b.js} +1 -1
  26. chainlit/frontend/dist/index.html +1 -2
  27. chainlit/message.py +13 -8
  28. chainlit/oauth_providers.py +96 -4
  29. chainlit/server.py +182 -57
  30. chainlit/slack/app.py +2 -2
  31. chainlit/socket.py +24 -21
  32. chainlit/step.py +12 -3
  33. chainlit/teams/__init__.py +6 -0
  34. chainlit/teams/app.py +332 -0
  35. chainlit/translations/en-US.json +1 -1
  36. chainlit/types.py +7 -17
  37. chainlit/user.py +9 -1
  38. chainlit/utils.py +43 -0
  39. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.301.dist-info}/METADATA +2 -2
  40. chainlit-1.1.301.dist-info/RECORD +79 -0
  41. chainlit/cli/utils.py +0 -24
  42. chainlit/frontend/dist/assets/index-919bea8f.js +0 -727
  43. chainlit-1.1.300rc4.dist-info/RECORD +0 -78
  44. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.301.dist-info}/WHEEL +0 -0
  45. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.301.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -18,6 +18,7 @@ import webbrowser
18
18
  from contextlib import asynccontextmanager
19
19
  from pathlib import Path
20
20
 
21
+ import socketio
21
22
  from chainlit.auth import create_jwt, get_configuration, get_current_user
22
23
  from chainlit.config import (
23
24
  APP_ROOT,
@@ -33,19 +34,19 @@ from chainlit.data import get_data_layer
33
34
  from chainlit.data.acl import is_thread_author
34
35
  from chainlit.logger import logger
35
36
  from chainlit.markdown import get_markdown_str
36
- from chainlit.telemetry import trace_event
37
37
  from chainlit.types import (
38
38
  DeleteFeedbackRequest,
39
39
  DeleteThreadRequest,
40
- GenerationRequest,
41
40
  GetThreadsRequest,
42
41
  Theme,
43
42
  UpdateFeedbackRequest,
44
43
  )
45
44
  from chainlit.user import PersistedUser, User
46
45
  from fastapi import (
46
+ APIRouter,
47
47
  Depends,
48
48
  FastAPI,
49
+ Form,
49
50
  HTTPException,
50
51
  Query,
51
52
  Request,
@@ -56,12 +57,14 @@ from fastapi import (
56
57
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
57
58
  from fastapi.security import OAuth2PasswordRequestForm
58
59
  from fastapi.staticfiles import StaticFiles
59
- from fastapi_socketio import SocketManager
60
60
  from starlette.datastructures import URL
61
61
  from starlette.middleware.cors import CORSMiddleware
62
62
  from typing_extensions import Annotated
63
63
  from watchfiles import awatch
64
64
 
65
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
66
+ IS_SUBMOUNT = os.environ.get("CHAINLIT_SUBMOUNT", "") == "true"
67
+
65
68
 
66
69
  @asynccontextmanager
67
70
  async def lifespan(app: FastAPI):
@@ -69,9 +72,9 @@ async def lifespan(app: FastAPI):
69
72
  port = config.run.port
70
73
 
71
74
  if host == DEFAULT_HOST:
72
- url = f"http://localhost:{port}"
75
+ url = f"http://localhost:{port}{ROOT_PATH}"
73
76
  else:
74
- url = f"http://{host}:{port}"
77
+ url = f"http://{host}:{port}{ROOT_PATH}"
75
78
 
76
79
  logger.info(f"Your app is available at {url}")
77
80
 
@@ -112,7 +115,7 @@ async def lifespan(app: FastAPI):
112
115
  logger.error(f"Error reloading module: {e}")
113
116
 
114
117
  await asyncio.sleep(1)
115
- await socket.emit("reload", {})
118
+ await sio.emit("reload", {})
116
119
 
117
120
  break
118
121
 
@@ -166,12 +169,36 @@ def get_build_dir(local_target: str, packaged_target: str):
166
169
  build_dir = get_build_dir("frontend", "frontend")
167
170
  copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
168
171
 
169
-
170
172
  app = FastAPI(lifespan=lifespan)
171
173
 
172
- app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public")
174
+ sio = socketio.AsyncServer(
175
+ cors_allowed_origins=[] if IS_SUBMOUNT else "*", async_mode="asgi"
176
+ )
177
+
178
+ combined_asgi_app = socketio.ASGIApp(
179
+ sio,
180
+ app,
181
+ socketio_path=f"{ROOT_PATH}/ws/socket.io" if ROOT_PATH else "/ws/socket.io",
182
+ )
183
+
184
+ app.add_middleware(
185
+ CORSMiddleware,
186
+ allow_origins=config.project.allow_origins,
187
+ allow_credentials=True,
188
+ allow_methods=["*"],
189
+ allow_headers=["*"],
190
+ )
191
+
192
+ router = APIRouter(prefix=ROOT_PATH)
193
+
173
194
  app.mount(
174
- "/assets",
195
+ f"{ROOT_PATH}/public",
196
+ StaticFiles(directory="public", check_dir=False),
197
+ name="public",
198
+ )
199
+
200
+ app.mount(
201
+ f"{ROOT_PATH}/assets",
175
202
  StaticFiles(
176
203
  packages=[("chainlit", os.path.join(build_dir, "assets"))],
177
204
  follow_symlink=config.project.follow_symlink,
@@ -180,7 +207,7 @@ app.mount(
180
207
  )
181
208
 
182
209
  app.mount(
183
- "/copilot",
210
+ f"{ROOT_PATH}/copilot",
184
211
  StaticFiles(
185
212
  packages=[("chainlit", copilot_build_dir)],
186
213
  follow_symlink=config.project.follow_symlink,
@@ -189,22 +216,6 @@ app.mount(
189
216
  )
190
217
 
191
218
 
192
- app.add_middleware(
193
- CORSMiddleware,
194
- allow_origins=config.project.allow_origins,
195
- allow_credentials=True,
196
- allow_methods=["*"],
197
- allow_headers=["*"],
198
- )
199
-
200
- socket = SocketManager(
201
- app,
202
- cors_allowed_origins=[],
203
- async_mode="asgi",
204
- socketio_path="/ws/socket.io",
205
- )
206
-
207
-
208
219
  # -------------------------------------------------------------------------------
209
220
  # SLACK HANDLER
210
221
  # -------------------------------------------------------------------------------
@@ -212,11 +223,28 @@ socket = SocketManager(
212
223
  if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
213
224
  from chainlit.slack.app import slack_app_handler
214
225
 
215
- @app.post("/slack/events")
216
- async def endpoint(req: Request):
226
+ @router.post("/slack/events")
227
+ async def slack_endpoint(req: Request):
217
228
  return await slack_app_handler.handle(req)
218
229
 
219
230
 
231
+ # -------------------------------------------------------------------------------
232
+ # TEAMS HANDLER
233
+ # -------------------------------------------------------------------------------
234
+
235
+ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
236
+ from botbuilder.schema import Activity
237
+ from chainlit.teams.app import adapter, bot
238
+
239
+ @router.post("/teams/events")
240
+ async def teams_endpoint(req: Request):
241
+ body = await req.json()
242
+ activity = Activity().deserialize(body)
243
+ auth_header = req.headers.get("Authorization", "")
244
+ response = await adapter.process_activity(activity, auth_header, bot.on_turn)
245
+ return response
246
+
247
+
220
248
  # -------------------------------------------------------------------------------
221
249
  # HTTP HANDLERS
222
250
  # -------------------------------------------------------------------------------
@@ -238,14 +266,17 @@ def get_html_template():
238
266
  )
239
267
  url = config.ui.github or default_url
240
268
  meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
269
+ favicon_path = ROOT_PATH + "/favicon" if ROOT_PATH else "/favicon"
241
270
 
242
271
  tags = f"""<title>{config.ui.name}</title>
272
+ <link rel="icon" href="{favicon_path}" />
243
273
  <meta name="description" content="{config.ui.description}">
244
274
  <meta property="og:type" content="website">
245
275
  <meta property="og:title" content="{config.ui.name}">
246
276
  <meta property="og:description" content="{config.ui.description}">
247
277
  <meta property="og:image" content="{meta_image_url}">
248
- <meta property="og:url" content="{url}">"""
278
+ <meta property="og:url" content="{url}">
279
+ <meta property="og:root_path" content="{ROOT_PATH}">"""
249
280
 
250
281
  js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
251
282
 
@@ -275,6 +306,9 @@ def get_html_template():
275
306
  content = replace_between_tags(
276
307
  content, "<!-- FONT START -->", "<!-- FONT END -->", font
277
308
  )
309
+ if ROOT_PATH:
310
+ content = content.replace('href="/', f'href="{ROOT_PATH}/')
311
+ content = content.replace('src="/', f'src="{ROOT_PATH}/')
278
312
  return content
279
313
 
280
314
 
@@ -284,6 +318,7 @@ def get_user_facing_url(url: URL):
284
318
  Handles deployment with proxies (like cloud run).
285
319
  """
286
320
 
321
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
287
322
  chainlit_url = os.environ.get("CHAINLIT_URL")
288
323
 
289
324
  # No config, we keep the URL as is
@@ -299,15 +334,26 @@ def get_user_facing_url(url: URL):
299
334
  if config_url.path.endswith("/"):
300
335
  config_url = config_url.replace(path=config_url.path[:-1])
301
336
 
337
+ # Add ROOT_PATH to the final URL if it exists
338
+ if ROOT_PATH:
339
+ # Ensure ROOT_PATH starts with a slash
340
+ if not ROOT_PATH.startswith("/"):
341
+ ROOT_PATH = "/" + ROOT_PATH
342
+ # Ensure ROOT_PATH does not end with a slash
343
+ if ROOT_PATH.endswith("/"):
344
+ ROOT_PATH = ROOT_PATH[:-1]
345
+
346
+ return config_url.__str__() + ROOT_PATH + url.path
347
+
302
348
  return config_url.__str__() + url.path
303
349
 
304
350
 
305
- @app.get("/auth/config")
351
+ @router.get("/auth/config")
306
352
  async def auth(request: Request):
307
353
  return get_configuration()
308
354
 
309
355
 
310
- @app.post("/login")
356
+ @router.post("/login")
311
357
  async def login(form_data: OAuth2PasswordRequestForm = Depends()):
312
358
  if not config.code.password_auth_callback:
313
359
  raise HTTPException(
@@ -336,14 +382,14 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
336
382
  }
337
383
 
338
384
 
339
- @app.post("/logout")
385
+ @router.post("/logout")
340
386
  async def logout(request: Request, response: Response):
341
387
  if config.code.on_logout:
342
388
  return await config.code.on_logout(request, response)
343
389
  return {"success": True}
344
390
 
345
391
 
346
- @app.post("/auth/header")
392
+ @router.post("/auth/header")
347
393
  async def header_auth(request: Request):
348
394
  if not config.code.header_auth_callback:
349
395
  raise HTTPException(
@@ -372,7 +418,7 @@ async def header_auth(request: Request):
372
418
  }
373
419
 
374
420
 
375
- @app.get("/auth/oauth/{provider_id}")
421
+ @router.get("/auth/oauth/{provider_id}")
376
422
  async def oauth_login(provider_id: str, request: Request):
377
423
  if config.code.oauth_callback is None:
378
424
  raise HTTPException(
@@ -413,7 +459,7 @@ async def oauth_login(provider_id: str, request: Request):
413
459
  return response
414
460
 
415
461
 
416
- @app.get("/auth/oauth/{provider_id}/callback")
462
+ @router.get("/auth/oauth/{provider_id}/callback")
417
463
  async def oauth_callback(
418
464
  provider_id: str,
419
465
  request: Request,
@@ -497,7 +543,85 @@ async def oauth_callback(
497
543
  return response
498
544
 
499
545
 
500
- @app.get("/project/translations")
546
+ # specific route for azure ad hybrid flow
547
+ @router.post("/auth/oauth/azure-ad-hybrid/callback")
548
+ async def oauth_azure_hf_callback(
549
+ request: Request,
550
+ error: Optional[str] = None,
551
+ code: Annotated[Optional[str], Form()] = None,
552
+ id_token: Annotated[Optional[str], Form()] = None,
553
+ ):
554
+ provider_id = "azure-ad-hybrid"
555
+ if config.code.oauth_callback is None:
556
+ raise HTTPException(
557
+ status_code=status.HTTP_400_BAD_REQUEST,
558
+ detail="No oauth_callback defined",
559
+ )
560
+
561
+ provider = get_oauth_provider(provider_id)
562
+ if not provider:
563
+ raise HTTPException(
564
+ status_code=status.HTTP_404_NOT_FOUND,
565
+ detail=f"Provider {provider_id} not found",
566
+ )
567
+
568
+ if error:
569
+ params = urllib.parse.urlencode(
570
+ {
571
+ "error": error,
572
+ }
573
+ )
574
+ response = RedirectResponse(
575
+ # FIXME: redirect to the right frontend base url to improve the dev environment
576
+ url=f"/login?{params}",
577
+ )
578
+ return response
579
+
580
+ if not code:
581
+ raise HTTPException(
582
+ status_code=status.HTTP_400_BAD_REQUEST,
583
+ detail="Missing code",
584
+ )
585
+
586
+ url = get_user_facing_url(request.url)
587
+ token = await provider.get_token(code, url)
588
+
589
+ (raw_user_data, default_user) = await provider.get_user_info(token)
590
+
591
+ user = await config.code.oauth_callback(
592
+ provider_id, token, raw_user_data, default_user, id_token
593
+ )
594
+
595
+ if not user:
596
+ raise HTTPException(
597
+ status_code=status.HTTP_401_UNAUTHORIZED,
598
+ detail="Unauthorized",
599
+ )
600
+
601
+ access_token = create_jwt(user)
602
+
603
+ if data_layer := get_data_layer():
604
+ try:
605
+ await data_layer.create_user(user)
606
+ except Exception as e:
607
+ logger.error(f"Error creating user: {e}")
608
+
609
+ params = urllib.parse.urlencode(
610
+ {
611
+ "access_token": access_token,
612
+ "token_type": "bearer",
613
+ }
614
+ )
615
+ response = RedirectResponse(
616
+ # FIXME: redirect to the right frontend base url to improve the dev environment
617
+ url=f"/login/callback?{params}",
618
+ status_code=302,
619
+ )
620
+ response.delete_cookie("oauth_state")
621
+ return response
622
+
623
+
624
+ @router.get("/project/translations")
501
625
  async def project_translations(
502
626
  language: str = Query(default="en-US", description="Language code"),
503
627
  ):
@@ -513,7 +637,7 @@ async def project_translations(
513
637
  )
514
638
 
515
639
 
516
- @app.get("/project/settings")
640
+ @router.get("/project/settings")
517
641
  async def project_settings(
518
642
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
519
643
  language: str = Query(default="en-US", description="Language code"),
@@ -559,7 +683,7 @@ async def project_settings(
559
683
  )
560
684
 
561
685
 
562
- @app.put("/feedback")
686
+ @router.put("/feedback")
563
687
  async def update_feedback(
564
688
  request: Request,
565
689
  update: UpdateFeedbackRequest,
@@ -578,7 +702,7 @@ async def update_feedback(
578
702
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
579
703
 
580
704
 
581
- @app.delete("/feedback")
705
+ @router.delete("/feedback")
582
706
  async def delete_feedback(
583
707
  request: Request,
584
708
  payload: DeleteFeedbackRequest,
@@ -597,7 +721,7 @@ async def delete_feedback(
597
721
  return JSONResponse(content={"success": True})
598
722
 
599
723
 
600
- @app.post("/project/threads")
724
+ @router.post("/project/threads")
601
725
  async def get_user_threads(
602
726
  request: Request,
603
727
  payload: GetThreadsRequest,
@@ -622,7 +746,7 @@ async def get_user_threads(
622
746
  return JSONResponse(content=res.to_dict())
623
747
 
624
748
 
625
- @app.get("/project/thread/{thread_id}")
749
+ @router.get("/project/thread/{thread_id}")
626
750
  async def get_thread(
627
751
  request: Request,
628
752
  thread_id: str,
@@ -640,7 +764,7 @@ async def get_thread(
640
764
  return JSONResponse(content=res)
641
765
 
642
766
 
643
- @app.get("/project/thread/{thread_id}/element/{element_id}")
767
+ @router.get("/project/thread/{thread_id}/element/{element_id}")
644
768
  async def get_thread_element(
645
769
  request: Request,
646
770
  thread_id: str,
@@ -659,7 +783,7 @@ async def get_thread_element(
659
783
  return JSONResponse(content=res)
660
784
 
661
785
 
662
- @app.delete("/project/thread")
786
+ @router.delete("/project/thread")
663
787
  async def delete_thread(
664
788
  request: Request,
665
789
  payload: DeleteThreadRequest,
@@ -680,7 +804,7 @@ async def delete_thread(
680
804
  return JSONResponse(content={"success": True})
681
805
 
682
806
 
683
- @app.post("/project/file")
807
+ @router.post("/project/file")
684
808
  async def upload_file(
685
809
  session_id: str,
686
810
  file: UploadFile,
@@ -716,7 +840,7 @@ async def upload_file(
716
840
  return JSONResponse(file_response)
717
841
 
718
842
 
719
- @app.get("/project/file/{file_id}")
843
+ @router.get("/project/file/{file_id}")
720
844
  async def get_file(
721
845
  file_id: str,
722
846
  session_id: Optional[str] = None,
@@ -738,7 +862,7 @@ async def get_file(
738
862
  raise HTTPException(status_code=404, detail="File not found")
739
863
 
740
864
 
741
- @app.get("/files/{filename:path}")
865
+ @router.get("/files/{filename:path}")
742
866
  async def serve_file(
743
867
  filename: str,
744
868
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
@@ -756,7 +880,7 @@ async def serve_file(
756
880
  raise HTTPException(status_code=404, detail="File not found")
757
881
 
758
882
 
759
- @app.get("/favicon")
883
+ @router.get("/favicon")
760
884
  async def get_favicon():
761
885
  custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
762
886
  files = glob.glob(custom_favicon_path)
@@ -771,7 +895,7 @@ async def get_favicon():
771
895
  return FileResponse(favicon_path, media_type=media_type)
772
896
 
773
897
 
774
- @app.get("/logo")
898
+ @router.get("/logo")
775
899
  async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
776
900
  theme_value = theme.value if theme else Theme.light.value
777
901
  logo_path = None
@@ -793,7 +917,7 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
793
917
  return FileResponse(logo_path, media_type=media_type)
794
918
 
795
919
 
796
- @app.get("/avatars/{avatar_id}")
920
+ @router.get("/avatars/{avatar_id}")
797
921
  async def get_avatar(avatar_id: str):
798
922
  if avatar_id == "default":
799
923
  avatar_id = config.ui.name
@@ -812,19 +936,20 @@ async def get_avatar(avatar_id: str):
812
936
  return await get_favicon()
813
937
 
814
938
 
815
- @app.head("/")
939
+ @router.head("/")
816
940
  def status_check():
817
941
  return {"message": "Site is operational"}
818
942
 
819
943
 
820
- def register_wildcard_route_handler():
821
- @app.get("/{path:path}")
822
- async def serve(request: Request, path: str):
823
- html_template = get_html_template()
824
- """Serve the UI files."""
825
- response = HTMLResponse(content=html_template, status_code=200)
944
+ @router.get("/{full_path:path}")
945
+ async def serve():
946
+ html_template = get_html_template()
947
+ """Serve the UI files."""
948
+ response = HTMLResponse(content=html_template, status_code=200)
949
+
950
+ return response
826
951
 
827
- return response
828
952
 
953
+ app.include_router(router)
829
954
 
830
955
  import chainlit.socket # noqa
chainlit/slack/app.py CHANGED
@@ -176,8 +176,8 @@ async def get_user(slack_user_id: str):
176
176
  slack_user = await slack_app.client.users_info(user=slack_user_id)
177
177
  slack_user_profile = slack_user["user"]["profile"]
178
178
 
179
- user_email = slack_user_profile.get("email")
180
- user = User(identifier=USER_PREFIX + user_email, metadata=slack_user_profile)
179
+ user_identifier = slack_user_profile.get("email") or slack_user_id
180
+ user = User(identifier=USER_PREFIX + user_identifier, metadata=slack_user_profile)
181
181
 
182
182
  users_by_slack_id[slack_user_id] = user
183
183
 
chainlit/socket.py CHANGED
@@ -2,8 +2,8 @@ import asyncio
2
2
  import json
3
3
  import time
4
4
  import uuid
5
- from urllib.parse import unquote
6
5
  from typing import Any, Dict, Literal
6
+ from urllib.parse import unquote
7
7
 
8
8
  from chainlit.action import Action
9
9
  from chainlit.auth import get_current_user, require_login
@@ -13,14 +13,15 @@ from chainlit.data import get_data_layer
13
13
  from chainlit.element import Element
14
14
  from chainlit.logger import logger
15
15
  from chainlit.message import ErrorMessage, Message
16
- from chainlit.server import socket
16
+ from chainlit.server import sio
17
17
  from chainlit.session import WebsocketSession
18
18
  from chainlit.telemetry import trace_event
19
19
  from chainlit.types import (
20
20
  AudioChunk,
21
21
  AudioChunkPayload,
22
22
  AudioEndPayload,
23
- UIMessagePayload,
23
+ MessagePayload,
24
+ SystemMessagePayload,
24
25
  )
25
26
  from chainlit.user_session import user_sessions
26
27
 
@@ -53,7 +54,7 @@ async def resume_thread(session: WebsocketSession):
53
54
  user_is_author = author == session.user.identifier
54
55
 
55
56
  if user_is_author:
56
- metadata = thread.get("metadata", {})
57
+ metadata = thread.get("metadata") or {}
57
58
  user_sessions[session.id] = metadata.copy()
58
59
  if chat_profile := metadata.get("chat_profile"):
59
60
  session.chat_profile = chat_profile
@@ -98,8 +99,8 @@ def build_anon_user_identifier(environ):
98
99
  return str(uuid.uuid5(uuid.NAMESPACE_DNS, ip))
99
100
 
100
101
 
101
- @socket.on("connect")
102
- async def connect(sid, environ, auth):
102
+ @sio.on("connect")
103
+ async def connect(sid, environ):
103
104
  if (
104
105
  not config.code.on_chat_start
105
106
  and not config.code.on_message
@@ -124,11 +125,11 @@ async def connect(sid, environ, auth):
124
125
 
125
126
  # Session scoped function to emit to the client
126
127
  def emit_fn(event, data):
127
- return socket.emit(event, data, to=sid)
128
+ return sio.emit(event, data, to=sid)
128
129
 
129
130
  # Session scoped function to emit to the client and wait for a response
130
131
  def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
131
- return socket.call(event, data, timeout=timeout, to=sid)
132
+ return sio.call(event, data, timeout=timeout, to=sid)
132
133
 
133
134
  session_id = environ.get("HTTP_X_CHAINLIT_SESSION_ID")
134
135
  if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
@@ -140,7 +141,9 @@ async def connect(sid, environ, auth):
140
141
  client_type = environ.get("HTTP_X_CHAINLIT_CLIENT_TYPE")
141
142
  http_referer = environ.get("HTTP_REFERER")
142
143
  url_encoded_chat_profile = environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE")
143
- chat_profile = unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
144
+ chat_profile = (
145
+ unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
146
+ )
144
147
 
145
148
  ws_session = WebsocketSession(
146
149
  id=session_id,
@@ -161,7 +164,7 @@ async def connect(sid, environ, auth):
161
164
  return True
162
165
 
163
166
 
164
- @socket.on("connection_successful")
167
+ @sio.on("connection_successful")
165
168
  async def connection_successful(sid):
166
169
  context = init_ws_context(sid)
167
170
 
@@ -189,14 +192,14 @@ async def connection_successful(sid):
189
192
  context.session.current_task = task
190
193
 
191
194
 
192
- @socket.on("clear_session")
195
+ @sio.on("clear_session")
193
196
  async def clean_session(sid):
194
197
  session = WebsocketSession.get(sid)
195
198
  if session:
196
199
  session.to_clear = True
197
200
 
198
201
 
199
- @socket.on("disconnect")
202
+ @sio.on("disconnect")
200
203
  async def disconnect(sid):
201
204
  session = WebsocketSession.get(sid)
202
205
 
@@ -230,7 +233,7 @@ async def disconnect(sid):
230
233
  asyncio.ensure_future(clear_on_timeout(sid))
231
234
 
232
235
 
233
- @socket.on("stop")
236
+ @sio.on("stop")
234
237
  async def stop(sid):
235
238
  if session := WebsocketSession.get(sid):
236
239
  trace_event("stop_task")
@@ -245,12 +248,12 @@ async def stop(sid):
245
248
  await config.code.on_stop()
246
249
 
247
250
 
248
- async def process_message(session: WebsocketSession, payload: UIMessagePayload):
251
+ async def process_message(session: WebsocketSession, payload: MessagePayload):
249
252
  """Process a message from the user."""
250
253
  try:
251
254
  context = init_ws_context(session)
252
255
  await context.emitter.task_start()
253
- message = await context.emitter.process_user_message(payload)
256
+ message = await context.emitter.process_message(payload)
254
257
 
255
258
  if config.code.on_message:
256
259
  # Sleep 1ms to make sure any children step starts after the message step start
@@ -267,8 +270,8 @@ async def process_message(session: WebsocketSession, payload: UIMessagePayload):
267
270
  await context.emitter.task_end()
268
271
 
269
272
 
270
- @socket.on("ui_message")
271
- async def message(sid, payload: UIMessagePayload):
273
+ @sio.on("client_message")
274
+ async def message(sid, payload: MessagePayload):
272
275
  """Handle a message sent by the User."""
273
276
  session = WebsocketSession.require(sid)
274
277
 
@@ -276,7 +279,7 @@ async def message(sid, payload: UIMessagePayload):
276
279
  session.current_task = task
277
280
 
278
281
 
279
- @socket.on("audio_chunk")
282
+ @sio.on("audio_chunk")
280
283
  async def audio_chunk(sid, payload: AudioChunkPayload):
281
284
  """Handle an audio chunk sent by the user."""
282
285
  session = WebsocketSession.require(sid)
@@ -287,7 +290,7 @@ async def audio_chunk(sid, payload: AudioChunkPayload):
287
290
  asyncio.create_task(config.code.on_audio_chunk(AudioChunk(**payload)))
288
291
 
289
292
 
290
- @socket.on("audio_end")
293
+ @sio.on("audio_end")
291
294
  async def audio_end(sid, payload: AudioEndPayload):
292
295
  """Handle the end of the audio stream."""
293
296
  session = WebsocketSession.require(sid)
@@ -331,7 +334,7 @@ async def process_action(action: Action):
331
334
  logger.warning("No callback found for action %s", action.name)
332
335
 
333
336
 
334
- @socket.on("action_call")
337
+ @sio.on("action_call")
335
338
  async def call_action(sid, action):
336
339
  """Handle an action call from the UI."""
337
340
  context = init_ws_context(sid)
@@ -358,7 +361,7 @@ async def call_action(sid, action):
358
361
  )
359
362
 
360
363
 
361
- @socket.on("chat_settings_change")
364
+ @sio.on("chat_settings_change")
362
365
  async def change_settings(sid, settings: Dict[str, Any]):
363
366
  """Handle change settings submit from the UI."""
364
367
  context = init_ws_context(sid)