chainlit 1.3.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (82) hide show
  1. chainlit/__init__.py +58 -56
  2. chainlit/action.py +12 -10
  3. chainlit/{auth.py → auth/__init__.py} +24 -34
  4. chainlit/auth/cookie.py +123 -0
  5. chainlit/auth/jwt.py +37 -0
  6. chainlit/cache.py +4 -6
  7. chainlit/callbacks.py +65 -11
  8. chainlit/chat_context.py +2 -2
  9. chainlit/chat_settings.py +3 -1
  10. chainlit/cli/__init__.py +15 -2
  11. chainlit/config.py +46 -90
  12. chainlit/context.py +4 -3
  13. chainlit/copilot/dist/index.js +8608 -642
  14. chainlit/data/__init__.py +96 -8
  15. chainlit/data/acl.py +3 -2
  16. chainlit/data/base.py +1 -15
  17. chainlit/data/chainlit_data_layer.py +584 -0
  18. chainlit/data/dynamodb.py +7 -4
  19. chainlit/data/literalai.py +4 -6
  20. chainlit/data/sql_alchemy.py +9 -8
  21. chainlit/data/storage_clients/__init__.py +0 -0
  22. chainlit/data/{storage_clients.py → storage_clients/azure.py} +2 -33
  23. chainlit/data/storage_clients/azure_blob.py +80 -0
  24. chainlit/data/storage_clients/base.py +22 -0
  25. chainlit/data/storage_clients/gcs.py +78 -0
  26. chainlit/data/storage_clients/s3.py +49 -0
  27. chainlit/discord/__init__.py +4 -4
  28. chainlit/discord/app.py +2 -1
  29. chainlit/element.py +41 -9
  30. chainlit/emitter.py +37 -16
  31. chainlit/frontend/dist/assets/{DailyMotion-CwoOhIL8.js → DailyMotion-DgRzV5GZ.js} +1 -1
  32. chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +22 -0
  33. chainlit/frontend/dist/assets/{Facebook-BhnGXlzq.js → Facebook-C0vx6HWv.js} +1 -1
  34. chainlit/frontend/dist/assets/{FilePlayer-CPSVT6fz.js → FilePlayer-CdhzeHPP.js} +1 -1
  35. chainlit/frontend/dist/assets/{Kaltura-COYaLzsL.js → Kaltura-5iVmeUct.js} +1 -1
  36. chainlit/frontend/dist/assets/{Mixcloud-JdadNiQ5.js → Mixcloud-C2zi77Ex.js} +1 -1
  37. chainlit/frontend/dist/assets/{Mux-CBN7RO2u.js → Mux-Vkebogdf.js} +1 -1
  38. chainlit/frontend/dist/assets/{Preview-CxAFvvjV.js → Preview-DwY_sEIl.js} +1 -1
  39. chainlit/frontend/dist/assets/{SoundCloud-JlgmASWm.js → SoundCloud-CREBXAWo.js} +1 -1
  40. chainlit/frontend/dist/assets/{Streamable-CUWgr6Zw.js → Streamable-B5Lu25uy.js} +1 -1
  41. chainlit/frontend/dist/assets/{Twitch-BiN1HEDM.js → Twitch-y9iKCcM1.js} +1 -1
  42. chainlit/frontend/dist/assets/{Vidyard-qhPmrhDm.js → Vidyard-ClYvcuEu.js} +1 -1
  43. chainlit/frontend/dist/assets/{Vimeo-CrZVSCaT.js → Vimeo-D6HvM2jt.js} +1 -1
  44. chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +1 -0
  45. chainlit/frontend/dist/assets/{YouTube-DKjw5Hbn.js → YouTube-D10tR6CJ.js} +1 -1
  46. chainlit/frontend/dist/assets/index-CI4qFOt5.js +8665 -0
  47. chainlit/frontend/dist/assets/index-CrrqM0nZ.css +1 -0
  48. chainlit/frontend/dist/assets/{react-plotly-Dpmqg5Sy.js → react-plotly-BpxUS-ab.js} +1 -1
  49. chainlit/frontend/dist/index.html +2 -2
  50. chainlit/haystack/callbacks.py +5 -4
  51. chainlit/input_widget.py +6 -4
  52. chainlit/langchain/callbacks.py +56 -47
  53. chainlit/langflow/__init__.py +1 -0
  54. chainlit/llama_index/callbacks.py +7 -7
  55. chainlit/message.py +8 -10
  56. chainlit/mistralai/__init__.py +3 -2
  57. chainlit/oauth_providers.py +70 -3
  58. chainlit/openai/__init__.py +3 -2
  59. chainlit/secret.py +1 -1
  60. chainlit/server.py +481 -182
  61. chainlit/session.py +7 -5
  62. chainlit/slack/__init__.py +3 -3
  63. chainlit/slack/app.py +3 -2
  64. chainlit/socket.py +89 -112
  65. chainlit/step.py +12 -12
  66. chainlit/sync.py +2 -1
  67. chainlit/teams/__init__.py +3 -3
  68. chainlit/teams/app.py +1 -0
  69. chainlit/translations/en-US.json +2 -1
  70. chainlit/translations/nl-NL.json +229 -0
  71. chainlit/types.py +24 -8
  72. chainlit/user.py +2 -1
  73. chainlit/utils.py +3 -2
  74. chainlit/version.py +3 -2
  75. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/METADATA +17 -37
  76. chainlit-2.0.0.dist-info/RECORD +106 -0
  77. chainlit/frontend/dist/assets/Wistia-C891KrBP.js +0 -1
  78. chainlit/frontend/dist/assets/index-CwmincdQ.css +0 -1
  79. chainlit/frontend/dist/assets/index-DLRdQOIx.js +0 -723
  80. chainlit-1.3.1.dist-info/RECORD +0 -96
  81. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/WHEEL +0 -0
  82. {chainlit-1.3.1.dist-info → chainlit-2.0.0.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import fnmatch
2
3
  import glob
3
4
  import json
4
5
  import mimetypes
@@ -9,10 +10,36 @@ import urllib.parse
9
10
  import webbrowser
10
11
  from contextlib import asynccontextmanager
11
12
  from pathlib import Path
12
- from typing import Any, Optional, Union
13
+ from typing import List, Optional, Union, cast
13
14
 
14
15
  import socketio
15
- from chainlit.auth import create_jwt, get_configuration, get_current_user
16
+ from fastapi import (
17
+ APIRouter,
18
+ Depends,
19
+ FastAPI,
20
+ Form,
21
+ HTTPException,
22
+ Query,
23
+ Request,
24
+ Response,
25
+ UploadFile,
26
+ status,
27
+ )
28
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
+ from fastapi.security import OAuth2PasswordRequestForm
30
+ from starlette.datastructures import URL
31
+ from starlette.middleware.cors import CORSMiddleware
32
+ from typing_extensions import Annotated
33
+ from watchfiles import awatch
34
+
35
+ from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
36
+ from chainlit.auth.cookie import (
37
+ clear_auth_cookie,
38
+ clear_oauth_state_cookie,
39
+ set_auth_cookie,
40
+ set_oauth_state_cookie,
41
+ validate_oauth_state_cookie,
42
+ )
16
43
  from chainlit.config import (
17
44
  APP_ROOT,
18
45
  BACKEND_ROOT,
@@ -21,6 +48,7 @@ from chainlit.config import (
21
48
  PACKAGE_ROOT,
22
49
  config,
23
50
  load_module,
51
+ public_dir,
24
52
  reload_config,
25
53
  )
26
54
  from chainlit.data import get_data_layer
@@ -30,33 +58,16 @@ from chainlit.markdown import get_markdown_str
30
58
  from chainlit.oauth_providers import get_oauth_provider
31
59
  from chainlit.secret import random_secret
32
60
  from chainlit.types import (
61
+ CallActionRequest,
33
62
  DeleteFeedbackRequest,
34
63
  DeleteThreadRequest,
64
+ ElementRequest,
35
65
  GetThreadsRequest,
36
66
  Theme,
37
67
  UpdateFeedbackRequest,
68
+ UpdateThreadRequest,
38
69
  )
39
70
  from chainlit.user import PersistedUser, User
40
- from fastapi import (
41
- APIRouter,
42
- Depends,
43
- FastAPI,
44
- File,
45
- Form,
46
- HTTPException,
47
- Query,
48
- Request,
49
- Response,
50
- UploadFile,
51
- status,
52
- )
53
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
54
- from fastapi.security import OAuth2PasswordRequestForm
55
- from fastapi.staticfiles import StaticFiles
56
- from starlette.datastructures import URL
57
- from starlette.middleware.cors import CORSMiddleware
58
- from typing_extensions import Annotated
59
- from watchfiles import awatch
60
71
 
61
72
  from ._utils import is_path_inside
62
73
 
@@ -205,29 +216,59 @@ app.add_middleware(
205
216
 
206
217
  router = APIRouter(prefix=PREFIX)
207
218
 
208
- app.mount(
209
- f"{PREFIX}/public",
210
- StaticFiles(directory="public", check_dir=False),
211
- name="public",
212
- )
213
219
 
214
- app.mount(
215
- f"{PREFIX}/assets",
216
- StaticFiles(
217
- packages=[("chainlit", os.path.join(build_dir, "assets"))],
218
- follow_symlink=config.project.follow_symlink,
219
- ),
220
- name="assets",
221
- )
220
+ @router.get("/public/{filename:path}")
221
+ async def serve_public_file(
222
+ filename: str,
223
+ ):
224
+ """Serve a file from public dir."""
222
225
 
223
- app.mount(
224
- f"{PREFIX}/copilot",
225
- StaticFiles(
226
- packages=[("chainlit", copilot_build_dir)],
227
- follow_symlink=config.project.follow_symlink,
228
- ),
229
- name="copilot",
230
- )
226
+ base_path = Path(public_dir)
227
+ file_path = (base_path / filename).resolve()
228
+
229
+ if not is_path_inside(file_path, base_path):
230
+ raise HTTPException(status_code=400, detail="Invalid filename")
231
+
232
+ if file_path.is_file():
233
+ return FileResponse(file_path)
234
+ else:
235
+ raise HTTPException(status_code=404, detail="File not found")
236
+
237
+
238
+ @router.get("/assets/{filename:path}")
239
+ async def serve_asset_file(
240
+ filename: str,
241
+ ):
242
+ """Serve a file from assets dir."""
243
+
244
+ base_path = Path(os.path.join(build_dir, "assets"))
245
+ file_path = (base_path / filename).resolve()
246
+
247
+ if not is_path_inside(file_path, base_path):
248
+ raise HTTPException(status_code=400, detail="Invalid filename")
249
+
250
+ if file_path.is_file():
251
+ return FileResponse(file_path)
252
+ else:
253
+ raise HTTPException(status_code=404, detail="File not found")
254
+
255
+
256
+ @router.get("/copilot/{filename:path}")
257
+ async def serve_copilot_file(
258
+ filename: str,
259
+ ):
260
+ """Serve a file from assets dir."""
261
+
262
+ base_path = Path(copilot_build_dir)
263
+ file_path = (base_path / filename).resolve()
264
+
265
+ if not is_path_inside(file_path, base_path):
266
+ raise HTTPException(status_code=400, detail="Invalid filename")
267
+
268
+ if file_path.is_file():
269
+ return FileResponse(file_path)
270
+ else:
271
+ raise HTTPException(status_code=404, detail="File not found")
231
272
 
232
273
 
233
274
  # -------------------------------------------------------------------------------
@@ -248,6 +289,7 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
248
289
 
249
290
  if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
250
291
  from botbuilder.schema import Activity
292
+
251
293
  from chainlit.teams.app import adapter, bot
252
294
 
253
295
  @router.post("/teams/events")
@@ -277,6 +319,16 @@ def get_html_template():
277
319
  """
278
320
  Get HTML template for the index view.
279
321
  """
322
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
323
+
324
+ custom_theme = None
325
+ custom_theme_file_path = Path(public_dir) / "theme.json"
326
+ if (
327
+ is_path_inside(custom_theme_file_path, Path(public_dir))
328
+ and custom_theme_file_path.is_file()
329
+ ):
330
+ custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
331
+
280
332
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
281
333
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
282
334
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -299,7 +351,10 @@ def get_html_template():
299
351
  <meta property="og:url" content="{url}">
300
352
  <meta property="og:root_path" content="{ROOT_PATH}">"""
301
353
 
302
- js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
354
+ js = f"""<script>
355
+ {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
356
+ {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
357
+ </script>"""
303
358
 
304
359
  css = None
305
360
  if config.ui.custom_css:
@@ -311,12 +366,15 @@ def get_html_template():
311
366
  js += f"""<script src="{config.ui.custom_js}" defer></script>"""
312
367
 
313
368
  font = None
314
- if config.ui.custom_font:
315
- font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
369
+ if custom_theme and custom_theme.get("custom_fonts"):
370
+ font = "\n".join(
371
+ f"""<link rel="stylesheet" href="{font}">"""
372
+ for font in custom_theme.get("custom_fonts")
373
+ )
316
374
 
317
375
  index_html_file_path = os.path.join(build_dir, "index.html")
318
376
 
319
- with open(index_html_file_path, "r", encoding="utf-8") as f:
377
+ with open(index_html_file_path, encoding="utf-8") as f:
320
378
  content = f.read()
321
379
  content = content.replace(PLACEHOLDER, tags)
322
380
  if js:
@@ -361,46 +419,132 @@ async def auth(request: Request):
361
419
  return get_configuration()
362
420
 
363
421
 
364
- @router.post("/login")
365
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
366
- """
367
- Login a user using the password auth callback.
368
- """
369
- if not config.code.password_auth_callback:
370
- raise HTTPException(
371
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
422
+ def _get_response_dict(access_token: str) -> dict:
423
+ """Get the response dictionary for the auth response."""
424
+
425
+ return {"success": True}
426
+
427
+
428
+ def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
429
+ """Get the redirect params for the OAuth callback."""
430
+
431
+ response_dict = _get_response_dict(access_token)
432
+
433
+ if redirect_to_callback:
434
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
435
+ redirect_url = (
436
+ f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
372
437
  )
373
438
 
374
- user = await config.code.password_auth_callback(
375
- form_data.username, form_data.password
439
+ return RedirectResponse(
440
+ # FIXME: redirect to the right frontend base url to improve the dev environment
441
+ url=redirect_url,
442
+ status_code=302,
443
+ )
444
+
445
+ return JSONResponse(response_dict)
446
+
447
+
448
+ def _get_oauth_redirect_error(error: str) -> Response:
449
+ """Get the redirect response for an OAuth error."""
450
+ params = urllib.parse.urlencode(
451
+ {
452
+ "error": error,
453
+ }
376
454
  )
455
+ response = RedirectResponse(
456
+ # FIXME: redirect to the right frontend base url to improve the dev environment
457
+ url=f"/login?{params}", # Shouldn't there be {root_path} here?
458
+ )
459
+ return response
460
+
461
+
462
+ async def _authenticate_user(
463
+ user: Optional[User], redirect_to_callback: bool = False
464
+ ) -> Response:
465
+ """Authenticate a user and return the response."""
377
466
 
378
467
  if not user:
379
468
  raise HTTPException(
380
469
  status_code=status.HTTP_401_UNAUTHORIZED,
381
470
  detail="credentialssignin",
382
471
  )
383
- access_token = create_jwt(user)
472
+
473
+ # If a data layer is defined, attempt to persist user.
384
474
  if data_layer := get_data_layer():
385
475
  try:
386
476
  await data_layer.create_user(user)
387
477
  except Exception as e:
478
+ # Catch and log exceptions during user creation.
479
+ # TODO: Make this catch only specific errors and allow others to propagate.
388
480
  logger.error(f"Error creating user: {e}")
389
481
 
390
- return {
391
- "access_token": access_token,
392
- "token_type": "bearer",
393
- }
482
+ access_token = create_jwt(user)
483
+
484
+ response = _get_auth_response(access_token, redirect_to_callback)
485
+
486
+ set_auth_cookie(response, access_token)
487
+
488
+ return response
489
+
490
+
491
+ @router.post("/login")
492
+ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
493
+ """
494
+ Login a user using the password auth callback.
495
+ """
496
+ if not config.code.password_auth_callback:
497
+ raise HTTPException(
498
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
499
+ )
500
+
501
+ user = await config.code.password_auth_callback(
502
+ form_data.username, form_data.password
503
+ )
504
+
505
+ return await _authenticate_user(user)
394
506
 
395
507
 
396
508
  @router.post("/logout")
397
509
  async def logout(request: Request, response: Response):
398
510
  """Logout the user by calling the on_logout callback."""
511
+ clear_auth_cookie(response)
512
+
399
513
  if config.code.on_logout:
400
514
  return await config.code.on_logout(request, response)
515
+
401
516
  return {"success": True}
402
517
 
403
518
 
519
+ @router.post("/auth/jwt")
520
+ async def jwt_auth(request: Request):
521
+ """Login a user using a valid jwt."""
522
+ from jwt import InvalidTokenError
523
+
524
+ auth_header: Optional[str] = request.headers.get("Authorization")
525
+ if not auth_header:
526
+ raise HTTPException(status_code=401, detail="Authorization header missing")
527
+
528
+ # Check if it starts with "Bearer "
529
+ try:
530
+ scheme, token = auth_header.split()
531
+ if scheme.lower() != "bearer":
532
+ raise HTTPException(
533
+ status_code=401,
534
+ detail="Invalid authentication scheme. Please use Bearer",
535
+ )
536
+ except ValueError:
537
+ raise HTTPException(
538
+ status_code=401, detail="Invalid authorization header format"
539
+ )
540
+
541
+ try:
542
+ user = decode_jwt(token)
543
+ return await _authenticate_user(user)
544
+ except InvalidTokenError:
545
+ raise HTTPException(status_code=401, detail="Invalid token")
546
+
547
+
404
548
  @router.post("/auth/header")
405
549
  async def header_auth(request: Request):
406
550
  """Login a user using the header_auth_callback."""
@@ -412,23 +556,7 @@ async def header_auth(request: Request):
412
556
 
413
557
  user = await config.code.header_auth_callback(request.headers)
414
558
 
415
- if not user:
416
- raise HTTPException(
417
- status_code=status.HTTP_401_UNAUTHORIZED,
418
- detail="Unauthorized",
419
- )
420
-
421
- access_token = create_jwt(user)
422
- if data_layer := get_data_layer():
423
- try:
424
- await data_layer.create_user(user)
425
- except Exception as e:
426
- logger.error(f"Error creating user: {e}")
427
-
428
- return {
429
- "access_token": access_token,
430
- "token_type": "bearer",
431
- }
559
+ return await _authenticate_user(user)
432
560
 
433
561
 
434
562
  @router.get("/auth/oauth/{provider_id}")
@@ -460,16 +588,9 @@ async def oauth_login(provider_id: str, request: Request):
460
588
  response = RedirectResponse(
461
589
  url=f"{provider.authorize_url}?{params}",
462
590
  )
463
- samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
464
- secure = samesite.lower() == "none"
465
- response.set_cookie(
466
- "oauth_state",
467
- random,
468
- httponly=True,
469
- samesite=samesite,
470
- secure=secure,
471
- max_age=3 * 60,
472
- )
591
+
592
+ set_oauth_state_cookie(response, random)
593
+
473
594
  return response
474
595
 
475
596
 
@@ -497,16 +618,7 @@ async def oauth_callback(
497
618
  )
498
619
 
499
620
  if error:
500
- params = urllib.parse.urlencode(
501
- {
502
- "error": error,
503
- }
504
- )
505
- response = RedirectResponse(
506
- # FIXME: redirect to the right frontend base url to improve the dev environment
507
- url=f"/login?{params}",
508
- )
509
- return response
621
+ return _get_oauth_redirect_error(error)
510
622
 
511
623
  if not code or not state:
512
624
  raise HTTPException(
@@ -514,9 +626,11 @@ async def oauth_callback(
514
626
  detail="Missing code or state",
515
627
  )
516
628
 
517
- # Check the state from the oauth provider against the browser cookie
518
- oauth_state = request.cookies.get("oauth_state")
519
- if oauth_state != state:
629
+ try:
630
+ validate_oauth_state_cookie(request, state)
631
+ except Exception as e:
632
+ logger.exception("Unable to validate oauth state: %1", e)
633
+
520
634
  raise HTTPException(
521
635
  status_code=status.HTTP_401_UNAUTHORIZED,
522
636
  detail="Unauthorized",
@@ -531,34 +645,10 @@ async def oauth_callback(
531
645
  provider_id, token, raw_user_data, default_user
532
646
  )
533
647
 
534
- if not user:
535
- raise HTTPException(
536
- status_code=status.HTTP_401_UNAUTHORIZED,
537
- detail="Unauthorized",
538
- )
539
-
540
- access_token = create_jwt(user)
541
-
542
- if data_layer := get_data_layer():
543
- try:
544
- await data_layer.create_user(user)
545
- except Exception as e:
546
- logger.error(f"Error creating user: {e}")
547
-
548
- params = urllib.parse.urlencode(
549
- {
550
- "access_token": access_token,
551
- "token_type": "bearer",
552
- }
553
- )
648
+ response = await _authenticate_user(user, redirect_to_callback=True)
554
649
 
555
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
650
+ clear_oauth_state_cookie(response)
556
651
 
557
- response = RedirectResponse(
558
- # FIXME: redirect to the right frontend base url to improve the dev environment
559
- url=f"{root_path}/login/callback?{params}",
560
- )
561
- response.delete_cookie("oauth_state")
562
652
  return response
563
653
 
564
654
 
@@ -587,16 +677,7 @@ async def oauth_azure_hf_callback(
587
677
  )
588
678
 
589
679
  if error:
590
- params = urllib.parse.urlencode(
591
- {
592
- "error": error,
593
- }
594
- )
595
- response = RedirectResponse(
596
- # FIXME: redirect to the right frontend base url to improve the dev environment
597
- url=f"/login?{params}",
598
- )
599
- return response
680
+ return _get_oauth_redirect_error(error)
600
681
 
601
682
  if not code:
602
683
  raise HTTPException(
@@ -613,36 +694,20 @@ async def oauth_azure_hf_callback(
613
694
  provider_id, token, raw_user_data, default_user, id_token
614
695
  )
615
696
 
616
- if not user:
617
- raise HTTPException(
618
- status_code=status.HTTP_401_UNAUTHORIZED,
619
- detail="Unauthorized",
620
- )
697
+ response = await _authenticate_user(user, redirect_to_callback=True)
621
698
 
622
- access_token = create_jwt(user)
699
+ clear_oauth_state_cookie(response)
623
700
 
624
- if data_layer := get_data_layer():
625
- try:
626
- await data_layer.create_user(user)
627
- except Exception as e:
628
- logger.error(f"Error creating user: {e}")
701
+ return response
629
702
 
630
- params = urllib.parse.urlencode(
631
- {
632
- "access_token": access_token,
633
- "token_type": "bearer",
634
- }
635
- )
636
703
 
637
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
704
+ GenericUser = Union[User, PersistedUser, None]
705
+ UserParam = Annotated[GenericUser, Depends(get_current_user)]
638
706
 
639
- response = RedirectResponse(
640
- # FIXME: redirect to the right frontend base url to improve the dev environment
641
- url=f"{root_path}/login/callback?{params}",
642
- status_code=302,
643
- )
644
- response.delete_cookie("oauth_state")
645
- return response
707
+
708
+ @router.get("/user")
709
+ async def get_user(current_user: UserParam) -> GenericUser:
710
+ return current_user
646
711
 
647
712
 
648
713
  _language_pattern = (
@@ -670,7 +735,7 @@ async def project_translations(
670
735
 
671
736
  @router.get("/project/settings")
672
737
  async def project_settings(
673
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
738
+ current_user: UserParam,
674
739
  language: str = Query(
675
740
  default="en-US", description="Language code", pattern=_language_pattern
676
741
  ),
@@ -721,7 +786,7 @@ async def project_settings(
721
786
  async def update_feedback(
722
787
  request: Request,
723
788
  update: UpdateFeedbackRequest,
724
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
789
+ current_user: UserParam,
725
790
  ):
726
791
  """Update the human feedback for a particular message."""
727
792
  data_layer = get_data_layer()
@@ -731,7 +796,7 @@ async def update_feedback(
731
796
  try:
732
797
  feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
733
798
  except Exception as e:
734
- raise HTTPException(detail=str(e), status_code=500)
799
+ raise HTTPException(detail=str(e), status_code=500) from e
735
800
 
736
801
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
737
802
 
@@ -740,7 +805,7 @@ async def update_feedback(
740
805
  async def delete_feedback(
741
806
  request: Request,
742
807
  payload: DeleteFeedbackRequest,
743
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
808
+ current_user: UserParam,
744
809
  ):
745
810
  """Delete a feedback."""
746
811
 
@@ -759,7 +824,7 @@ async def delete_feedback(
759
824
  async def get_user_threads(
760
825
  request: Request,
761
826
  payload: GetThreadsRequest,
762
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
827
+ current_user: UserParam,
763
828
  ):
764
829
  """Get the threads page by page."""
765
830
 
@@ -768,6 +833,9 @@ async def get_user_threads(
768
833
  if not data_layer:
769
834
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
770
835
 
836
+ if not current_user:
837
+ raise HTTPException(status_code=401, detail="Unauthorized")
838
+
771
839
  if not isinstance(current_user, PersistedUser):
772
840
  persisted_user = await data_layer.get_user(identifier=current_user.identifier)
773
841
  if not persisted_user:
@@ -784,7 +852,7 @@ async def get_user_threads(
784
852
  async def get_thread(
785
853
  request: Request,
786
854
  thread_id: str,
787
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
855
+ current_user: UserParam,
788
856
  ):
789
857
  """Get a specific thread."""
790
858
  data_layer = get_data_layer()
@@ -792,6 +860,9 @@ async def get_thread(
792
860
  if not data_layer:
793
861
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
794
862
 
863
+ if not current_user:
864
+ raise HTTPException(status_code=401, detail="Unauthorized")
865
+
795
866
  await is_thread_author(current_user.identifier, thread_id)
796
867
 
797
868
  res = await data_layer.get_thread(thread_id)
@@ -803,7 +874,7 @@ async def get_thread_element(
803
874
  request: Request,
804
875
  thread_id: str,
805
876
  element_id: str,
806
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
877
+ current_user: UserParam,
807
878
  ):
808
879
  """Get a specific thread element."""
809
880
  data_layer = get_data_layer()
@@ -811,17 +882,135 @@ async def get_thread_element(
811
882
  if not data_layer:
812
883
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
813
884
 
885
+ if not current_user:
886
+ raise HTTPException(status_code=401, detail="Unauthorized")
887
+
814
888
  await is_thread_author(current_user.identifier, thread_id)
815
889
 
816
890
  res = await data_layer.get_element(thread_id, element_id)
817
891
  return JSONResponse(content=res)
818
892
 
819
893
 
894
+ @router.put("/project/element")
895
+ async def update_thread_element(
896
+ payload: ElementRequest,
897
+ current_user: UserParam,
898
+ ):
899
+ """Update a specific thread element."""
900
+
901
+ from chainlit.context import init_ws_context
902
+ from chainlit.element import CustomElement, ElementDict
903
+ from chainlit.session import WebsocketSession
904
+
905
+ session = WebsocketSession.get_by_id(payload.sessionId)
906
+ context = init_ws_context(session)
907
+
908
+ element_dict = cast(ElementDict, payload.element)
909
+
910
+ if element_dict["type"] != "custom":
911
+ return {"success": False}
912
+
913
+ element = CustomElement(
914
+ id=element_dict["id"],
915
+ object_key=element_dict["objectKey"],
916
+ chainlit_key=element_dict["chainlitKey"],
917
+ url=element_dict["url"],
918
+ for_id=element_dict.get("forId") or "",
919
+ thread_id=element_dict.get("threadId") or "",
920
+ name=element_dict["name"],
921
+ props=element_dict.get("props") or {},
922
+ display=element_dict["display"],
923
+ )
924
+
925
+ if current_user:
926
+ if (
927
+ not context.session.user
928
+ or context.session.user.identifier != current_user.identifier
929
+ ):
930
+ raise HTTPException(
931
+ status_code=401,
932
+ detail="You are not authorized to update elements for this session",
933
+ )
934
+
935
+ await element.send(for_id=element.for_id or "")
936
+ return {"success": True}
937
+
938
+
939
+ @router.delete("/project/element")
940
+ async def delete_thread_element(
941
+ payload: ElementRequest,
942
+ current_user: UserParam,
943
+ ):
944
+ """Delete a specific thread element."""
945
+
946
+ from chainlit.context import init_ws_context
947
+ from chainlit.element import CustomElement, ElementDict
948
+ from chainlit.session import WebsocketSession
949
+
950
+ session = WebsocketSession.get_by_id(payload.sessionId)
951
+ context = init_ws_context(session)
952
+
953
+ element_dict = cast(ElementDict, payload.element)
954
+
955
+ if element_dict["type"] != "custom":
956
+ return {"success": False}
957
+
958
+ element = CustomElement(
959
+ id=element_dict["id"],
960
+ object_key=element_dict["objectKey"],
961
+ chainlit_key=element_dict["chainlitKey"],
962
+ url=element_dict["url"],
963
+ for_id=element_dict.get("forId") or "",
964
+ thread_id=element_dict.get("threadId") or "",
965
+ name=element_dict["name"],
966
+ props=element_dict.get("props") or {},
967
+ display=element_dict["display"],
968
+ )
969
+
970
+ if current_user:
971
+ if (
972
+ not context.session.user
973
+ or context.session.user.identifier != current_user.identifier
974
+ ):
975
+ raise HTTPException(
976
+ status_code=401,
977
+ detail="You are not authorized to remove elements for this session",
978
+ )
979
+
980
+ await element.remove()
981
+
982
+ return {"success": True}
983
+
984
+
985
+ @router.put("/project/thread")
986
+ async def rename_thread(
987
+ request: Request,
988
+ payload: UpdateThreadRequest,
989
+ current_user: UserParam,
990
+ ):
991
+ """Rename a thread."""
992
+
993
+ data_layer = get_data_layer()
994
+
995
+ if not data_layer:
996
+ raise HTTPException(status_code=400, detail="Data persistence is not enabled")
997
+
998
+ if not current_user:
999
+ raise HTTPException(status_code=401, detail="Unauthorized")
1000
+
1001
+ thread_id = payload.threadId
1002
+
1003
+ await is_thread_author(current_user.identifier, thread_id)
1004
+
1005
+ await data_layer.update_thread(thread_id, name=payload.name)
1006
+ return JSONResponse(content={"success": True})
1007
+
1008
+
820
1009
  @router.delete("/project/thread")
821
1010
  async def delete_thread(
822
1011
  request: Request,
823
1012
  payload: DeleteThreadRequest,
824
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1013
+ current_user: UserParam,
825
1014
  ):
826
1015
  """Delete a thread."""
827
1016
 
@@ -830,6 +1019,9 @@ async def delete_thread(
830
1019
  if not data_layer:
831
1020
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
832
1021
 
1022
+ if not current_user:
1023
+ raise HTTPException(status_code=401, detail="Unauthorized")
1024
+
833
1025
  thread_id = payload.threadId
834
1026
 
835
1027
  await is_thread_author(current_user.identifier, thread_id)
@@ -838,9 +1030,47 @@ async def delete_thread(
838
1030
  return JSONResponse(content={"success": True})
839
1031
 
840
1032
 
1033
+ @router.post("/project/action")
1034
+ async def call_action(
1035
+ payload: CallActionRequest,
1036
+ current_user: UserParam,
1037
+ ):
1038
+ """Run an action."""
1039
+
1040
+ from chainlit.action import Action
1041
+ from chainlit.context import init_ws_context
1042
+ from chainlit.session import WebsocketSession
1043
+
1044
+ session = WebsocketSession.get_by_id(payload.sessionId)
1045
+ context = init_ws_context(session)
1046
+
1047
+ action = Action(**payload.action)
1048
+
1049
+ if current_user:
1050
+ if (
1051
+ not context.session.user
1052
+ or context.session.user.identifier != current_user.identifier
1053
+ ):
1054
+ raise HTTPException(
1055
+ status_code=401,
1056
+ detail="You are not authorized to upload files for this session",
1057
+ )
1058
+
1059
+ callback = config.code.action_callbacks.get(action.name)
1060
+ if callback:
1061
+ await callback(action)
1062
+ else:
1063
+ raise HTTPException(
1064
+ status_code=404,
1065
+ detail=f"No callback found for action {action.name}",
1066
+ )
1067
+
1068
+ return JSONResponse(content={"success": True})
1069
+
1070
+
841
1071
  @router.post("/project/file")
842
1072
  async def upload_file(
843
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1073
+ current_user: UserParam,
844
1074
  session_id: str,
845
1075
  file: UploadFile,
846
1076
  ):
@@ -870,6 +1100,11 @@ async def upload_file(
870
1100
  assert file.filename, "No filename for uploaded file"
871
1101
  assert file.content_type, "No content type for uploaded file"
872
1102
 
1103
+ try:
1104
+ validate_file_upload(file)
1105
+ except ValueError as e:
1106
+ raise HTTPException(status_code=400, detail=str(e))
1107
+
873
1108
  file_response = await session.persist_file(
874
1109
  name=file.filename, content=content, mime=file.content_type
875
1110
  )
@@ -877,14 +1112,79 @@ async def upload_file(
877
1112
  return JSONResponse(content=file_response)
878
1113
 
879
1114
 
1115
+ def validate_file_upload(file: UploadFile):
1116
+ """Validate the file upload as configured in config.features.spontaneous_file_upload.
1117
+ Args:
1118
+ file (UploadFile): The file to validate.
1119
+ Raises:
1120
+ ValueError: If the file is not allowed.
1121
+ """
1122
+ if config.features.spontaneous_file_upload is None:
1123
+ """Default for a missing config is to allow the fileupload without any restrictions"""
1124
+ return
1125
+ if config.features.spontaneous_file_upload.enabled is False:
1126
+ raise ValueError("File upload is not enabled")
1127
+
1128
+ validate_file_mime_type(file)
1129
+ validate_file_size(file)
1130
+
1131
+
1132
+ def validate_file_mime_type(file: UploadFile):
1133
+ """Validate the file mime type as configured in config.features.spontaneous_file_upload.
1134
+ Args:
1135
+ file (UploadFile): The file to validate.
1136
+ Raises:
1137
+ ValueError: If the file type is not allowed.
1138
+ """
1139
+ accept = config.features.spontaneous_file_upload.accept
1140
+ if accept is None:
1141
+ "Accept is not configured, allowing all file types"
1142
+ return
1143
+
1144
+ assert (
1145
+ isinstance(accept, List) or isinstance(accept, dict)
1146
+ ), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1147
+
1148
+ if isinstance(accept, List):
1149
+ for pattern in accept:
1150
+ if fnmatch.fnmatch(file.content_type, pattern):
1151
+ return
1152
+ elif isinstance(accept, dict):
1153
+ for pattern, extensions in accept.items():
1154
+ if fnmatch.fnmatch(file.content_type, pattern):
1155
+ if len(extensions) == 0:
1156
+ return
1157
+ for extension in extensions:
1158
+ if file.filename is not None and file.filename.endswith(extension):
1159
+ return
1160
+ raise ValueError("File type not allowed")
1161
+
1162
+
1163
+ def validate_file_size(file: UploadFile):
1164
+ """Validate the file size as configured in config.features.spontaneous_file_upload.
1165
+ Args:
1166
+ file (UploadFile): The file to validate.
1167
+ Raises:
1168
+ ValueError: If the file size is too large.
1169
+ """
1170
+ if config.features.spontaneous_file_upload.max_size_mb is None:
1171
+ return
1172
+
1173
+ if (
1174
+ file.size is not None
1175
+ and file.size
1176
+ > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
1177
+ ):
1178
+ raise ValueError("File size too large")
1179
+
1180
+
880
1181
  @router.get("/project/file/{file_id}")
881
1182
  async def get_file(
882
1183
  file_id: str,
883
1184
  session_id: str,
884
- # current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], #TODO: Causes 401 error. See https://github.com/Chainlit/chainlit/issues/1472
1185
+ current_user: UserParam,
885
1186
  ):
886
1187
  """Get a file from the session files directory."""
887
-
888
1188
  from chainlit.session import WebsocketSession
889
1189
 
890
1190
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -895,13 +1195,12 @@ async def get_file(
895
1195
  detail="Unauthorized",
896
1196
  )
897
1197
 
898
- #TODO: Causes 401 error. See https://github.com/Chainlit/chainlit/issues/1472
899
- # if current_user:
900
- # if not session.user or session.user.identifier != current_user.identifier:
901
- # raise HTTPException(
902
- # status_code=401,
903
- # detail="You are not authorized to download files from this session",
904
- # )
1198
+ if current_user:
1199
+ if not session.user or session.user.identifier != current_user.identifier:
1200
+ raise HTTPException(
1201
+ status_code=401,
1202
+ detail="You are not authorized to download files from this session",
1203
+ )
905
1204
 
906
1205
  if file_id in session.files:
907
1206
  file = session.files[file_id]
@@ -913,7 +1212,7 @@ async def get_file(
913
1212
  @router.get("/files/{filename:path}")
914
1213
  async def serve_file(
915
1214
  filename: str,
916
- current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1215
+ current_user: UserParam,
917
1216
  ):
918
1217
  """Serve a file from the local filesystem."""
919
1218