chainlit 2.0.0__py3-none-any.whl → 2.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (76) hide show
  1. chainlit/__init__.py +57 -55
  2. chainlit/action.py +10 -12
  3. chainlit/{auth/__init__.py → auth.py} +34 -20
  4. chainlit/cache.py +1 -2
  5. chainlit/callbacks.py +7 -52
  6. chainlit/chat_context.py +2 -2
  7. chainlit/chat_settings.py +1 -3
  8. chainlit/cli/__init__.py +1 -14
  9. chainlit/config.py +69 -35
  10. chainlit/context.py +2 -3
  11. chainlit/copilot/dist/index.js +935 -8533
  12. chainlit/data/__init__.py +8 -96
  13. chainlit/data/acl.py +2 -3
  14. chainlit/data/base.py +1 -1
  15. chainlit/data/dynamodb.py +3 -5
  16. chainlit/data/literalai.py +6 -4
  17. chainlit/data/sql_alchemy.py +7 -8
  18. chainlit/data/storage_clients/azure.py +0 -1
  19. chainlit/data/storage_clients/base.py +0 -6
  20. chainlit/data/storage_clients/s3.py +3 -16
  21. chainlit/discord/app.py +1 -2
  22. chainlit/element.py +9 -13
  23. chainlit/emitter.py +21 -17
  24. chainlit/frontend/dist/assets/{DailyMotion-DgRzV5GZ.js → DailyMotion-D1ipkdPJ.js} +1 -1
  25. chainlit/frontend/dist/assets/{Facebook-C0vx6HWv.js → Facebook-d4TLeTik.js} +1 -1
  26. chainlit/frontend/dist/assets/{FilePlayer-CdhzeHPP.js → FilePlayer-BcU7tttX.js} +1 -1
  27. chainlit/frontend/dist/assets/{Kaltura-5iVmeUct.js → Kaltura-DdaRjZrh.js} +1 -1
  28. chainlit/frontend/dist/assets/{Mixcloud-C2zi77Ex.js → Mixcloud-BaJoMsaU.js} +1 -1
  29. chainlit/frontend/dist/assets/{Mux-Vkebogdf.js → Mux-DxPCM5d3.js} +1 -1
  30. chainlit/frontend/dist/assets/{Preview-DwY_sEIl.js → Preview-tUK_Z9pZ.js} +1 -1
  31. chainlit/frontend/dist/assets/{SoundCloud-CREBXAWo.js → SoundCloud-K8-lFZC6.js} +1 -1
  32. chainlit/frontend/dist/assets/{Streamable-B5Lu25uy.js → Streamable-hB-AQ54w.js} +1 -1
  33. chainlit/frontend/dist/assets/{Twitch-y9iKCcM1.js → Twitch-pmuNY0J5.js} +1 -1
  34. chainlit/frontend/dist/assets/{Vidyard-ClYvcuEu.js → Vidyard-BSUm6trV.js} +1 -1
  35. chainlit/frontend/dist/assets/{Vimeo-D6HvM2jt.js → Vimeo-JIPn71zS.js} +1 -1
  36. chainlit/frontend/dist/assets/Wistia-D75KkqOG.js +1 -0
  37. chainlit/frontend/dist/assets/{YouTube-D10tR6CJ.js → YouTube-CPlwqNm_.js} +1 -1
  38. chainlit/frontend/dist/assets/index-CuSbXjG5.js +1091 -0
  39. chainlit/frontend/dist/assets/index-CwmincdQ.css +1 -0
  40. chainlit/frontend/dist/assets/{react-plotly-BpxUS-ab.js → react-plotly-DALmanjC.js} +1 -1
  41. chainlit/frontend/dist/index.html +2 -2
  42. chainlit/haystack/callbacks.py +4 -5
  43. chainlit/input_widget.py +4 -6
  44. chainlit/langchain/callbacks.py +47 -56
  45. chainlit/langflow/__init__.py +0 -1
  46. chainlit/llama_index/callbacks.py +7 -7
  47. chainlit/message.py +7 -6
  48. chainlit/mistralai/__init__.py +2 -3
  49. chainlit/oauth_providers.py +3 -70
  50. chainlit/openai/__init__.py +2 -3
  51. chainlit/secret.py +1 -1
  52. chainlit/server.py +174 -474
  53. chainlit/session.py +5 -7
  54. chainlit/slack/app.py +2 -3
  55. chainlit/socket.py +103 -78
  56. chainlit/step.py +11 -11
  57. chainlit/sync.py +1 -2
  58. chainlit/teams/app.py +0 -1
  59. chainlit/types.py +4 -20
  60. chainlit/user.py +1 -2
  61. chainlit/utils.py +2 -3
  62. {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/METADATA +38 -8
  63. chainlit-2.0.dev1.dist-info/RECORD +99 -0
  64. chainlit/auth/cookie.py +0 -123
  65. chainlit/auth/jwt.py +0 -37
  66. chainlit/data/chainlit_data_layer.py +0 -584
  67. chainlit/data/storage_clients/azure_blob.py +0 -80
  68. chainlit/data/storage_clients/gcs.py +0 -78
  69. chainlit/frontend/dist/assets/Dataframe-DVgwSMU2.js +0 -22
  70. chainlit/frontend/dist/assets/Wistia-Cu4zZ2Ci.js +0 -1
  71. chainlit/frontend/dist/assets/index-CI4qFOt5.js +0 -8665
  72. chainlit/frontend/dist/assets/index-CrrqM0nZ.css +0 -1
  73. chainlit/translations/nl-NL.json +0 -229
  74. chainlit-2.0.0.dist-info/RECORD +0 -106
  75. {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/WHEEL +0 -0
  76. {chainlit-2.0.0.dist-info → chainlit-2.0.dev1.dist-info}/entry_points.txt +0 -0
chainlit/server.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- import fnmatch
3
2
  import glob
4
3
  import json
5
4
  import mimetypes
@@ -10,36 +9,10 @@ import urllib.parse
10
9
  import webbrowser
11
10
  from contextlib import asynccontextmanager
12
11
  from pathlib import Path
13
- from typing import List, Optional, Union, cast
12
+ from typing import Any, Optional, Union
14
13
 
15
14
  import socketio
16
- from fastapi import (
17
- APIRouter,
18
- Depends,
19
- FastAPI,
20
- Form,
21
- HTTPException,
22
- Query,
23
- Request,
24
- Response,
25
- UploadFile,
26
- status,
27
- )
28
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
29
- from fastapi.security import OAuth2PasswordRequestForm
30
- from starlette.datastructures import URL
31
- from starlette.middleware.cors import CORSMiddleware
32
- from typing_extensions import Annotated
33
- from watchfiles import awatch
34
-
35
- from chainlit.auth import create_jwt, decode_jwt, get_configuration, get_current_user
36
- from chainlit.auth.cookie import (
37
- clear_auth_cookie,
38
- clear_oauth_state_cookie,
39
- set_auth_cookie,
40
- set_oauth_state_cookie,
41
- validate_oauth_state_cookie,
42
- )
15
+ from chainlit.auth import create_jwt, get_configuration, get_current_user
43
16
  from chainlit.config import (
44
17
  APP_ROOT,
45
18
  BACKEND_ROOT,
@@ -48,7 +21,6 @@ from chainlit.config import (
48
21
  PACKAGE_ROOT,
49
22
  config,
50
23
  load_module,
51
- public_dir,
52
24
  reload_config,
53
25
  )
54
26
  from chainlit.data import get_data_layer
@@ -58,16 +30,33 @@ from chainlit.markdown import get_markdown_str
58
30
  from chainlit.oauth_providers import get_oauth_provider
59
31
  from chainlit.secret import random_secret
60
32
  from chainlit.types import (
61
- CallActionRequest,
62
33
  DeleteFeedbackRequest,
63
34
  DeleteThreadRequest,
64
- ElementRequest,
65
35
  GetThreadsRequest,
66
36
  Theme,
67
37
  UpdateFeedbackRequest,
68
- UpdateThreadRequest,
69
38
  )
70
39
  from chainlit.user import PersistedUser, User
40
+ from fastapi import (
41
+ APIRouter,
42
+ Depends,
43
+ FastAPI,
44
+ File,
45
+ Form,
46
+ HTTPException,
47
+ Query,
48
+ Request,
49
+ Response,
50
+ UploadFile,
51
+ status,
52
+ )
53
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
54
+ from fastapi.security import OAuth2PasswordRequestForm
55
+ from fastapi.staticfiles import StaticFiles
56
+ from starlette.datastructures import URL
57
+ from starlette.middleware.cors import CORSMiddleware
58
+ from typing_extensions import Annotated
59
+ from watchfiles import awatch
71
60
 
72
61
  from ._utils import is_path_inside
73
62
 
@@ -216,59 +205,29 @@ app.add_middleware(
216
205
 
217
206
  router = APIRouter(prefix=PREFIX)
218
207
 
208
+ app.mount(
209
+ f"{PREFIX}/public",
210
+ StaticFiles(directory="public", check_dir=False),
211
+ name="public",
212
+ )
219
213
 
220
- @router.get("/public/{filename:path}")
221
- async def serve_public_file(
222
- filename: str,
223
- ):
224
- """Serve a file from public dir."""
225
-
226
- base_path = Path(public_dir)
227
- file_path = (base_path / filename).resolve()
228
-
229
- if not is_path_inside(file_path, base_path):
230
- raise HTTPException(status_code=400, detail="Invalid filename")
231
-
232
- if file_path.is_file():
233
- return FileResponse(file_path)
234
- else:
235
- raise HTTPException(status_code=404, detail="File not found")
236
-
237
-
238
- @router.get("/assets/{filename:path}")
239
- async def serve_asset_file(
240
- filename: str,
241
- ):
242
- """Serve a file from assets dir."""
243
-
244
- base_path = Path(os.path.join(build_dir, "assets"))
245
- file_path = (base_path / filename).resolve()
246
-
247
- if not is_path_inside(file_path, base_path):
248
- raise HTTPException(status_code=400, detail="Invalid filename")
249
-
250
- if file_path.is_file():
251
- return FileResponse(file_path)
252
- else:
253
- raise HTTPException(status_code=404, detail="File not found")
254
-
255
-
256
- @router.get("/copilot/{filename:path}")
257
- async def serve_copilot_file(
258
- filename: str,
259
- ):
260
- """Serve a file from assets dir."""
261
-
262
- base_path = Path(copilot_build_dir)
263
- file_path = (base_path / filename).resolve()
264
-
265
- if not is_path_inside(file_path, base_path):
266
- raise HTTPException(status_code=400, detail="Invalid filename")
214
+ app.mount(
215
+ f"{PREFIX}/assets",
216
+ StaticFiles(
217
+ packages=[("chainlit", os.path.join(build_dir, "assets"))],
218
+ follow_symlink=config.project.follow_symlink,
219
+ ),
220
+ name="assets",
221
+ )
267
222
 
268
- if file_path.is_file():
269
- return FileResponse(file_path)
270
- else:
271
- raise HTTPException(status_code=404, detail="File not found")
223
+ app.mount(
224
+ f"{PREFIX}/copilot",
225
+ StaticFiles(
226
+ packages=[("chainlit", copilot_build_dir)],
227
+ follow_symlink=config.project.follow_symlink,
228
+ ),
229
+ name="copilot",
230
+ )
272
231
 
273
232
 
274
233
  # -------------------------------------------------------------------------------
@@ -289,7 +248,6 @@ if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
289
248
 
290
249
  if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
291
250
  from botbuilder.schema import Activity
292
-
293
251
  from chainlit.teams.app import adapter, bot
294
252
 
295
253
  @router.post("/teams/events")
@@ -319,16 +277,6 @@ def get_html_template():
319
277
  """
320
278
  Get HTML template for the index view.
321
279
  """
322
- ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
323
-
324
- custom_theme = None
325
- custom_theme_file_path = Path(public_dir) / "theme.json"
326
- if (
327
- is_path_inside(custom_theme_file_path, Path(public_dir))
328
- and custom_theme_file_path.is_file()
329
- ):
330
- custom_theme = json.loads(custom_theme_file_path.read_text(encoding="utf-8"))
331
-
332
280
  PLACEHOLDER = "<!-- TAG INJECTION PLACEHOLDER -->"
333
281
  JS_PLACEHOLDER = "<!-- JS INJECTION PLACEHOLDER -->"
334
282
  CSS_PLACEHOLDER = "<!-- CSS INJECTION PLACEHOLDER -->"
@@ -351,10 +299,7 @@ def get_html_template():
351
299
  <meta property="og:url" content="{url}">
352
300
  <meta property="og:root_path" content="{ROOT_PATH}">"""
353
301
 
354
- js = f"""<script>
355
- {f"window.theme = {json.dumps(custom_theme.get('variables'))};" if custom_theme and custom_theme.get("variables") else "undefined"}
356
- {f"window.transports = {json.dumps(config.project.transports)};" if config.project.transports else "undefined"}
357
- </script>"""
302
+ js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
358
303
 
359
304
  css = None
360
305
  if config.ui.custom_css:
@@ -366,15 +311,12 @@ def get_html_template():
366
311
  js += f"""<script src="{config.ui.custom_js}" defer></script>"""
367
312
 
368
313
  font = None
369
- if custom_theme and custom_theme.get("custom_fonts"):
370
- font = "\n".join(
371
- f"""<link rel="stylesheet" href="{font}">"""
372
- for font in custom_theme.get("custom_fonts")
373
- )
314
+ if config.ui.custom_font:
315
+ font = f"""<link rel="stylesheet" href="{config.ui.custom_font}">"""
374
316
 
375
317
  index_html_file_path = os.path.join(build_dir, "index.html")
376
318
 
377
- with open(index_html_file_path, encoding="utf-8") as f:
319
+ with open(index_html_file_path, "r", encoding="utf-8") as f:
378
320
  content = f.read()
379
321
  content = content.replace(PLACEHOLDER, tags)
380
322
  if js:
@@ -419,132 +361,46 @@ async def auth(request: Request):
419
361
  return get_configuration()
420
362
 
421
363
 
422
- def _get_response_dict(access_token: str) -> dict:
423
- """Get the response dictionary for the auth response."""
424
-
425
- return {"success": True}
426
-
427
-
428
- def _get_auth_response(access_token: str, redirect_to_callback: bool) -> Response:
429
- """Get the redirect params for the OAuth callback."""
430
-
431
- response_dict = _get_response_dict(access_token)
432
-
433
- if redirect_to_callback:
434
- root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
435
- redirect_url = (
436
- f"{root_path}/login/callback?{urllib.parse.urlencode(response_dict)}"
437
- )
438
-
439
- return RedirectResponse(
440
- # FIXME: redirect to the right frontend base url to improve the dev environment
441
- url=redirect_url,
442
- status_code=302,
364
+ @router.post("/login")
365
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
366
+ """
367
+ Login a user using the password auth callback.
368
+ """
369
+ if not config.code.password_auth_callback:
370
+ raise HTTPException(
371
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
443
372
  )
444
373
 
445
- return JSONResponse(response_dict)
446
-
447
-
448
- def _get_oauth_redirect_error(error: str) -> Response:
449
- """Get the redirect response for an OAuth error."""
450
- params = urllib.parse.urlencode(
451
- {
452
- "error": error,
453
- }
454
- )
455
- response = RedirectResponse(
456
- # FIXME: redirect to the right frontend base url to improve the dev environment
457
- url=f"/login?{params}", # Shouldn't there be {root_path} here?
374
+ user = await config.code.password_auth_callback(
375
+ form_data.username, form_data.password
458
376
  )
459
- return response
460
-
461
-
462
- async def _authenticate_user(
463
- user: Optional[User], redirect_to_callback: bool = False
464
- ) -> Response:
465
- """Authenticate a user and return the response."""
466
377
 
467
378
  if not user:
468
379
  raise HTTPException(
469
380
  status_code=status.HTTP_401_UNAUTHORIZED,
470
381
  detail="credentialssignin",
471
382
  )
472
-
473
- # If a data layer is defined, attempt to persist user.
383
+ access_token = create_jwt(user)
474
384
  if data_layer := get_data_layer():
475
385
  try:
476
386
  await data_layer.create_user(user)
477
387
  except Exception as e:
478
- # Catch and log exceptions during user creation.
479
- # TODO: Make this catch only specific errors and allow others to propagate.
480
388
  logger.error(f"Error creating user: {e}")
481
389
 
482
- access_token = create_jwt(user)
483
-
484
- response = _get_auth_response(access_token, redirect_to_callback)
485
-
486
- set_auth_cookie(response, access_token)
487
-
488
- return response
489
-
490
-
491
- @router.post("/login")
492
- async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
493
- """
494
- Login a user using the password auth callback.
495
- """
496
- if not config.code.password_auth_callback:
497
- raise HTTPException(
498
- status_code=status.HTTP_400_BAD_REQUEST, detail="No auth_callback defined"
499
- )
500
-
501
- user = await config.code.password_auth_callback(
502
- form_data.username, form_data.password
503
- )
504
-
505
- return await _authenticate_user(user)
390
+ return {
391
+ "access_token": access_token,
392
+ "token_type": "bearer",
393
+ }
506
394
 
507
395
 
508
396
  @router.post("/logout")
509
397
  async def logout(request: Request, response: Response):
510
398
  """Logout the user by calling the on_logout callback."""
511
- clear_auth_cookie(response)
512
-
513
399
  if config.code.on_logout:
514
400
  return await config.code.on_logout(request, response)
515
-
516
401
  return {"success": True}
517
402
 
518
403
 
519
- @router.post("/auth/jwt")
520
- async def jwt_auth(request: Request):
521
- """Login a user using a valid jwt."""
522
- from jwt import InvalidTokenError
523
-
524
- auth_header: Optional[str] = request.headers.get("Authorization")
525
- if not auth_header:
526
- raise HTTPException(status_code=401, detail="Authorization header missing")
527
-
528
- # Check if it starts with "Bearer "
529
- try:
530
- scheme, token = auth_header.split()
531
- if scheme.lower() != "bearer":
532
- raise HTTPException(
533
- status_code=401,
534
- detail="Invalid authentication scheme. Please use Bearer",
535
- )
536
- except ValueError:
537
- raise HTTPException(
538
- status_code=401, detail="Invalid authorization header format"
539
- )
540
-
541
- try:
542
- user = decode_jwt(token)
543
- return await _authenticate_user(user)
544
- except InvalidTokenError:
545
- raise HTTPException(status_code=401, detail="Invalid token")
546
-
547
-
548
404
  @router.post("/auth/header")
549
405
  async def header_auth(request: Request):
550
406
  """Login a user using the header_auth_callback."""
@@ -556,7 +412,23 @@ async def header_auth(request: Request):
556
412
 
557
413
  user = await config.code.header_auth_callback(request.headers)
558
414
 
559
- return await _authenticate_user(user)
415
+ if not user:
416
+ raise HTTPException(
417
+ status_code=status.HTTP_401_UNAUTHORIZED,
418
+ detail="Unauthorized",
419
+ )
420
+
421
+ access_token = create_jwt(user)
422
+ if data_layer := get_data_layer():
423
+ try:
424
+ await data_layer.create_user(user)
425
+ except Exception as e:
426
+ logger.error(f"Error creating user: {e}")
427
+
428
+ return {
429
+ "access_token": access_token,
430
+ "token_type": "bearer",
431
+ }
560
432
 
561
433
 
562
434
  @router.get("/auth/oauth/{provider_id}")
@@ -588,9 +460,16 @@ async def oauth_login(provider_id: str, request: Request):
588
460
  response = RedirectResponse(
589
461
  url=f"{provider.authorize_url}?{params}",
590
462
  )
591
-
592
- set_oauth_state_cookie(response, random)
593
-
463
+ samesite: Any = os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax")
464
+ secure = samesite.lower() == "none"
465
+ response.set_cookie(
466
+ "oauth_state",
467
+ random,
468
+ httponly=True,
469
+ samesite=samesite,
470
+ secure=secure,
471
+ max_age=3 * 60,
472
+ )
594
473
  return response
595
474
 
596
475
 
@@ -618,7 +497,16 @@ async def oauth_callback(
618
497
  )
619
498
 
620
499
  if error:
621
- return _get_oauth_redirect_error(error)
500
+ params = urllib.parse.urlencode(
501
+ {
502
+ "error": error,
503
+ }
504
+ )
505
+ response = RedirectResponse(
506
+ # FIXME: redirect to the right frontend base url to improve the dev environment
507
+ url=f"/login?{params}",
508
+ )
509
+ return response
622
510
 
623
511
  if not code or not state:
624
512
  raise HTTPException(
@@ -626,11 +514,9 @@ async def oauth_callback(
626
514
  detail="Missing code or state",
627
515
  )
628
516
 
629
- try:
630
- validate_oauth_state_cookie(request, state)
631
- except Exception as e:
632
- logger.exception("Unable to validate oauth state: %1", e)
633
-
517
+ # Check the state from the oauth provider against the browser cookie
518
+ oauth_state = request.cookies.get("oauth_state")
519
+ if oauth_state != state:
634
520
  raise HTTPException(
635
521
  status_code=status.HTTP_401_UNAUTHORIZED,
636
522
  detail="Unauthorized",
@@ -645,10 +531,34 @@ async def oauth_callback(
645
531
  provider_id, token, raw_user_data, default_user
646
532
  )
647
533
 
648
- response = await _authenticate_user(user, redirect_to_callback=True)
534
+ if not user:
535
+ raise HTTPException(
536
+ status_code=status.HTTP_401_UNAUTHORIZED,
537
+ detail="Unauthorized",
538
+ )
649
539
 
650
- clear_oauth_state_cookie(response)
540
+ access_token = create_jwt(user)
651
541
 
542
+ if data_layer := get_data_layer():
543
+ try:
544
+ await data_layer.create_user(user)
545
+ except Exception as e:
546
+ logger.error(f"Error creating user: {e}")
547
+
548
+ params = urllib.parse.urlencode(
549
+ {
550
+ "access_token": access_token,
551
+ "token_type": "bearer",
552
+ }
553
+ )
554
+
555
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
556
+
557
+ response = RedirectResponse(
558
+ # FIXME: redirect to the right frontend base url to improve the dev environment
559
+ url=f"{root_path}/login/callback?{params}",
560
+ )
561
+ response.delete_cookie("oauth_state")
652
562
  return response
653
563
 
654
564
 
@@ -677,7 +587,16 @@ async def oauth_azure_hf_callback(
677
587
  )
678
588
 
679
589
  if error:
680
- return _get_oauth_redirect_error(error)
590
+ params = urllib.parse.urlencode(
591
+ {
592
+ "error": error,
593
+ }
594
+ )
595
+ response = RedirectResponse(
596
+ # FIXME: redirect to the right frontend base url to improve the dev environment
597
+ url=f"/login?{params}",
598
+ )
599
+ return response
681
600
 
682
601
  if not code:
683
602
  raise HTTPException(
@@ -694,20 +613,36 @@ async def oauth_azure_hf_callback(
694
613
  provider_id, token, raw_user_data, default_user, id_token
695
614
  )
696
615
 
697
- response = await _authenticate_user(user, redirect_to_callback=True)
698
-
699
- clear_oauth_state_cookie(response)
616
+ if not user:
617
+ raise HTTPException(
618
+ status_code=status.HTTP_401_UNAUTHORIZED,
619
+ detail="Unauthorized",
620
+ )
700
621
 
701
- return response
622
+ access_token = create_jwt(user)
702
623
 
624
+ if data_layer := get_data_layer():
625
+ try:
626
+ await data_layer.create_user(user)
627
+ except Exception as e:
628
+ logger.error(f"Error creating user: {e}")
703
629
 
704
- GenericUser = Union[User, PersistedUser, None]
705
- UserParam = Annotated[GenericUser, Depends(get_current_user)]
630
+ params = urllib.parse.urlencode(
631
+ {
632
+ "access_token": access_token,
633
+ "token_type": "bearer",
634
+ }
635
+ )
706
636
 
637
+ root_path = os.environ.get("CHAINLIT_ROOT_PATH", "")
707
638
 
708
- @router.get("/user")
709
- async def get_user(current_user: UserParam) -> GenericUser:
710
- return current_user
639
+ response = RedirectResponse(
640
+ # FIXME: redirect to the right frontend base url to improve the dev environment
641
+ url=f"{root_path}/login/callback?{params}",
642
+ status_code=302,
643
+ )
644
+ response.delete_cookie("oauth_state")
645
+ return response
711
646
 
712
647
 
713
648
  _language_pattern = (
@@ -735,7 +670,7 @@ async def project_translations(
735
670
 
736
671
  @router.get("/project/settings")
737
672
  async def project_settings(
738
- current_user: UserParam,
673
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
739
674
  language: str = Query(
740
675
  default="en-US", description="Language code", pattern=_language_pattern
741
676
  ),
@@ -786,7 +721,7 @@ async def project_settings(
786
721
  async def update_feedback(
787
722
  request: Request,
788
723
  update: UpdateFeedbackRequest,
789
- current_user: UserParam,
724
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
790
725
  ):
791
726
  """Update the human feedback for a particular message."""
792
727
  data_layer = get_data_layer()
@@ -805,7 +740,7 @@ async def update_feedback(
805
740
  async def delete_feedback(
806
741
  request: Request,
807
742
  payload: DeleteFeedbackRequest,
808
- current_user: UserParam,
743
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
809
744
  ):
810
745
  """Delete a feedback."""
811
746
 
@@ -824,7 +759,7 @@ async def delete_feedback(
824
759
  async def get_user_threads(
825
760
  request: Request,
826
761
  payload: GetThreadsRequest,
827
- current_user: UserParam,
762
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
828
763
  ):
829
764
  """Get the threads page by page."""
830
765
 
@@ -833,9 +768,6 @@ async def get_user_threads(
833
768
  if not data_layer:
834
769
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
835
770
 
836
- if not current_user:
837
- raise HTTPException(status_code=401, detail="Unauthorized")
838
-
839
771
  if not isinstance(current_user, PersistedUser):
840
772
  persisted_user = await data_layer.get_user(identifier=current_user.identifier)
841
773
  if not persisted_user:
@@ -852,7 +784,7 @@ async def get_user_threads(
852
784
  async def get_thread(
853
785
  request: Request,
854
786
  thread_id: str,
855
- current_user: UserParam,
787
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
856
788
  ):
857
789
  """Get a specific thread."""
858
790
  data_layer = get_data_layer()
@@ -860,9 +792,6 @@ async def get_thread(
860
792
  if not data_layer:
861
793
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
862
794
 
863
- if not current_user:
864
- raise HTTPException(status_code=401, detail="Unauthorized")
865
-
866
795
  await is_thread_author(current_user.identifier, thread_id)
867
796
 
868
797
  res = await data_layer.get_thread(thread_id)
@@ -874,7 +803,7 @@ async def get_thread_element(
874
803
  request: Request,
875
804
  thread_id: str,
876
805
  element_id: str,
877
- current_user: UserParam,
806
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
878
807
  ):
879
808
  """Get a specific thread element."""
880
809
  data_layer = get_data_layer()
@@ -882,135 +811,17 @@ async def get_thread_element(
882
811
  if not data_layer:
883
812
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
884
813
 
885
- if not current_user:
886
- raise HTTPException(status_code=401, detail="Unauthorized")
887
-
888
814
  await is_thread_author(current_user.identifier, thread_id)
889
815
 
890
816
  res = await data_layer.get_element(thread_id, element_id)
891
817
  return JSONResponse(content=res)
892
818
 
893
819
 
894
- @router.put("/project/element")
895
- async def update_thread_element(
896
- payload: ElementRequest,
897
- current_user: UserParam,
898
- ):
899
- """Update a specific thread element."""
900
-
901
- from chainlit.context import init_ws_context
902
- from chainlit.element import CustomElement, ElementDict
903
- from chainlit.session import WebsocketSession
904
-
905
- session = WebsocketSession.get_by_id(payload.sessionId)
906
- context = init_ws_context(session)
907
-
908
- element_dict = cast(ElementDict, payload.element)
909
-
910
- if element_dict["type"] != "custom":
911
- return {"success": False}
912
-
913
- element = CustomElement(
914
- id=element_dict["id"],
915
- object_key=element_dict["objectKey"],
916
- chainlit_key=element_dict["chainlitKey"],
917
- url=element_dict["url"],
918
- for_id=element_dict.get("forId") or "",
919
- thread_id=element_dict.get("threadId") or "",
920
- name=element_dict["name"],
921
- props=element_dict.get("props") or {},
922
- display=element_dict["display"],
923
- )
924
-
925
- if current_user:
926
- if (
927
- not context.session.user
928
- or context.session.user.identifier != current_user.identifier
929
- ):
930
- raise HTTPException(
931
- status_code=401,
932
- detail="You are not authorized to update elements for this session",
933
- )
934
-
935
- await element.send(for_id=element.for_id or "")
936
- return {"success": True}
937
-
938
-
939
- @router.delete("/project/element")
940
- async def delete_thread_element(
941
- payload: ElementRequest,
942
- current_user: UserParam,
943
- ):
944
- """Delete a specific thread element."""
945
-
946
- from chainlit.context import init_ws_context
947
- from chainlit.element import CustomElement, ElementDict
948
- from chainlit.session import WebsocketSession
949
-
950
- session = WebsocketSession.get_by_id(payload.sessionId)
951
- context = init_ws_context(session)
952
-
953
- element_dict = cast(ElementDict, payload.element)
954
-
955
- if element_dict["type"] != "custom":
956
- return {"success": False}
957
-
958
- element = CustomElement(
959
- id=element_dict["id"],
960
- object_key=element_dict["objectKey"],
961
- chainlit_key=element_dict["chainlitKey"],
962
- url=element_dict["url"],
963
- for_id=element_dict.get("forId") or "",
964
- thread_id=element_dict.get("threadId") or "",
965
- name=element_dict["name"],
966
- props=element_dict.get("props") or {},
967
- display=element_dict["display"],
968
- )
969
-
970
- if current_user:
971
- if (
972
- not context.session.user
973
- or context.session.user.identifier != current_user.identifier
974
- ):
975
- raise HTTPException(
976
- status_code=401,
977
- detail="You are not authorized to remove elements for this session",
978
- )
979
-
980
- await element.remove()
981
-
982
- return {"success": True}
983
-
984
-
985
- @router.put("/project/thread")
986
- async def rename_thread(
987
- request: Request,
988
- payload: UpdateThreadRequest,
989
- current_user: UserParam,
990
- ):
991
- """Rename a thread."""
992
-
993
- data_layer = get_data_layer()
994
-
995
- if not data_layer:
996
- raise HTTPException(status_code=400, detail="Data persistence is not enabled")
997
-
998
- if not current_user:
999
- raise HTTPException(status_code=401, detail="Unauthorized")
1000
-
1001
- thread_id = payload.threadId
1002
-
1003
- await is_thread_author(current_user.identifier, thread_id)
1004
-
1005
- await data_layer.update_thread(thread_id, name=payload.name)
1006
- return JSONResponse(content={"success": True})
1007
-
1008
-
1009
820
  @router.delete("/project/thread")
1010
821
  async def delete_thread(
1011
822
  request: Request,
1012
823
  payload: DeleteThreadRequest,
1013
- current_user: UserParam,
824
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1014
825
  ):
1015
826
  """Delete a thread."""
1016
827
 
@@ -1019,9 +830,6 @@ async def delete_thread(
1019
830
  if not data_layer:
1020
831
  raise HTTPException(status_code=400, detail="Data persistence is not enabled")
1021
832
 
1022
- if not current_user:
1023
- raise HTTPException(status_code=401, detail="Unauthorized")
1024
-
1025
833
  thread_id = payload.threadId
1026
834
 
1027
835
  await is_thread_author(current_user.identifier, thread_id)
@@ -1030,47 +838,9 @@ async def delete_thread(
1030
838
  return JSONResponse(content={"success": True})
1031
839
 
1032
840
 
1033
- @router.post("/project/action")
1034
- async def call_action(
1035
- payload: CallActionRequest,
1036
- current_user: UserParam,
1037
- ):
1038
- """Run an action."""
1039
-
1040
- from chainlit.action import Action
1041
- from chainlit.context import init_ws_context
1042
- from chainlit.session import WebsocketSession
1043
-
1044
- session = WebsocketSession.get_by_id(payload.sessionId)
1045
- context = init_ws_context(session)
1046
-
1047
- action = Action(**payload.action)
1048
-
1049
- if current_user:
1050
- if (
1051
- not context.session.user
1052
- or context.session.user.identifier != current_user.identifier
1053
- ):
1054
- raise HTTPException(
1055
- status_code=401,
1056
- detail="You are not authorized to upload files for this session",
1057
- )
1058
-
1059
- callback = config.code.action_callbacks.get(action.name)
1060
- if callback:
1061
- await callback(action)
1062
- else:
1063
- raise HTTPException(
1064
- status_code=404,
1065
- detail=f"No callback found for action {action.name}",
1066
- )
1067
-
1068
- return JSONResponse(content={"success": True})
1069
-
1070
-
1071
841
  @router.post("/project/file")
1072
842
  async def upload_file(
1073
- current_user: UserParam,
843
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1074
844
  session_id: str,
1075
845
  file: UploadFile,
1076
846
  ):
@@ -1100,11 +870,6 @@ async def upload_file(
1100
870
  assert file.filename, "No filename for uploaded file"
1101
871
  assert file.content_type, "No content type for uploaded file"
1102
872
 
1103
- try:
1104
- validate_file_upload(file)
1105
- except ValueError as e:
1106
- raise HTTPException(status_code=400, detail=str(e))
1107
-
1108
873
  file_response = await session.persist_file(
1109
874
  name=file.filename, content=content, mime=file.content_type
1110
875
  )
@@ -1112,79 +877,14 @@ async def upload_file(
1112
877
  return JSONResponse(content=file_response)
1113
878
 
1114
879
 
1115
- def validate_file_upload(file: UploadFile):
1116
- """Validate the file upload as configured in config.features.spontaneous_file_upload.
1117
- Args:
1118
- file (UploadFile): The file to validate.
1119
- Raises:
1120
- ValueError: If the file is not allowed.
1121
- """
1122
- if config.features.spontaneous_file_upload is None:
1123
- """Default for a missing config is to allow the fileupload without any restrictions"""
1124
- return
1125
- if config.features.spontaneous_file_upload.enabled is False:
1126
- raise ValueError("File upload is not enabled")
1127
-
1128
- validate_file_mime_type(file)
1129
- validate_file_size(file)
1130
-
1131
-
1132
- def validate_file_mime_type(file: UploadFile):
1133
- """Validate the file mime type as configured in config.features.spontaneous_file_upload.
1134
- Args:
1135
- file (UploadFile): The file to validate.
1136
- Raises:
1137
- ValueError: If the file type is not allowed.
1138
- """
1139
- accept = config.features.spontaneous_file_upload.accept
1140
- if accept is None:
1141
- "Accept is not configured, allowing all file types"
1142
- return
1143
-
1144
- assert (
1145
- isinstance(accept, List) or isinstance(accept, dict)
1146
- ), "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
1147
-
1148
- if isinstance(accept, List):
1149
- for pattern in accept:
1150
- if fnmatch.fnmatch(file.content_type, pattern):
1151
- return
1152
- elif isinstance(accept, dict):
1153
- for pattern, extensions in accept.items():
1154
- if fnmatch.fnmatch(file.content_type, pattern):
1155
- if len(extensions) == 0:
1156
- return
1157
- for extension in extensions:
1158
- if file.filename is not None and file.filename.endswith(extension):
1159
- return
1160
- raise ValueError("File type not allowed")
1161
-
1162
-
1163
- def validate_file_size(file: UploadFile):
1164
- """Validate the file size as configured in config.features.spontaneous_file_upload.
1165
- Args:
1166
- file (UploadFile): The file to validate.
1167
- Raises:
1168
- ValueError: If the file size is too large.
1169
- """
1170
- if config.features.spontaneous_file_upload.max_size_mb is None:
1171
- return
1172
-
1173
- if (
1174
- file.size is not None
1175
- and file.size
1176
- > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
1177
- ):
1178
- raise ValueError("File size too large")
1179
-
1180
-
1181
880
  @router.get("/project/file/{file_id}")
1182
881
  async def get_file(
1183
882
  file_id: str,
1184
883
  session_id: str,
1185
- current_user: UserParam,
884
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1186
885
  ):
1187
886
  """Get a file from the session files directory."""
887
+
1188
888
  from chainlit.session import WebsocketSession
1189
889
 
1190
890
  session = WebsocketSession.get_by_id(session_id) if session_id else None
@@ -1212,7 +912,7 @@ async def get_file(
1212
912
  @router.get("/files/{filename:path}")
1213
913
  async def serve_file(
1214
914
  filename: str,
1215
- current_user: UserParam,
915
+ current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
1216
916
  ):
1217
917
  """Serve a file from the local filesystem."""
1218
918