chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (113) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +123 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +191 -102
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +614 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +81 -29
  33. chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
  48. chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
  49. chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
  50. chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +103 -68
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/callbacks.py +65 -40
  57. chainlit/markdown.py +22 -6
  58. chainlit/message.py +54 -56
  59. chainlit/mistralai/__init__.py +50 -0
  60. chainlit/oauth_providers.py +266 -8
  61. chainlit/openai/__init__.py +10 -18
  62. chainlit/secret.py +1 -1
  63. chainlit/server.py +789 -228
  64. chainlit/session.py +108 -90
  65. chainlit/slack/__init__.py +6 -0
  66. chainlit/slack/app.py +397 -0
  67. chainlit/socket.py +199 -116
  68. chainlit/step.py +141 -89
  69. chainlit/sync.py +2 -1
  70. chainlit/teams/__init__.py +6 -0
  71. chainlit/teams/app.py +338 -0
  72. chainlit/translations/bn.json +244 -0
  73. chainlit/translations/en-US.json +122 -8
  74. chainlit/translations/gu.json +244 -0
  75. chainlit/translations/he-IL.json +244 -0
  76. chainlit/translations/hi.json +244 -0
  77. chainlit/translations/ja.json +242 -0
  78. chainlit/translations/kn.json +244 -0
  79. chainlit/translations/ml.json +244 -0
  80. chainlit/translations/mr.json +244 -0
  81. chainlit/translations/nl-NL.json +242 -0
  82. chainlit/translations/ta.json +244 -0
  83. chainlit/translations/te.json +244 -0
  84. chainlit/translations/zh-CN.json +243 -0
  85. chainlit/translations.py +60 -0
  86. chainlit/types.py +133 -28
  87. chainlit/user.py +14 -3
  88. chainlit/user_session.py +6 -3
  89. chainlit/utils.py +52 -5
  90. chainlit/version.py +3 -2
  91. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
  92. chainlit-2.0.4.dist-info/RECORD +107 -0
  93. chainlit/cli/utils.py +0 -24
  94. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  95. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  96. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  97. chainlit/playground/__init__.py +0 -2
  98. chainlit/playground/config.py +0 -40
  99. chainlit/playground/provider.py +0 -108
  100. chainlit/playground/providers/__init__.py +0 -13
  101. chainlit/playground/providers/anthropic.py +0 -118
  102. chainlit/playground/providers/huggingface.py +0 -75
  103. chainlit/playground/providers/langchain.py +0 -89
  104. chainlit/playground/providers/openai.py +0 -408
  105. chainlit/playground/providers/vertexai.py +0 -171
  106. chainlit/translations/pt-BR.json +0 -155
  107. chainlit-1.0.401.dist-info/RECORD +0 -66
  108. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  109. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  111. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  112. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
  113. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/__init__.py CHANGED
@@ -2,33 +2,31 @@ import os
2
2
 
3
3
  from dotenv import load_dotenv
4
4
 
5
+ # ruff: noqa: E402
6
+ # Keep this here to ensure imports have environment available.
5
7
  env_found = load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))
6
8
 
9
+ from chainlit.logger import logger
10
+
11
+ if env_found:
12
+ logger.info("Loaded .env file")
13
+
7
14
  import asyncio
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
15
+ from typing import TYPE_CHECKING, Any, Dict
9
16
 
10
- from fastapi import Request, Response
17
+ from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
11
18
  from pydantic.dataclasses import dataclass
12
- from starlette.datastructures import Headers
13
-
14
- if TYPE_CHECKING:
15
- from chainlit.haystack.callbacks import HaystackAgentCallbackHandler
16
- from chainlit.langchain.callbacks import (
17
- LangchainCallbackHandler,
18
- AsyncLangchainCallbackHandler,
19
- )
20
- from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler
21
- from chainlit.openai import instrument_openai
22
19
 
