chainlit 2.0.0__py3-none-any.whl → 2.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (98) hide show
  1. chainlit/__init__.py +57 -56
  2. chainlit/action.py +10 -12
  3. chainlit/{auth/__init__.py → auth.py} +34 -26
  4. chainlit/cache.py +6 -4
  5. chainlit/callbacks.py +7 -52
  6. chainlit/chat_context.py +2 -2
  7. chainlit/chat_settings.py +1 -3
  8. chainlit/cli/__init__.py +2 -15
  9. chainlit/config.py +70 -41
  10. chainlit/context.py +9 -8
  11. chainlit/copilot/dist/index.js +874 -8533
  12. chainlit/data/__init__.py +8 -96
  13. chainlit/data/acl.py +2 -3
  14. chainlit/data/base.py +15 -1
  15. chainlit/data/dynamodb.py +4 -7
  16. chainlit/data/literalai.py +6 -4
  17. chainlit/data/sql_alchemy.py +9 -10
  18. chainlit/data/{storage_clients/azure.py → storage_clients.py} +33 -2
  19. chainlit/discord/__init__.py +4 -4
  20. chainlit/discord/app.py +1 -2
  21. chainlit/element.py +9 -41
  22. chainlit/emitter.py +21 -17
  23. chainlit/frontend/dist/assets/DailyMotion-b4b7af47.js +1 -0
  24. chainlit/frontend/dist/assets/Facebook-572972a0.js +1 -0
  25. chainlit/frontend/dist/assets/FilePlayer-85c69ca8.js +1 -0
  26. chainlit/frontend/dist/assets/Kaltura-dfc24672.js +1 -0
  27. chainlit/frontend/dist/assets/Mixcloud-705011f4.js +1 -0
  28. chainlit/frontend/dist/assets/Mux-4201a9e6.js +1 -0
  29. chainlit/frontend/dist/assets/Preview-23ba40a6.js +1 -0
  30. chainlit/frontend/dist/assets/SoundCloud-1a582d51.js +1 -0
  31. chainlit/frontend/dist/assets/Streamable-5017c4ba.js +1 -0
  32. chainlit/frontend/dist/assets/Twitch-bb2de2fa.js +1 -0
  33. chainlit/frontend/dist/assets/Vidyard-54e269b1.js +1 -0
  34. chainlit/frontend/dist/assets/Vimeo-d92c37dd.js +1 -0
  35. chainlit/frontend/dist/assets/Wistia-25a1363b.js +1 -0
  36. chainlit/frontend/dist/assets/YouTube-616e8cb7.js +1 -0
  37. chainlit/frontend/dist/assets/index-aaf974a9.css +1 -0
  38. chainlit/frontend/dist/assets/index-f5df2072.js +1027 -0
  39. chainlit/frontend/dist/assets/{react-plotly-BpxUS-ab.js → react-plotly-f0315f86.js} +94 -94
  40. chainlit/frontend/dist/index.html +3 -2
  41. chainlit/haystack/callbacks.py +4 -5
  42. chainlit/input_widget.py +4 -6
  43. chainlit/langchain/callbacks.py +47 -56
  44. chainlit/langflow/__init__.py +0 -1
  45. chainlit/llama_index/callbacks.py +7 -7
  46. chainlit/message.py +10 -8
  47. chainlit/mistralai/__init__.py +2 -3
  48. chainlit/oauth_providers.py +12 -113
  49. chainlit/openai/__init__.py +7 -6
  50. chainlit/secret.py +1 -1
  51. chainlit/server.py +181 -491
  52. chainlit/session.py +5 -7
  53. chainlit/slack/__init__.py +3 -3
  54. chainlit/slack/app.py +2 -3
  55. chainlit/socket.py +103 -78
  56. chainlit/step.py +29 -21
  57. chainlit/sync.py +1 -2
  58. chainlit/teams/__init__.py +3 -3
  59. chainlit/teams/app.py +0 -1
  60. chainlit/types.py +4 -20
  61. chainlit/user.py +1 -2
  62. chainlit/utils.py +2 -3
  63. chainlit/version.py +2 -3
  64. {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/METADATA +39 -27
  65. chainlit-2.0.dev0.dist-info/RECORD +96 -0
  66. chainlit/auth/cookie.py +0 -123
  67. chainlit/auth/jwt.py +0 -37
  68. chainlit/data/chainlit_data_layer.py +0 -584
  69. chainlit/data/storage_clients/__init__.py +0 -0
  70. chainlit/data/storage_clients/azure_blob.py +0 -80
  71. chainlit/data/storage_clients/base.py +0 -22
  72. chainlit/data/storage_clients/gcs.py +0 -78
  73. chainlit/data/storage_clients/s3.py +0 -49
  74. chainlit/frontend/dist/assets/DailyMotion-DgRzV5GZ.js +0 -1
  75. chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +0 -22
  76. chainlit/frontend/dist/assets/Facebook-C0vx6HWv.js +0 -1
  77. chainlit/frontend/dist/assets/FilePlayer-CdhzeHPP.js +0 -1
  78. chainlit/frontend/dist/assets/Kaltura-5iVmeUct.js +0 -1
  79. chainlit/frontend/dist/assets/Mixcloud-C2zi77Ex.js +0 -1
  80. chainlit/frontend/dist/assets/Mux-Vkebogdf.js +0 -1
  81. chainlit/frontend/dist/assets/Preview-DwY_sEIl.js +0 -1
  82. chainlit/frontend/dist/assets/SoundCloud-CREBXAWo.js +0 -1
  83. chainlit/frontend/dist/assets/Streamable-B5Lu25uy.js +0 -1
  84. chainlit/frontend/dist/assets/Twitch-y9iKCcM1.js +0 -1
  85. chainlit/frontend/dist/assets/Vidyard-ClYvcuEu.js +0 -1
  86. chainlit/frontend/dist/assets/Vimeo-D6HvM2jt.js +0 -1
  87. chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +0 -1
  88. chainlit/frontend/dist/assets/YouTube-D10tR6CJ.js +0 -1
  89. chainlit/frontend/dist/assets/index-CI4qFOt5.js +0 -8665
  90. chainlit/frontend/dist/assets/index-CrrqM0nZ.css +0 -1
  91. chainlit/translations/nl-NL.json +0 -229
  92. chainlit-2.0.0.dist-info/RECORD +0 -106
  93. /chainlit/copilot/dist/assets/{logo_dark-IkGJ_IwC.svg → logo_dark-2a3cf740.svg} +0 -0
  94. /chainlit/copilot/dist/assets/{logo_light-Bb_IPh6r.svg → logo_light-b078e7bc.svg} +0 -0
  95. /chainlit/frontend/dist/assets/{logo_dark-IkGJ_IwC.svg → logo_dark-2a3cf740.svg} +0 -0
  96. /chainlit/frontend/dist/assets/{logo_light-Bb_IPh6r.svg → logo_light-b078e7bc.svg} +0 -0
  97. {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/WHEEL +0 -0
  98. {chainlit-2.0.0.dist-info → chainlit-2.0.dev0.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import fnmatch
3
2
  import glob
4
3
  import json
5
4
  import mimetypes
@@ -10,36 +9,10 @@ import urllib.parse
10
9
  import webbrowser
11
10
  from contextlib import asynccontextmanager
12
11
  from pathlib import Path
13
- from typing import List, Optional, Union, cast
12
+ from typing import Any, Optional, Union
14
13
 
15
14
  import socketio
16
- from fastapi import (
17
- APIRouter,
18
- Depends,
19
- FastAPI,
20
- Form,
21
- HTTPException,
22
- Query,
23
- Request,
24
- Response,
25
- UploadFile,
26
- status,
27
- )
28
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
- from fastapi.security import OAuth2PasswordRequestForm
30
- from starlette.datastructures import URL
31
- from starlette.middleware.cors import CORSMiddleware
32
- from typing_extensions import Annotated
33
- from watchfiles import awatch
34
-
35
- from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
36
- from chainlit.auth.cookie import (
37
- clear_auth_cookie,
38
- clear_oauth_state_cookie,
39
- set_auth_cookie,
40
- set_oauth_state_cookie,
41
- validate_oauth_state_cookie,
42
- )
15
+ from chainlit.auth import create_jwt, get_configuration, get_current_user
43
16
  from chainlit.config import (
44
17
  APP_ROOT,
45
18
  BACKEND_ROOT,
@@ -48,7 +21,6 @@ from chainlit.config import (
48
21
  PACKAGE_ROOT,
49
22
  config,
50
23
  load_module,
51
- public_dir,
52
24
  reload_config,
53
25
  )
54
26
  from chainlit.data import get_data_layer
@@ -58,16 +30,32 @@ from chainlit.markdown import get_markdown_str
58
30
  from chainlit.oauth_providers import get_oauth_provider
59
31
  from chainlit.secret import random_secret
60
32
  from chainlit.types import (
61
- CallActionRequest,
62
33
  DeleteFeedbackRequest,
63
34
  DeleteThreadRequest,
64
- ElementRequest,
65
35
  GetThreadsRequest,
66
36
  Theme,
67
37
  UpdateFeedbackRequest,
68
- UpdateThreadRequest,
69
38
  )
70
39
  from chainlit.user import PersistedUser, User
40
+ from fastapi import (
41
+ APIRouter,
42
+ Depends,
43
+ FastAPI,
44
+ Form,
45
+ HTTPException,
46
+ Query,
47
+ Request,
48
+ Response,
49
+ UploadFile,
50
+ status,
51
+ )
52
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
53
+ from fastapi.security import OAuth2PasswordRequestForm
54
+ from fastapi.staticfiles import StaticFiles
55
+ from starlette.datastructures import URL
56
+ from starlette.middleware.cors import CORSMiddleware
57
+ from typing_extensions import Annotated
58
+ from watchfiles import awatch
71
59
 
72
60
  from ._utils import is_path_inside
73
61
 
@@ -216,59 +204,29 @@ app.add_middleware(
216
204
 
217
205
  router = APIRouter(prefix=PREFIX)
218
206
 
207
+ app.mount(
208
+ f"{PREFIX}/public",
209
+ StaticFiles(directory="public", check_dir=False),
210
+ name="public",
211
+ )
219
212
 
220
- @router.get("/public/{filename:path}")
221
- async def serve_public_file(
222
- filename: str,
223
- ):
224
- """Serve a file from public dir."""
225
-
226
- base_path = Path(public_dir)
227
- file_path = (base_path / filename).resolve()
228
-
229
- if not is_path_inside(file_path, base_path):
230
- raise HTTPException(status_code=400, detail="Invalid filename")
231
-
232
- if file_path.is_file():
233
- return FileResponse(file_path)
234
- else:
235
- raise HTTPException(status_code=404, detail="File not found")
236
-
237
-
238
- @router.get("/assets/{filename:path}")
239
- async def serve_asset_file(
240
- filename: str,
241
- ):
242
- """Serve a file from assets dir."""
243
-
244
- base_path = Path(os.path.join(build_dir, "assets"))
245
- file_path = (base_path / filename).resolve()
246
-
247
- if not is_path_inside(file_path, base_path):
248
- raise HTTPException(status_code=400, detail="Invalid filename")
249
-
250
- if file_path.is_file():
251
- return FileResponse(file_path)
252
- else:
253
- raise HTTPException(status_code=404, detail="File not found")
254
-
255
-
256
- @router.get("/copilot/{filename:path}")
257
- async def serve_copilot_file(
258
- filename: str,
259
- ):
260
- """Serve a file from assets dir."""
261
-
262
- base_path = Path(copilot_build_dir)
263
- file_path = (base_path / filename).resolve()
264
-
265
- if not is_path_inside(file_path, base_path):
266
- raise HTTPException(status_code=400, detail="Invalid filename")
213
+ app.mount(
214
+ f"{PREFIX}/assets",
215
+ StaticFiles(
216
+ packages=[("chainlit", os.path.join(build_dir, "assets"))],
217
+ follow_symlink=config.project.follow_symlink,
218
+ ),
219
+ name="assets",
220
+ )
267
221
 
268
- if file_path.is_file():
269
- return FileResponse(file_path)
270
- else:
271
- raise HTTPException(status_code=404, detail="File not found")
222
+ app.mount(
223
+ f"{PREFIX}/copilot",
224
+ StaticFiles(
225
+ packages=[("chainlit", copilot_build_dir)],
226
+ follow_symlink=config.project.follow_symlink,
227
+ ),
228
+ name="copilot",
229
+ )
272
230
 
273
231
 
274
232
  # -------------------------------------------------------------------------------
@@ -289,7 +247,6 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
289
247
 
290
248
  if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
291
249
  from botbuilder.schema import Activity
292
-
293
250
  from chainlit.teams.app import adapter, bot
294
251
 
295
252
  @router.post("/teams/events")
@@ -319,16 +276,6 @@ def get_html_template():
319
276
  """
320
277
  Get HTML template for the index view.
321
278
  """
322
- ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
323
-
324
- custom_theme = None
325
- custom_theme_file_path = Path(public_dir) / "theme.json"
326
- if (
327
- is_path_inside(custom_theme_file_path, Path(public_dir))
328
- and custom_theme_file_path.is_file()
329
- ):
330
- custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
331
-
332
279
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
333
280
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
334
281
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -351,10 +298,7 @@ def get_html_template():
351
298
  <meta property="og:url" content="{url}">
352
299
  <meta property="og:root_path" content="{ROOT_PATH}">"""
353
300
 
354
- js = f"""<script>
355
- {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
356
- {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
357
- </script>"""
301
+ js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
358
302
 
359
303
  css = None
360
304
  if config.ui.custom_css:
@@ -366,15 +310,12 @@ def get_html_template():
366
310
  js += f"""<script src="{config.ui.custom_js}" defer></script>"""
367
311
 
368
312
  font = None
369
- if custom_theme and custom_theme.get("custom_fonts"):
370
- font = "\n".join(
371
- f"""<link rel="stylesheet" href="{font}">"""
372
- for font in custom_theme.get("custom_fonts")
373
- )
313
+ if config.ui.custom_font:
314
+ font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
374
315
 
375
316
  index_html_file_path = os.path.join(build_dir, "index.html")
376
317
 
377
- with open(index_html_file_path, encoding="utf-8") as f:
318
+ with open(index_html_file_path, "r", encoding="utf-8") as f:
378
319
  content = f.read()
379
320
  content = content.replace(PLACEHOLDER, tags)
380
321
  if js:
@@ -419,132 +360,46 @@ async def auth(request: Request):
419
360
  return get_configuration()
420
361
 
421
362
 
422
- def _get_response_dict(access_token: str) -> dict:
423
- """Get the response dictionary for the auth response."""
424
-
425
- return {"success": True}
426
-
427
-
428
- def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
429
- """Get the redirect params for the OAuth callback."""
430
-
431
- response_dict = _get_response_dict(access_token)
432
-
433
- if redirect_to_callback:
434
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
435
- redirect_url = (
436
- f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
437
- )
438
-
439
- return RedirectResponse(
440
- # FIXME: redirect to the right frontend base url to improve the dev environment
441
- url=redirect_url,
442
- status_code=302,
363
+ @router.post("/login")
364
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
365
+ """
366
+ Login a user using the password auth callback.
367
+ """
368
+ if not config.code.password_auth_callback:
369
+ raise HTTPException(
370
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
443
371
  )
444
372
 
445
- return JSONResponse(response_dict)
446
-
447
-
448
- def _get_oauth_redirect_error(error: str) -> Response:
449
- """Get the redirect response for an OAuth error."""
450
- params = urllib.parse.urlencode(
451
- {
452
- "error": error,
453
- }
454
- )
455
- response = RedirectResponse(
456
- # FIXME: redirect to the right frontend base url to improve the dev environment
457
- url=f"/login?{params}", # Shouldn't there be {root_path} here?
373
+ user = await config.code.password_auth_callback(
374
+ form_data.username, form_data.password
458
375
  )
459
- return response
460
-
461
-
462
- async def _authenticate_user(
463
- user: Optional[User], redirect_to_callback: bool = False
464
- ) -> Response:
465
- """Authenticate a user and return the response."""
466
376
 
467
377
  if not user:
468
378
  raise HTTPException(
469
379
  status_code=status.HTTP_401_UNAUTHORIZED,
470
380
  detail="credentialssignin",
471
381
  )
472
-
473
- # If a data layer is defined, attempt to persist user.
382
+ access_token = create_jwt(user)
474
383
  if data_layer := get_data_layer():
475
384
  try:
476
385
  await data_layer.create_user(user)
477
386
  except Exception as e:
478
- # Catch and log exceptions during user creation.
479
- # TODO: Make this catch only specific errors and allow others to propagate.
480
387
  logger.error(f"Error creating user: {e}")
481
388
 
482
- access_token = create_jwt(user)
483
-
484
- response = _get_auth_response(access_token, redirect_to_callback)
485
-
486
- set_auth_cookie(response, access_token)
487
-
488
- return response
489
-
490
-
491
- @router.post("/login")
492
- async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
493
- """
494
- Login a user using the password auth callback.
495
- """
496
- if not config.code.password_auth_callback:
497
- raise HTTPException(
498
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
499
- )
500
-
501
- user = await config.code.password_auth_callback(
502
- form_data.username, form_data.password
503
- )
504
-
505
- return await _authenticate_user(user)
389
+ return {
390
+ "access_token": access_token,
391
+ "token_type": "bearer",
392
+ }
506
393
 
507
394
 
508
395
  @router.post("/logout")
509
396
  async def logout(request: Request, response: Response):
510
397
  """Logout the user by calling the on_logout callback."""
511
- clear_auth_cookie(response)
512
-
513
398
  if config.code.on_logout:
514
399
  return await config.code.on_logout(request, response)
515
-
516
400
  return {"success": True}
517
401
 
518
402
 
519
- @router.post("/auth/jwt")
520
- async def jwt_auth(request: Request):
521
- """Login a user using a valid jwt."""
522
- from jwt import InvalidTokenError
523
-
524
- auth_header: Optional[str] = request.headers.get("Authorization")
525
- if not auth_header:
526
- raise HTTPException(status_code=401, detail="Authorization header missing")
527
-
528
- # Check if it starts with "Bearer "
529
- try:
530
- scheme, token = auth_header.split()
531
- if scheme.lower() != "bearer":
532
- raise HTTPException(
533
- status_code=401,
534
- detail="Invalid authentication scheme. Please use Bearer",
535
- )
536
- except ValueError:
537
- raise HTTPException(
538
- status_code=401, detail="Invalid authorization header format"
539
- )
540
-
541
- try:
542
- user = decode_jwt(token)
543
- return await _authenticate_user(user)
544
- except InvalidTokenError:
545
- raise HTTPException(status_code=401, detail="Invalid token")
546
-
547
-
548
403
  @router.post("/auth/header")
549
404
  async def header_auth(request: Request):
550
405
  """Login a user using the header_auth_callback."""
@@ -556,7 +411,23 @@ async def header_auth(request: Request):
556
411
 
557
412
  user = await config.code.header_auth_callback(request.headers)
558
413
 
559
- return await _authenticate_user(user)
414
+ if not user:
415
+ raise HTTPException(
416
+ status_code=status.HTTP_401_UNAUTHORIZED,
417
+ detail="Unauthorized",
418
+ )
419
+
420
+ access_token = create_jwt(user)
421
+ if data_layer := get_data_layer():
422
+ try:
423
+ await data_layer.create_user(user)
424
+ except Exception as e:
425
+ logger.error(f"Error creating user: {e}")
426
+
427
+ return {
428
+ "access_token": access_token,
429
+ "token_type": "bearer",
430
+ }
560
431
 
561
432
 
562
433
  @router.get("/auth/oauth/{provider_id}")
@@ -588,9 +459,16 @@ async def oauth_login(provider_id: str, request: Request):
588
459
  response = RedirectResponse(
589
460
  url=f"{provider.authorize_url}?{params}",
590
461
  )
591
-
592
- set_oauth_state_cookie(response, random)
593
-
462
+ samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
463
+ secure = samesite.lower() == "none"
464
+ response.set_cookie(
465
+ "oauth_state",
466
+ random,
467
+ httponly=True,
468
+ samesite=samesite,
469
+ secure=secure,
470
+ max_age=3 * 60,
471
+ )
594
472
  return response
595
473
 
596
474
 
@@ -618,7 +496,16 @@ async def oauth_callback(
618
496
  )
619
497
 
620
498
  if error:
621
- return _get_oauth_redirect_error(error)
499
+ params = urllib.parse.urlencode(
500
+ {
501
+ "error": error,
502
+ }
503
+ )
504
+ response = RedirectResponse(
505
+ # FIXME: redirect to the right frontend base url to improve the dev environment
506
+ url=f"/login?{params}",
507
+ )
508
+ return response
622
509
 
623
510
  if not code or not state:
624
511
  raise HTTPException(
@@ -626,11 +513,9 @@ async def oauth_callback(
626
513
  detail="Missing code or state",
627
514
  )
628
515
 
629
- try:
630
- validate_oauth_state_cookie(request, state)
631
- except Exception as e:
632
- logger.exception("Unable to validate oauth state: %1", e)
633
-
516
+ # Check the state from the oauth provider against the browser cookie
517
+ oauth_state = request.cookies.get("oauth_state")
518
+ if oauth_state != state:
634
519
  raise HTTPException(
635
520
  status_code=status.HTTP_401_UNAUTHORIZED,
636
521
  detail="Unauthorized",
@@ -645,10 +530,34 @@ async def oauth_callback(
645
530
  provider_id, token, raw_user_data, default_user
646
531
  )
647
532
 
648
- response = await _authenticate_user(user, redirect_to_callback=True)
533
+ if not user:
534
+ raise HTTPException(
535
+ status_code=status.HTTP_401_UNAUTHORIZED,
536
+ detail="Unauthorized",
537
+ )
649
538
 
650
- clear_oauth_state_cookie(response)
539
+ access_token = create_jwt(user)
651
540
 
541
+ if data_layer := get_data_layer():
542
+ try:
543
+ await data_layer.create_user(user)
544
+ except Exception as e:
545
+ logger.error(f"Error creating user: {e}")
546
+
547
+ params = urllib.parse.urlencode(
548
+ {
549
+ "access_token": access_token,
550
+ "token_type": "bearer",
551
+ }
552
+ )
553
+
554
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
555
+
556
+ response = RedirectResponse(
557
+ # FIXME: redirect to the right frontend base url to improve the dev environment
558
+ url=f"{root_path}/login/callback?{params}",
559
+ )
560
+ response.delete_cookie("oauth_state")
652
561
  return response
653
562
 
654
563
 
@@ -677,7 +586,16 @@ async def oauth_azure_hf_callback(
677
586
  )
678
587
 
679
588
  if error:
680
- return _get_oauth_redirect_error(error)
589
+ params = urllib.parse.urlencode(
590
+ {
591
+ "error": error,
592
+ }
593
+ )
594
+ response = RedirectResponse(
595
+ # FIXME: redirect to the right frontend base url to improve the dev environment
596
+ url=f"/login?{params}",
597
+ )
598
+ return response
681
599
 
682
600
  if not code:
683
601
  raise HTTPException(
@@ -694,24 +612,40 @@ async def oauth_azure_hf_callback(
694
612
  provider_id, token, raw_user_data, default_user, id_token
695
613
  )
696
614
 
697
- response = await _authenticate_user(user, redirect_to_callback=True)
698
-
699
- clear_oauth_state_cookie(response)
615
+ if not user:
616
+ raise HTTPException(
617
+ status_code=status.HTTP_401_UNAUTHORIZED,
618
+ detail="Unauthorized",
619
+ )
700
620
 
701
- return response
621
+ access_token = create_jwt(user)
702
622
 
623
+ if data_layer := get_data_layer():
624
+ try:
625
+ await data_layer.create_user(user)
626
+ except Exception as e:
627
+ logger.error(f"Error creating user: {e}")
703
628
 
704
- GenericUser = Union[User, PersistedUser, None]
705
- UserParam = Annotated[GenericUser, Depends(get_current_user)]
629
+ params = urllib.parse.urlencode(
630
+ {
631
+ "access_token": access_token,
632
+ "token_type": "bearer",
633
+ }
634
+ )
706
635
 
636
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
707
637
 
708
- @router.get("/user")
709
- async def get_user(current_user: UserParam) -> GenericUser:
710
- return current_user
638
+ response = RedirectResponse(
639
+ # FIXME: redirect to the right frontend base url to improve the dev environment
640
+ url=f"{root_path}/login/callback?{params}",
641
+ status_code=302,
642
+ )
643
+ response.delete_cookie("oauth_state")
644
+ return response
711
645
 
712
646
 
713
647
  _language_pattern = (
714
- "^[a-zA-Z]{2,3}(-[a-zA-Z0-9]{2,3})?(-[a-zA-Z0-9]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
648
+ "^[a-zA-Z]{2,3}(-[a-zA-Z]{2,3})?(-[a-zA-Z]{2,8})?(-x-[a-zA-Z0-9]{1,8})?$"
715
649
  )
716
650
 
717
651
 
@@ -735,7 +669,7 @@ async def project_translations(
735
669
 
736
670
  @router.get("/project/settings")
737
671
  async def project_settings(
738
- current_user: UserParam,
672
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
739
673
  language: str = Query(
740
674
  default="en-US", description="Language code", pattern=_language_pattern
741
675
  ),
@@ -786,7 +720,7 @@ async def project_settings(
786
720
  async def update_feedback(
787
721
  request: Request,
788
722
  update: UpdateFeedbackRequest,
789
- current_user: UserParam,
723
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
790
724
  ):
791
725
  """Update the human feedback for a particular message."""
792
726
  data_layer = get_data_layer()
@@ -796,7 +730,7 @@ async def update_feedback(
796
730
  try:
797
731
  feedback_id = await data_layer.upsert_feedback(feedback=update.feedback)
798
732
  except Exception as e:
799
- raise HTTPException(detail=str(e), status_code=500) from e
733
+ raise HTTPException(detail=str(e), status_code=500)
800
734
 
801
735
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
802
736
 
@@ -805,7 +739,7 @@ async def update_feedback(
805
739
  async def delete_feedback(
806
740
  request: Request,
807
741
  payload: DeleteFeedbackRequest,
808
- current_user: UserParam,
742
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
809
743
  ):
810
744
  """Delete a feedback."""
811
745
 
@@ -824,7 +758,7 @@ async def delete_feedback(
824
758
  async def get_user_threads(
825
759
  request: Request,
826
760
  payload: GetThreadsRequest,
827
- current_user: UserParam,
761
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
828
762
  ):
829
763
  """Get the threads page by page."""
830
764
 
@@ -833,9 +767,6 @@ async def get_user_threads(
833
767
  if not data_layer:
834
768
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
835
769
 
836
- if not current_user:
837
- raise HTTPException(status_code=401, detail="Unauthorized")
838
-
839
770
  if not isinstance(current_user, PersistedUser):
840
771
  persisted_user = await data_layer.get_user(identifier=current_user.identifier)
841
772
  if not persisted_user:
@@ -852,7 +783,7 @@ async def get_user_threads(
852
783
  async def get_thread(
853
784
  request: Request,
854
785
  thread_id: str,
855
- current_user: UserParam,
786
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
856
787
  ):
857
788
  """Get a specific thread."""
858
789
  data_layer = get_data_layer()
@@ -860,9 +791,6 @@ async def get_thread(
860
791
  if not data_layer:
861
792
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
862
793
 
863
- if not current_user:
864
- raise HTTPException(status_code=401, detail="Unauthorized")
865
-
866
794
  await is_thread_author(current_user.identifier, thread_id)
867
795
 
868
796
  res = await data_layer.get_thread(thread_id)
@@ -874,7 +802,7 @@ async def get_thread_element(
874
802
  request: Request,
875
803
  thread_id: str,
876
804
  element_id: str,
877
- current_user: UserParam,
805
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
878
806
  ):
879
807
  """Get a specific thread element."""
880
808
  data_layer = get_data_layer()
@@ -882,135 +810,17 @@ async def get_thread_element(
882
810
  if not data_layer:
883
811
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
884
812
 
885
- if not current_user:
886
- raise HTTPException(status_code=401, detail="Unauthorized")
887
-
888
813
  await is_thread_author(current_user.identifier, thread_id)
889
814
 
890
815
  res = await data_layer.get_element(thread_id, element_id)
891
816
  return JSONResponse(content=res)
892
817
 
893
818
 
894
- @router.put("/project/element")
895
- async def update_thread_element(
896
- payload: ElementRequest,
897
- current_user: UserParam,
898
- ):
899
- """Update a specific thread element."""
900
-
901
- from chainlit.context import init_ws_context
902
- from chainlit.element import CustomElement, ElementDict
903
- from chainlit.session import WebsocketSession
904
-
905
- session = WebsocketSession.get_by_id(payload.sessionId)
906
- context = init_ws_context(session)
907
-
908
- element_dict = cast(ElementDict, payload.element)
909
-
910
- if element_dict["type"] != "custom":
911
- return {"success": False}
912
-
913
- element = CustomElement(
914
- id=element_dict["id"],
915
- object_key=element_dict["objectKey"],
916
- chainlit_key=element_dict["chainlitKey"],
917
- url=element_dict["url"],
918
- for_id=element_dict.get("forId") or "",
919
- thread_id=element_dict.get("threadId") or "",
920
- name=element_dict["name"],
921
- props=element_dict.get("props") or {},
922
- display=element_dict["display"],
923
- )
924
-
925
- if current_user:
926
- if (
927
- not context.session.user
928
- or context.session.user.identifier != current_user.identifier
929
- ):
930
- raise HTTPException(
931
- status_code=401,
932
- detail="You are not authorized to update elements for this session",
933
- )
934
-
935
- await element.send(for_id=element.for_id or "")
936
- return {"success": True}
937
-
938
-
939
- @router.delete("/project/element")
940
- async def delete_thread_element(
941
- payload: ElementRequest,
942
- current_user: UserParam,
943
- ):
944
- """Delete a specific thread element."""
945
-
946
- from chainlit.context import init_ws_context
947
- from chainlit.element import CustomElement, ElementDict
948
- from chainlit.session import WebsocketSession
949
-
950
- session = WebsocketSession.get_by_id(payload.sessionId)
951
- context = init_ws_context(session)
952
-
953
- element_dict = cast(ElementDict, payload.element)
954
-
955
- if element_dict["type"] != "custom":
956
- return {"success": False}
957
-
958
- element = CustomElement(
959
- id=element_dict["id"],
960
- object_key=element_dict["objectKey"],
961
- chainlit_key=element_dict["chainlitKey"],
962
- url=element_dict["url"],
963
- for_id=element_dict.get("forId") or "",
964
- thread_id=element_dict.get("threadId") or "",
965
- name=element_dict["name"],
966
- props=element_dict.get("props") or {},
967
- display=element_dict["display"],
968
- )
969
-
970
- if current_user:
971
- if (
972
- not context.session.user
973
- or context.session.user.identifier != current_user.identifier
974
- ):
975
- raise HTTPException(
976
- status_code=401,
977
- detail="You are not authorized to remove elements for this session",
978
- )
979
-
980
- await element.remove()
981
-
982
- return {"success": True}
983
-
984
-
985
- @router.put("/project/thread")
986
- async def rename_thread(
987
- request: Request,
988
- payload: UpdateThreadRequest,
989
- current_user: UserParam,
990
- ):
991
- """Rename a thread."""
992
-
993
- data_layer = get_data_layer()
994
-
995
- if not data_layer:
996
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
997
-
998
- if not current_user:
999
- raise HTTPException(status_code=401, detail="Unauthorized")
1000
-
1001
- thread_id = payload.threadId
1002
-
1003
- await is_thread_author(current_user.identifier, thread_id)
1004
-
1005
- await data_layer.update_thread(thread_id, name=payload.name)
1006
- return JSONResponse(content={"success": True})
1007
-
1008
-
1009
819
  @router.delete("/project/thread")
1010
820
  async def delete_thread(
1011
821
  request: Request,
1012
822
  payload: DeleteThreadRequest,
1013
- current_user: UserParam,
823
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1014
824
  ):
1015
825
  """Delete a thread."""
1016
826
 
@@ -1019,9 +829,6 @@ async def delete_thread(
1019
829
  if not data_layer:
1020
830
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
1021
831
 
1022
- if not current_user:
1023
- raise HTTPException(status_code=401, detail="Unauthorized")
1024
-
1025
832
  thread_id = payload.threadId
1026
833
 
1027
834
  await is_thread_author(current_user.identifier, thread_id)
@@ -1030,49 +837,13 @@ async def delete_thread(
1030
837
  return JSONResponse(content={"success": True})
1031
838
 
1032
839
 
1033
- @router.post("/project/action")
1034
- async def call_action(
1035
- payload: CallActionRequest,
1036
- current_user: UserParam,
1037
- ):
1038
- """Run an action."""
1039
-
1040
- from chainlit.action import Action
1041
- from chainlit.context import init_ws_context
1042
- from chainlit.session import WebsocketSession
1043
-
1044
- session = WebsocketSession.get_by_id(payload.sessionId)
1045
- context = init_ws_context(session)
1046
-
1047
- action = Action(**payload.action)
1048
-
1049
- if current_user:
1050
- if (
1051
- not context.session.user
1052
- or context.session.user.identifier != current_user.identifier
1053
- ):
1054
- raise HTTPException(
1055
- status_code=401,
1056
- detail="You are not authorized to upload files for this session",
1057
- )
1058
-
1059
- callback = config.code.action_callbacks.get(action.name)
1060
- if callback:
1061
- await callback(action)
1062
- else:
1063
- raise HTTPException(
1064
- status_code=404,
1065
- detail=f"No callback found for action {action.name}",
1066
- )
1067
-
1068
- return JSONResponse(content={"success": True})
1069
-
1070
-
1071
840
  @router.post("/project/file")
1072
841
  async def upload_file(
1073
- current_user: UserParam,
1074
842
  session_id: str,
1075
843
  file: UploadFile,
844
+ current_user: Annotated[
845
+ Union[None, User, PersistedUser], Depends(get_current_user)
846
+ ],
1076
847
  ):
1077
848
  """Upload a file to the session files directory."""
1078
849
 
@@ -1097,111 +868,30 @@ async def upload_file(
1097
868
 
1098
869
  content = await file.read()
1099
870
 
1100
- assert file.filename, "No filename for uploaded file"
1101
- assert file.content_type, "No content type for uploaded file"
1102
-
1103
- try:
1104
- validate_file_upload(file)
1105
- except ValueError as e:
1106
- raise HTTPException(status_code=400, detail=str(e))
1107
-
1108
871
  file_response = await session.persist_file(
1109
872
  name=file.filename, content=content, mime=file.content_type
1110
873
  )
1111
874
 
1112
- return JSONResponse(content=file_response)
1113
-
1114
-
1115
- def validate_file_upload(file: UploadFile):
1116
- """Validate the file upload as configured in config.features.spontaneous_file_upload.
1117
- Args:
1118
- file (UploadFile): The file to validate.
1119
- Raises:
1120
- ValueError: If the file is not allowed.
1121
- """
1122
- if config.features.spontaneous_file_upload is None:
1123
- """Default for a missing config is to allow the fileupload without any restrictions"""
1124
- return
1125
- if config.features.spontaneous_file_upload.enabled is False:
1126
- raise ValueError("File upload is not enabled")
1127
-
1128
- validate_file_mime_type(file)
1129
- validate_file_size(file)
1130
-
1131
-
1132
- def validate_file_mime_type(file: UploadFile):
1133
- """Validate the file mime type as configured in config.features.spontaneous_file_upload.
1134
- Args:
1135
- file (UploadFile): The file to validate.
1136
- Raises:
1137
- ValueError: If the file type is not allowed.
1138
- """
1139
- accept = config.features.spontaneous_file_upload.accept
1140
- if accept is None:
1141
- "Accept is not configured, allowing all file types"
1142
- return
1143
-
1144
- assert (
1145
- isinstance(accept, List) or isinstance(accept, dict)
1146
- ), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1147
-
1148
- if isinstance(accept, List):
1149
- for pattern in accept:
1150
- if fnmatch.fnmatch(file.content_type, pattern):
1151
- return
1152
- elif isinstance(accept, dict):
1153
- for pattern, extensions in accept.items():
1154
- if fnmatch.fnmatch(file.content_type, pattern):
1155
- if len(extensions) == 0:
1156
- return
1157
- for extension in extensions:
1158
- if file.filename is not None and file.filename.endswith(extension):
1159
- return
1160
- raise ValueError("File type not allowed")
1161
-
1162
-
1163
- def validate_file_size(file: UploadFile):
1164
- """Validate the file size as configured in config.features.spontaneous_file_upload.
1165
- Args:
1166
- file (UploadFile): The file to validate.
1167
- Raises:
1168
- ValueError: If the file size is too large.
1169
- """
1170
- if config.features.spontaneous_file_upload.max_size_mb is None:
1171
- return
1172
-
1173
- if (
1174
- file.size is not None
1175
- and file.size
1176
- > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
1177
- ):
1178
- raise ValueError("File size too large")
875
+ return JSONResponse(file_response)
1179
876
 
1180
877
 
1181
878
  @router.get("/project/file/{file_id}")
1182
879
  async def get_file(
1183
880
  file_id: str,
1184
- session_id: str,
1185
- current_user: UserParam,
881
+ session_id: Optional[str] = None,
1186
882
  ):
1187
883
  """Get a file from the session files directory."""
884
+
1188
885
  from chainlit.session import WebsocketSession
1189
886
 
1190
887
  session = WebsocketSession.get_by_id(session_id) if session_id else None
1191
888
 
1192
889
  if not session:
1193
890
  raise HTTPException(
1194
- status_code=401,
1195
- detail="Unauthorized",
891
+ status_code=404,
892
+ detail="Session not found",
1196
893
  )
1197
894
 
1198
- if current_user:
1199
- if not session.user or session.user.identifier != current_user.identifier:
1200
- raise HTTPException(
1201
- status_code=401,
1202
- detail="You are not authorized to download files from this session",
1203
- )
1204
-
1205
895
  if file_id in session.files:
1206
896
  file = session.files[file_id]
1207
897
  return FileResponse(file["path"], media_type=file["type"])
@@ -1212,7 +902,7 @@ async def get_file(
1212
902
  @router.get("/files/{filename:path}")
1213
903
  async def serve_file(
1214
904
  filename: str,
1215
- current_user: UserParam,
905
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1216
906
  ):
1217
907
  """Serve a file from the local filesystem."""
1218
908
 
@@ -1271,7 +961,7 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
1271
961
  @router.get("/avatars/{avatar_id:str}")
1272
962
  async def get_avatar(avatar_id: str):
1273
963
  """Get the avatar for the user based on the avatar_id."""
1274
- if not re.match(r"^[a-zA-Z0-9_ -]+$", avatar_id):
964
+ if not re.match(r"^[a-zA-Z0-9_-]+$", avatar_id):
1275
965
  raise HTTPException(status_code=400, detail="Invalid avatar_id")
1276
966
 
1277
967
  if avatar_id == "default":