23
20
  import chainlit.input_widget as input_widget
24
21
  from chainlit.action import Action
25
22
  from chainlit.cache import cache
23
+ from chainlit.chat_context import chat_context
26
24
  from chainlit.chat_settings import ChatSettings
27
- from chainlit.config import config
28
25
  from chainlit.context import context
29
26
  from chainlit.element import (
30
27
  Audio,
31
- Avatar,
28
+ CustomElement,
29
+ Dataframe,
32
30
  File,
33
31
  Image,
34
32
  Pdf,
@@ -40,7 +38,6 @@ from chainlit.element import (
40
38
  Text,
41
39
  Video,
42
40
  )
43
- from chainlit.logger import logger
44
41
  from chainlit.message import (
45
42
  AskActionMessage,
46
43
  AskFileMessage,
@@ -48,243 +45,46 @@ from chainlit.message import (
48
45
  ErrorMessage,
49
46
  Message,
50
47
  )
51
- from chainlit.oauth_providers import get_configured_oauth_providers
52
48
  from chainlit.step import Step, step
53
49
  from chainlit.sync import make_async, run_sync
54
- from chainlit.telemetry import trace
55
- from chainlit.types import ChatProfile, ThreadDict
50
+ from chainlit.types import ChatProfile, InputAudioChunk, OutputAudioChunk, Starter
56
51
  from chainlit.user import PersistedUser, User
57
52
  from chainlit.user_session import user_session
58
- from chainlit.utils import make_module_getattr, wrap_user_function
53
+ from chainlit.utils import make_module_getattr
59
54
  from chainlit.version import __version__
60
- from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
61
-
62
- if env_found:
63
- logger.info("Loaded .env file")
64
-
65
-
66
- @trace
67
- def password_auth_callback(func: Callable[[str, str], Optional[User]]) -> Callable:
68
- """
69
- Framework agnostic decorator to authenticate the user.
70
-
71
- Args:
72
- func (Callable[[str, str], Optional[User]]): The authentication callback to execute. Takes the email and password as parameters.
73
-
74
- Example:
75
- @cl.password_auth_callback
76
- async def password_auth_callback(username: str, password: str) -> Optional[User]:
77
-
78
- Returns:
79
- Callable[[str, str], Optional[User]]: The decorated authentication callback.
80
- """
81
-
82
- config.code.password_auth_callback = wrap_user_function(func)
83
- return func
84
-
85
-
86
- @trace
87
- def header_auth_callback(func: Callable[[Headers], Optional[User]]) -> Callable:
88
- """
89
- Framework agnostic decorator to authenticate the user via a header
90
-
91
- Args:
92
- func (Callable[[Headers], Optional[User]]): The authentication callback to execute.
93
-
94
- Example:
95
- @cl.header_auth_callback
96
- async def header_auth_callback(headers: Headers) -> Optional[User]:
97
-
98
- Returns:
99
- Callable[[Headers], Optional[User]]: The decorated authentication callback.
100
- """
101
-
102
- config.code.header_auth_callback = wrap_user_function(func)
103
- return func
104
-
105
-
106
- @trace
107
- def oauth_callback(
108
- func: Callable[[str, str, Dict[str, str], User], Optional[User]]
109
- ) -> Callable:
110
- """
111
- Framework agnostic decorator to authenticate the user via oauth
112
-
113
- Args:
114
- func (Callable[[str, str, Dict[str, str], User], Optional[User]]): The authentication callback to execute.
115
-
116
- Example:
117
- @cl.oauth_callback
118
- async def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], default_app_user: User) -> Optional[User]:
119
-
120
- Returns:
121
- Callable[[str, str, Dict[str, str], User], Optional[User]]: The decorated authentication callback.
122
- """
123
-
124
- if len(get_configured_oauth_providers()) == 0:
125
- raise ValueError(
126
- "You must set the environment variable for at least one oauth provider to use oauth authentication."
127
- )
128
-
129
- config.code.oauth_callback = wrap_user_function(func)
130
- return func
131
-
132
-
133
- @trace
134
- def on_logout(func: Callable[[Request, Response], Any]) -> Callable:
135
- """
136
- Function called when the user logs out.
137
- Takes the FastAPI request and response as parameters.
138
- """
139
-
140
- config.code.on_logout = wrap_user_function(func)
141
- return func
142
-
143
-
144
- @trace
145
- def on_message(func: Callable) -> Callable:
146
- """
147
- Framework agnostic decorator to react to messages coming from the UI.
148
- The decorated function is called every time a new message is received.
149
-
150
- Args:
151
- func (Callable[[Message], Any]): The function to be called when a new message is received. Takes a cl.Message.
152
-
153
- Returns:
154
- Callable[[str], Any]: The decorated on_message function.
155
- """
156
-
157
- config.code.on_message = wrap_user_function(func)
158
- return func
159
-
160
-
161
- @trace
162
- def on_chat_start(func: Callable) -> Callable:
163
- """
164
- Hook to react to the user websocket connection event.
165
-
166
- Args:
167
- func (Callable[], Any]): The connection hook to execute.
168
-
169
- Returns:
170
- Callable[], Any]: The decorated hook.
171
- """
172
-
173
- config.code.on_chat_start = wrap_user_function(func, with_task=True)
174
- return func
175
-
176
-
177
- @trace
178
- def on_chat_resume(func: Callable[[ThreadDict], Any]) -> Callable:
179
- """
180
- Hook to react to resume websocket connection event.
181
-
182
- Args:
183
- func (Callable[], Any]): The connection hook to execute.
184
-
185
- Returns:
186
- Callable[], Any]: The decorated hook.
187
- """
188
-
189
- config.code.on_chat_resume = wrap_user_function(func, with_task=True)
190
- return func
191
-
192
-
193
- @trace
194
- def set_chat_profiles(
195
- func: Callable[[Optional["User"]], List["ChatProfile"]]
196
- ) -> Callable:
197
- """
198
- Programmatic declaration of the available chat profiles (can depend on the User from the session if authentication is setup).
199
-
200
- Args:
201
- func (Callable[[Optional["User"]], List["ChatProfile"]]): The function declaring the chat profiles.
202
-
203
- Returns:
204
- Callable[[Optional["User"]], List["ChatProfile"]]: The decorated function.
205
- """
206
-
207
- config.code.set_chat_profiles = wrap_user_function(func)
208
- return func
209
-
210
-
211
- @trace
212
- def on_chat_end(func: Callable) -> Callable:
213
- """
214
- Hook to react to the user websocket disconnect event.
215
-
216
- Args:
217
- func (Callable[], Any]): The disconnect hook to execute.
218
-
219
- Returns:
220
- Callable[], Any]: The decorated hook.
221
- """
222
-
223
- config.code.on_chat_end = wrap_user_function(func, with_task=True)
224
- return func
225
-
226
-
227
- @trace
228
- def author_rename(func: Callable[[str], str]) -> Callable[[str], str]:
229
- """
230
- Useful to rename the author of message to display more friendly author names in the UI.
231
- Args:
232
- func (Callable[[str], str]): The function to be called to rename an author. Takes the original author name as parameter.
233
-
234
- Returns:
235
- Callable[[Any, str], Any]: The decorated function.
236
- """
237
55
 
238
- config.code.author_rename = wrap_user_function(func)
239
- return func
240
-
241
-
242
- @trace
243
- def on_stop(func: Callable) -> Callable:
244
- """
245
- Hook to react to the user stopping a thread.
246
-
247
- Args:
248
- func (Callable[[], Any]): The stop hook to execute.
249
-
250
- Returns:
251
- Callable[[], Any]: The decorated stop hook.
252
- """
253
-
254
- config.code.on_stop = wrap_user_function(func)
255
- return func
256
-
257
-
258
- def action_callback(name: str) -> Callable:
259
- """
260
- Callback to call when an action is clicked in the UI.
261
-
262
- Args:
263
- func (Callable[[Action], Any]): The action callback to execute. First parameter is the action.
264
- """
265
-
266
- def decorator(func: Callable[[Action], Any]):
267
- config.code.action_callbacks[name] = wrap_user_function(func, with_task=True)
268
- return func
269
-
270
- return decorator
271
-
272
-
273
- def on_settings_update(
274
- func: Callable[[Dict[str, Any]], Any]
275
- ) -> Callable[[Dict[str, Any]], Any]:
276
- """
277
- Hook to react to the user changing any settings.
278
-
279
- Args:
280
- func (Callable[], Any]): The hook to execute after settings were changed.
281
-
282
- Returns:
283
- Callable[], Any]: The decorated hook.
284
- """
56
+ from .callbacks import (
57
+ action_callback,
58
+ author_rename,
59
+ data_layer,
60
+ header_auth_callback,
61
+ oauth_callback,
62
+ on_audio_chunk,
63
+ on_audio_end,
64
+ on_audio_start,
65
+ on_chat_end,
66
+ on_chat_resume,
67
+ on_chat_start,
68
+ on_logout,
69
+ on_message,
70
+ on_settings_update,
71
+ on_stop,
72
+ on_window_message,
73
+ password_auth_callback,
74
+ send_window_message,
75
+ set_chat_profiles,
76
+ set_starters,
77
+ )
285
78
 
286
- config.code.on_settings_update = wrap_user_function(func, with_task=True)
287
- return func
79
+ if TYPE_CHECKING:
80
+ from chainlit.haystack.callbacks import HaystackAgentCallbackHandler
81
+ from chainlit.langchain.callbacks import (
82
+ AsyncLangchainCallbackHandler,
83
+ LangchainCallbackHandler,
84
+ )
85
+ from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler
86
+ from chainlit.mistralai import instrument_mistralai
87
+ from chainlit.openai import instrument_openai
288
88
 
289
89
 
290
90
  def sleep(duration: int):
@@ -312,59 +112,78 @@ __getattr__ = make_module_getattr(
312
112
  "LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks",
313
113
  "HaystackAgentCallbackHandler": "chainlit.haystack.callbacks",
314
114
  "instrument_openai": "chainlit.openai",
115
+ "instrument_mistralai": "chainlit.mistralai",
315
116
  }
316
117
  )
317
118
 
318
119
  __all__ = [
319
- "user_session",
320
- "CopilotFunction",
321
120
  "Action",
322
- "User",
323
- "PersistedUser",
121
+ "AskActionMessage",
122
+ "AskFileMessage",
123
+ "AskUserMessage",
124
+ "AsyncLangchainCallbackHandler",
324
125
  "Audio",
126
+ "ChatGeneration",
127
+ "ChatProfile",
128
+ "ChatSettings",
129
+ "CompletionGeneration",
130
+ "CopilotFunction",
131
+ "CustomElement",
132
+ "Dataframe",
133
+ "ErrorMessage",
134
+ "File",
135
+ "GenerationMessage",
136
+ "HaystackAgentCallbackHandler",
137
+ "Image",
138
+ "InputAudioChunk",
139
+ "LangchainCallbackHandler",
140
+ "LlamaIndexCallbackHandler",
141
+ "Message",
142
+ "OutputAudioChunk",
325
143
  "Pdf",
144
+ "PersistedUser",
326
145
  "Plotly",
327
- "Image",
328
- "Text",
329
- "Avatar",
330
146
  "Pyplot",
331
- "File",
147
+ "Starter",
148
+ "Step",
332
149
  "Task",
333
150
  "TaskList",
334
151
  "TaskStatus",
152
+ "Text",
153
+ "User",
335
154
  "Video",
336
- "ChatSettings",
155
+ "__version__",
156
+ "action_callback",
157
+ "author_rename",
158
+ "cache",
159
+ "chat_context",
160
+ "context",
161
+ "data_layer",
162
+ "header_auth_callback",
337
163
  "input_widget",
338
- "Message",
339
- "ErrorMessage",
340
- "AskUserMessage",
341
- "AskActionMessage",
342
- "AskFileMessage",
343
- "Step",
344
- "step",
345
- "ChatGeneration",
346
- "CompletionGeneration",
347
- "GenerationMessage",
348
- "on_logout",
349
- "on_chat_start",
164
+ "instrument_mistralai",
165
+ "instrument_openai",
166
+ "make_async",
167
+ "oauth_callback",
168
+ "on_audio_chunk",
169
+ "on_audio_end",
170
+ "on_audio_start",
350
171
  "on_chat_end",
351
172
  "on_chat_resume",
352
- "on_stop",
353
- "action_callback",
354
- "author_rename",
173
+ "on_chat_start",
174
+ "on_logout",
175
+ "on_message",
355
176
  "on_settings_update",
177
+ "on_stop",
178
+ "on_window_message",
356
179
  "password_auth_callback",
357
- "header_auth_callback",
358
- "sleep",
359
180
  "run_sync",
360
- "make_async",
361
- "cache",
362
- "context",
363
- "LangchainCallbackHandler",
364
- "AsyncLangchainCallbackHandler",
365
- "LlamaIndexCallbackHandler",
366
- "HaystackAgentCallbackHandler",
367
- "instrument_openai",
181
+ "send_window_message",
182
+ "set_chat_profiles",
183
+ "set_starters",
184
+ "sleep",
185
+ "step",
186
+ "user_session",
368
187
  ]
369
188
 
370
189
 
chainlit/_utils.py ADDED
@@ -0,0 +1,8 @@
1
+ """Util functions which are explicitly not part of the public API."""
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def is_path_inside(child_path: Path, parent_path: Path) -> bool:
7
+ """Check if the child path is inside the parent path."""
8
+ return parent_path.resolve() in child_path.resolve().parents
chainlit/action.py CHANGED
@@ -1,28 +1,30 @@
1
1
  import uuid
2
- from typing import Optional
2
+ from typing import Dict, Optional
3
+
4
+ from dataclasses_json import DataClassJsonMixin
5
+ from pydantic import Field
6
+ from pydantic.dataclasses import dataclass
3
7
 
4
8
  from chainlit.context import context
5
9
  from chainlit.telemetry import trace_event
6
- from dataclasses_json import DataClassJsonMixin
7
- from pydantic.dataclasses import Field, dataclass
8
10
 
9
11
 
10
12
  @dataclass
11
13
  class Action(DataClassJsonMixin):
12
14
  # Name of the action, this should be used in the action_callback
13
15
  name: str
14
- # The value associated with the action. This is useful to differentiate between multiple actions with the same name.
15
- value: str
16
- # The label of the action. This is what the user will see. If not provided the name will be used.
16
+ # The parameters to call this action with.
17
+ payload: Dict
18
+ # The label of the action. This is what the user will see.
17
19
  label: str = ""
18
- # The description of the action. This is what the user will see when they hover the action.
19
- description: str = ""
20
+ # The tooltip of the action button. This is what the user will see when they hover the action.
21
+ tooltip: str = ""
22
+ # The lucid icon name for this action.
23
+ icon: Optional[str] = None
20
24
  # This should not be set manually, only used internally.
21
25
  forId: Optional[str] = None
22
26
  # The ID of the action
23
27
  id: str = Field(default_factory=lambda: str(uuid.uuid4()))
24
- # Show the action in a drawer menu
25
- collapsed: bool = False
26
28
 
27
29
  def __post_init__(self) -> None:
28
30
  trace_event(f"init {self.__class__.__name__}")
@@ -1,20 +1,16 @@
1
1
  import os
2
- from datetime import datetime, timedelta
3
- from typing import Any, Dict
4
2
 
5
- import jwt
3
+ from fastapi import Depends, HTTPException
4
+
6
5
  from chainlit.config import config
7
6
  from chainlit.data import get_data_layer
7
+ from chainlit.logger import logger
8
8
  from chainlit.oauth_providers import get_configured_oauth_providers
9
- from chainlit.user import User
10
- from fastapi import Depends, HTTPException
11
- from fastapi.security import OAuth2PasswordBearer
12
-
13
- reuseable_oauth = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
14
9
 
10
+ from .cookie import OAuth2PasswordBearerWithCookie
11
+ from .jwt import create_jwt, decode_jwt, get_jwt_secret
15
12
 
16
- def get_jwt_secret():
17
- return os.environ.get("CHAINLIT_AUTH_SECRET")
13
+ reuseable_oauth = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False)
18
14
 
19
15
 
20
16
  def ensure_jwt_secret():
@@ -42,46 +38,39 @@ def get_configuration():
42
38
  "requireLogin": require_login(),
43
39
  "passwordAuth": config.code.password_auth_callback is not None,
44
40
  "headerAuth": config.code.header_auth_callback is not None,
45
- "oauthProviders": get_configured_oauth_providers()
46
- if is_oauth_enabled()
47
- else [],
41
+ "oauthProviders": (
42
+ get_configured_oauth_providers() if is_oauth_enabled() else []
43
+ ),
44
+ "default_theme": config.ui.default_theme,
48
45
  }
49
46
 
50
47
 
51
- def create_jwt(data: User) -> str:
52
- to_encode = data.to_dict() # type: Dict[str, Any]
53
- to_encode.update(
54
- {
55
- "exp": datetime.utcnow() + timedelta(minutes=60 * 24 * 15), # 15 days
56
- }
57
- )
58
- encoded_jwt = jwt.encode(to_encode, get_jwt_secret(), algorithm="HS256")
59
- return encoded_jwt
60
-
61
-
62
48
  async def authenticate_user(token: str = Depends(reuseable_oauth)):
63
49
  try:
64
- dict = jwt.decode(
65
- token,
66
- get_jwt_secret(),
67
- algorithms=["HS256"],
68
- options={"verify_signature": True},
69
- )
70
- del dict["exp"]
71
- user = User(**dict)
50
+ user = decode_jwt(token)
72
51
  except Exception as e:
73
- raise HTTPException(status_code=401, detail="Invalid authentication token")
52
+ raise HTTPException(
53
+ status_code=401, detail="Invalid authentication token"
54
+ ) from e
55
+
74
56
  if data_layer := get_data_layer():
57
+ # Get or create persistent user if we've a data layer available.
75
58
  try:
76
59
  persisted_user = await data_layer.get_user(user.identifier)
77
- if persisted_user == None:
60
+ if persisted_user is None:
78
61
  persisted_user = await data_layer.create_user(user)
62
+ assert persisted_user
79
63
  except Exception as e:
64
+ logger.exception("Unable to get persisted_user from data layer: %s", e)
80
65
  return user
81
66
 
67
+ if user and user.display_name:
68
+ # Copy ephemeral display_name from authenticated user to persistent user.
69
+ persisted_user.display_name = user.display_name
70
+
82
71
  return persisted_user
83
- else:
84
- return user
72
+
73
+ return user
85
74
 
86
75
 
87
76
  async def get_current_user(token: str = Depends(reuseable_oauth)):
@@ -89,3 +78,6 @@ async def get_current_user(token: str = Depends(reuseable_oauth)):
89
78
  return None
90
79
 
91
80
  return await authenticate_user(token)
81
+
82
+
83
+ __all__ = ["create_jwt", "get_configuration", "get_current_user"]