chainlit 1.1.300rc4__py3-none-any.whl → 1.1.300rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (43) hide show
  1. chainlit/__init__.py +20 -1
  2. chainlit/cli/__init__.py +48 -6
  3. chainlit/config.py +5 -0
  4. chainlit/context.py +9 -0
  5. chainlit/copilot/dist/index.js +180 -180
  6. chainlit/data/__init__.py +6 -3
  7. chainlit/data/sql_alchemy.py +3 -3
  8. chainlit/element.py +33 -9
  9. chainlit/emitter.py +3 -3
  10. chainlit/frontend/dist/assets/{DailyMotion-1a2b7d60.js → DailyMotion-f9db5a1d.js} +1 -1
  11. chainlit/frontend/dist/assets/{Facebook-8422f48c.js → Facebook-f95b29c9.js} +1 -1
  12. chainlit/frontend/dist/assets/{FilePlayer-a0a41349.js → FilePlayer-ba3f562c.js} +1 -1
  13. chainlit/frontend/dist/assets/{Kaltura-aa7990f2.js → Kaltura-195ed801.js} +1 -1
  14. chainlit/frontend/dist/assets/{Mixcloud-1647b3e4.js → Mixcloud-f64c6d87.js} +1 -1
  15. chainlit/frontend/dist/assets/{Mux-7e57be81.js → Mux-206cbddc.js} +1 -1
  16. chainlit/frontend/dist/assets/{Preview-cb89b2c6.js → Preview-af249586.js} +1 -1
  17. chainlit/frontend/dist/assets/{SoundCloud-c0d86d55.js → SoundCloud-80a26cdf.js} +1 -1
  18. chainlit/frontend/dist/assets/{Streamable-57c43c18.js → Streamable-f80b255d.js} +1 -1
  19. chainlit/frontend/dist/assets/{Twitch-bed3f21d.js → Twitch-0e2f1d13.js} +1 -1
  20. chainlit/frontend/dist/assets/{Vidyard-4dd76e44.js → Vidyard-bd67bfc6.js} +1 -1
  21. chainlit/frontend/dist/assets/{Vimeo-93bb5ae2.js → Vimeo-f9496a5d.js} +1 -1
  22. chainlit/frontend/dist/assets/{Wistia-97199246.js → Wistia-a943e0aa.js} +1 -1
  23. chainlit/frontend/dist/assets/{YouTube-fe1a7afe.js → YouTube-cf572a1f.js} +1 -1
  24. chainlit/frontend/dist/assets/{index-919bea8f.js → index-5511e258.js} +131 -131
  25. chainlit/frontend/dist/assets/{react-plotly-b416b8f9.js → react-plotly-74b55763.js} +1 -1
  26. chainlit/frontend/dist/index.html +1 -2
  27. chainlit/message.py +13 -8
  28. chainlit/oauth_providers.py +63 -1
  29. chainlit/server.py +170 -48
  30. chainlit/socket.py +40 -20
  31. chainlit/step.py +12 -3
  32. chainlit/teams/__init__.py +6 -0
  33. chainlit/teams/app.py +332 -0
  34. chainlit/translations/en-US.json +1 -1
  35. chainlit/types.py +7 -17
  36. chainlit/user.py +9 -1
  37. chainlit/utils.py +42 -0
  38. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.300rc5.dist-info}/METADATA +2 -2
  39. chainlit-1.1.300rc5.dist-info/RECORD +79 -0
  40. chainlit/cli/utils.py +0 -24
  41. chainlit-1.1.300rc4.dist-info/RECORD +0 -78
  42. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.300rc5.dist-info}/WHEEL +0 -0
  43. {chainlit-1.1.300rc4.dist-info → chainlit-1.1.300rc5.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,6 @@
4
4
  <meta charset="UTF-8" />
5
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
6
  <!-- TAG INJECTION PLACEHOLDER -->
7
- <link rel="icon" href="/favicon" />
8
7
  <link rel="preconnect" href="https://fonts.googleapis.com" />
9
8
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
10
9
  <!-- FONT START -->
@@ -22,7 +21,7 @@
22
21
  <script>
23
22
  const global = globalThis;
24
23
  </script>
25
- <script type="module" crossorigin src="/assets/index-919bea8f.js"></script>
24
+ <script type="module" crossorigin src="/assets/index-5511e258.js"></script>
26
25
  <link rel="stylesheet" href="/assets/index-aaf974a9.css">
27
26
  </head>
28
27
  <body>
chainlit/message.py CHANGED
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union, cast
7
7
 
8
8
  from chainlit.action import Action
9
9
  from chainlit.config import config
10
- from chainlit.context import context
10
+ from chainlit.context import context, local_steps
11
11
  from chainlit.data import get_data_layer
12
12
  from chainlit.element import ElementBased
13
13
  from chainlit.logger import logger
@@ -21,7 +21,6 @@ from chainlit.types import (
21
21
  AskSpec,
22
22
  FileDict,
23
23
  )
24
- from literalai import BaseGeneration
25
24
  from literalai.helper import utc_now
26
25
  from literalai.step import MessageStepType
27
26
 
@@ -38,17 +37,22 @@ class MessageBase(ABC):
38
37
  fail_on_persist_error: bool = False
39
38
  persisted = False
40
39
  is_error = False
40
+ parent_id: Optional[str] = None
41
41
  language: Optional[str] = None
42
42
  metadata: Optional[Dict] = None
43
43
  tags: Optional[List[str]] = None
44
44
  wait_for_answer = False
45
45
  indent: Optional[int] = None
46
- generation: Optional[BaseGeneration] = None
47
46
 
48
47
  def __post_init__(self) -> None:
49
48
  trace_event(f"init {self.__class__.__name__}")
50
49
  self.thread_id = context.session.thread_id
51
50
 
51
+ previous_steps = local_steps.get() or []
52
+ parent_step = previous_steps[-1] if previous_steps else None
53
+ if parent_step:
54
+ self.parent_id = parent_step.id
55
+
52
56
  if not getattr(self, "id", None):
53
57
  self.id = str(uuid.uuid4())
54
58
 
@@ -57,6 +61,7 @@ class MessageBase(ABC):
57
61
  type = _dict.get("type", "assistant_message")
58
62
  message = Message(
59
63
  id=_dict["id"],
64
+ parent_id=_dict.get("parentId"),
60
65
  created_at=_dict["createdAt"],
61
66
  content=_dict["output"],
62
67
  author=_dict.get("name", config.ui.name),
@@ -71,6 +76,7 @@ class MessageBase(ABC):
71
76
  _dict: StepDict = {
72
77
  "id": self.id,
73
78
  "threadId": self.thread_id,
79
+ "parentId": self.parent_id,
74
80
  "createdAt": self.created_at,
75
81
  "start": self.created_at,
76
82
  "end": self.created_at,
@@ -84,7 +90,6 @@ class MessageBase(ABC):
84
90
  "isError": self.is_error,
85
91
  "waitForAnswer": self.wait_for_answer,
86
92
  "indent": self.indent,
87
- "generation": self.generation.to_dict() if self.generation else None,
88
93
  "metadata": self.metadata or {},
89
94
  "tags": self.tags,
90
95
  }
@@ -212,15 +217,14 @@ class Message(MessageBase):
212
217
  elements: Optional[List[ElementBased]] = None,
213
218
  disable_feedback: bool = False,
214
219
  type: MessageStepType = "assistant_message",
215
- generation: Optional[BaseGeneration] = None,
216
220
  metadata: Optional[Dict] = None,
217
221
  tags: Optional[List[str]] = None,
218
222
  id: Optional[str] = None,
223
+ parent_id: Optional[str] = None,
219
224
  created_at: Union[str, None] = None,
220
225
  ):
221
226
  time.sleep(0.001)
222
227
  self.language = language
223
- self.generation = generation
224
228
  if isinstance(content, dict):
225
229
  try:
226
230
  self.content = json.dumps(content, indent=4, ensure_ascii=False)
@@ -237,6 +241,9 @@ class Message(MessageBase):
237
241
  if id:
238
242
  self.id = str(id)
239
243
 
244
+ if parent_id:
245
+ self.parent_id = str(parent_id)
246
+
240
247
  if created_at:
241
248
  self.created_at = created_at
242
249
 
@@ -304,8 +311,6 @@ class ErrorMessage(MessageBase):
304
311
  Args:
305
312
  content (str): Text displayed above the upload button.
306
313
  author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
307
- parent_id (str, optional): If provided, the message will be nested inside the parent in the UI.
308
- indent (int, optional): If positive, the message will be nested in the UI.
309
314
  """
310
315
 
311
316
  def __init__(
@@ -4,6 +4,7 @@ import urllib.parse
4
4
  from typing import Dict, List, Optional, Tuple
5
5
 
6
6
  import httpx
7
+ from chainlit.secret import random_secret
7
8
  from chainlit.user import User
8
9
  from fastapi import HTTPException
9
10
 
@@ -186,6 +187,60 @@ class AzureADOAuthProvider(OAuthProvider):
186
187
  )
187
188
  return token
188
189
 
190
+
191
+ class AzureADHybridOAuthProvider(OAuthProvider):
192
+ id = "azure-ad-hybrid"
193
+ env = [
194
+ "OAUTH_AZURE_AD_HYBRID_CLIENT_ID",
195
+ "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET",
196
+ "OAUTH_AZURE_AD_HYBRID_TENANT_ID",
197
+ ]
198
+ authorize_url = (
199
+ f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize"
200
+ if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
201
+ else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
202
+ )
203
+ token_url = (
204
+ f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token"
205
+ if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
206
+ else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
207
+ )
208
+
209
+ def __init__(self):
210
+ self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID")
211
+ self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET")
212
+ nonce = random_secret(16)
213
+ self.authorize_params = {
214
+ "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
215
+ "response_type": "code id_token",
216
+ "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid",
217
+ "response_mode": "form_post",
218
+ "nonce": nonce,
219
+ }
220
+
221
+ async def get_token(self, code: str, url: str):
222
+ payload = {
223
+ "client_id": self.client_id,
224
+ "client_secret": self.client_secret,
225
+ "code": code,
226
+ "grant_type": "authorization_code",
227
+ "redirect_uri": url,
228
+ }
229
+ async with httpx.AsyncClient() as client:
230
+ response = await client.post(
231
+ self.token_url,
232
+ data=payload,
233
+ )
234
+ response.raise_for_status()
235
+ json = response.json()
236
+
237
+ token = json["access_token"]
238
+ if not token:
239
+ raise HTTPException(
240
+ status_code=400, detail="Failed to get the access token"
241
+ )
242
+ return token
243
+
189
244
  async def get_user_info(self, token: str):
190
245
  async with httpx.AsyncClient() as client:
191
246
  response = await client.get(
@@ -474,9 +529,14 @@ class AWSCognitoOAuthProvider(OAuthProvider):
474
529
  )
475
530
  return (cognito_user, user)
476
531
 
532
+
477
533
  class GitlabOAuthProvider(OAuthProvider):
478
534
  id = "gitlab"
479
- env = ["OAUTH_GITLAB_CLIENT_ID", "OAUTH_GITLAB_CLIENT_SECRET", "OAUTH_GITLAB_DOMAIN"]
535
+ env = [
536
+ "OAUTH_GITLAB_CLIENT_ID",
537
+ "OAUTH_GITLAB_CLIENT_SECRET",
538
+ "OAUTH_GITLAB_DOMAIN",
539
+ ]
480
540
 
481
541
  def __init__(self):
482
542
  self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID")
@@ -530,10 +590,12 @@ class GitlabOAuthProvider(OAuthProvider):
530
590
  )
531
591
  return (gitlab_user, user)
532
592
 
593
+
533
594
  providers = [
534
595
  GithubOAuthProvider(),
535
596
  GoogleOAuthProvider(),
536
597
  AzureADOAuthProvider(),
598
+ AzureADHybridOAuthProvider(),
537
599
  OktaOAuthProvider(),
538
600
  Auth0OAuthProvider(),
539
601
  DescopeOAuthProvider(),
chainlit/server.py CHANGED
@@ -18,6 +18,7 @@ import webbrowser
18
18
  from contextlib import asynccontextmanager
19
19
  from pathlib import Path
20
20
 
21
+ import socketio
21
22
  from chainlit.auth import create_jwt, get_configuration, get_current_user
22
23
  from chainlit.config import (
23
24
  APP_ROOT,
@@ -33,19 +34,19 @@ from chainlit.data import get_data_layer
33
34
  from chainlit.data.acl import is_thread_author
34
35
  from chainlit.logger import logger
35
36
  from chainlit.markdown import get_markdown_str
36
- from chainlit.telemetry import trace_event
37
37
  from chainlit.types import (
38
38
  DeleteFeedbackRequest,
39
39
  DeleteThreadRequest,
40
- GenerationRequest,
41
40
  GetThreadsRequest,
42
41
  Theme,
43
42
  UpdateFeedbackRequest,
44
43
  )
45
44
  from chainlit.user import PersistedUser, User
46
45
  from fastapi import (
46
+ APIRouter,
47
47
  Depends,
48
48
  FastAPI,
49
+ Form,
49
50
  HTTPException,
50
51
  Query,
51
52
  Request,
@@ -56,12 +57,13 @@ from fastapi import (
56
57
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
57
58
  from fastapi.security import OAuth2PasswordRequestForm
58
59
  from fastapi.staticfiles import StaticFiles
59
- from fastapi_socketio import SocketManager
60
60
  from starlette.datastructures import URL
61
61
  from starlette.middleware.cors import CORSMiddleware
62
62
  from typing_extensions import Annotated
63
63
  from watchfiles import awatch
64
64
 
65
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
66
+
65
67
 
66
68
  @asynccontextmanager
67
69
  async def lifespan(app: FastAPI):
@@ -69,9 +71,9 @@ async def lifespan(app: FastAPI):
69
71
  port = config.run.port
70
72
 
71
73
  if host == DEFAULT_HOST:
72
- url = f"http://localhost:{port}"
74
+ url = f"http://localhost:{port}{ROOT_PATH}"
73
75
  else:
74
- url = f"http://{host}:{port}"
76
+ url = f"http://{host}:{port}{ROOT_PATH}"
75
77
 
76
78
  logger.info(f"Your app is available at {url}")
77
79
 
@@ -166,12 +168,26 @@ def get_build_dir(local_target: str, packaged_target: str):
166
168
  build_dir = get_build_dir("frontend", "frontend")
167
169
  copilot_build_dir = get_build_dir(os.path.join("libs", "copilot"), "copilot")
168
170
 
169
-
170
171
  app = FastAPI(lifespan=lifespan)
171
172
 
172
- app.mount("/public", StaticFiles(directory="public", check_dir=False), name="public")
173
+ sio = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi")
174
+
175
+ combined_asgi_app = socketio.ASGIApp(
176
+ sio,
177
+ app,
178
+ socketio_path=f"{ROOT_PATH}/ws/socket.io" if ROOT_PATH else "/ws/socket.io",
179
+ )
180
+
181
+ router = APIRouter(prefix=ROOT_PATH)
182
+
183
+ app.mount(
184
+ f"{ROOT_PATH}/public",
185
+ StaticFiles(directory="public", check_dir=False),
186
+ name="public",
187
+ )
188
+
173
189
  app.mount(
174
- "/assets",
190
+ f"{ROOT_PATH}/assets",
175
191
  StaticFiles(
176
192
  packages=[("chainlit", os.path.join(build_dir, "assets"))],
177
193
  follow_symlink=config.project.follow_symlink,
@@ -180,7 +196,7 @@ app.mount(
180
196
  )
181
197
 
182
198
  app.mount(
183
- "/copilot",
199
+ f"{ROOT_PATH}/copilot",
184
200
  StaticFiles(
185
201
  packages=[("chainlit", copilot_build_dir)],
186
202
  follow_symlink=config.project.follow_symlink,
@@ -188,7 +204,6 @@ app.mount(
188
204
  name="copilot",
189
205
  )
190
206
 
191
-
192
207
  app.add_middleware(
193
208
  CORSMiddleware,
194
209
  allow_origins=config.project.allow_origins,
@@ -197,13 +212,6 @@ app.add_middleware(
197
212
  allow_headers=["*"],
198
213
  )
199
214
 
200
- socket = SocketManager(
201
- app,
202
- cors_allowed_origins=[],
203
- async_mode="asgi",
204
- socketio_path="/ws/socket.io",
205
- )
206
-
207
215
 
208
216
  # -------------------------------------------------------------------------------
209
217
  # SLACK HANDLER
@@ -212,11 +220,28 @@ socket = SocketManager(
212
220
  if os.environ.get("SLACK_BOT_TOKEN") and os.environ.get("SLACK_SIGNING_SECRET"):
213
221
  from chainlit.slack.app import slack_app_handler
214
222
 
215
- @app.post("/slack/events")
216
- async def endpoint(req: Request):
223
+ @router.post("/slack/events")
224
+ async def slack_endpoint(req: Request):
217
225
  return await slack_app_handler.handle(req)
218
226
 
219
227
 
228
+ # -------------------------------------------------------------------------------
229
+ # TEAMS HANDLER
230
+ # -------------------------------------------------------------------------------
231
+
232
+ if os.environ.get("TEAMS_APP_ID") and os.environ.get("TEAMS_APP_PASSWORD"):
233
+ from botbuilder.schema import Activity
234
+ from chainlit.teams.app import adapter, bot
235
+
236
+ @router.post("/teams/events")
237
+ async def teams_endpoint(req: Request):
238
+ body = await req.json()
239
+ activity = Activity().deserialize(body)
240
+ auth_header = req.headers.get("Authorization", "")
241
+ response = await adapter.process_activity(activity, auth_header, bot.on_turn)
242
+ return response
243
+
244
+
220
245
  # -------------------------------------------------------------------------------
221
246
  # HTTP HANDLERS
222
247
  # -------------------------------------------------------------------------------
@@ -238,14 +263,17 @@ def get_html_template():
238
263
  )
239
264
  url = config.ui.github or default_url
240
265
  meta_image_url = config.ui.custom_meta_image_url or default_meta_image_url
266
+ favicon_path = ROOT_PATH + "/favicon" if ROOT_PATH else "/favicon"
241
267
 
242
268
  tags = f"""<title>{config.ui.name}</title>
269
+ <link rel="icon" href="{favicon_path}" />
243
270
  <meta name="description" content="{config.ui.description}">
244
271
  <meta property="og:type" content="website">
245
272
  <meta property="og:title" content="{config.ui.name}">
246
273
  <meta property="og:description" content="{config.ui.description}">
247
274
  <meta property="og:image" content="{meta_image_url}">
248
- <meta property="og:url" content="{url}">"""
275
+ <meta property="og:url" content="{url}">
276
+ <meta property="og:root_path" content="{ROOT_PATH}">"""
249
277
 
250
278
  js = f"""<script>{f"window.theme = {json.dumps(config.ui.theme.to_dict())}; " if config.ui.theme else ""}</script>"""
251
279
 
@@ -275,6 +303,9 @@ def get_html_template():
275
303
  content = replace_between_tags(
276
304
  content, "<!-- FONT START -->", "<!-- FONT END -->", font
277
305
  )
306
+ if ROOT_PATH:
307
+ content = content.replace('href="/', f'href="{ROOT_PATH}/')
308
+ content = content.replace('src="/', f'src="{ROOT_PATH}/')
278
309
  return content
279
310
 
280
311
 
@@ -284,6 +315,7 @@ def get_user_facing_url(url: URL):
284
315
  Handles deployment with proxies (like cloud run).
285
316
  """
286
317
 
318
+ ROOT_PATH = os.environ.get("CHAINLIT_ROOT_PATH", "")
287
319
  chainlit_url = os.environ.get("CHAINLIT_URL")
288
320
 
289
321
  # No config, we keep the URL as is
@@ -299,15 +331,26 @@ def get_user_facing_url(url: URL):
299
331
  if config_url.path.endswith("/"):
300
332
  config_url = config_url.replace(path=config_url.path[:-1])
301
333
 
334
+ # Add ROOT_PATH to the final URL if it exists
335
+ if ROOT_PATH:
336
+ # Ensure ROOT_PATH starts with a slash
337
+ if not ROOT_PATH.startswith("/"):
338
+ ROOT_PATH = "/" + ROOT_PATH
339
+ # Ensure ROOT_PATH does not end with a slash
340
+ if ROOT_PATH.endswith("/"):
341
+ ROOT_PATH = ROOT_PATH[:-1]
342
+
343
+ return config_url.__str__() + ROOT_PATH + url.path
344
+
302
345
  return config_url.__str__() + url.path
303
346
 
304
347
 
305
- @app.get("/auth/config")
348
+ @router.get("/auth/config")
306
349
  async def auth(request: Request):
307
350
  return get_configuration()
308
351
 
309
352
 
310
- @app.post("/login")
353
+ @router.post("/login")
311
354
  async def login(form_data: OAuth2PasswordRequestForm = Depends()):
312
355
  if not config.code.password_auth_callback:
313
356
  raise HTTPException(
@@ -336,14 +379,14 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
336
379
  }
337
380
 
338
381
 
339
- @app.post("/logout")
382
+ @router.post("/logout")
340
383
  async def logout(request: Request, response: Response):
341
384
  if config.code.on_logout:
342
385
  return await config.code.on_logout(request, response)
343
386
  return {"success": True}
344
387
 
345
388
 
346
- @app.post("/auth/header")
389
+ @router.post("/auth/header")
347
390
  async def header_auth(request: Request):
348
391
  if not config.code.header_auth_callback:
349
392
  raise HTTPException(
@@ -372,7 +415,7 @@ async def header_auth(request: Request):
372
415
  }
373
416
 
374
417
 
375
- @app.get("/auth/oauth/{provider_id}")
418
+ @router.get("/auth/oauth/{provider_id}")
376
419
  async def oauth_login(provider_id: str, request: Request):
377
420
  if config.code.oauth_callback is None:
378
421
  raise HTTPException(
@@ -413,7 +456,7 @@ async def oauth_login(provider_id: str, request: Request):
413
456
  return response
414
457
 
415
458
 
416
- @app.get("/auth/oauth/{provider_id}/callback")
459
+ @router.get("/auth/oauth/{provider_id}/callback")
417
460
  async def oauth_callback(
418
461
  provider_id: str,
419
462
  request: Request,
@@ -497,7 +540,85 @@ async def oauth_callback(
497
540
  return response
498
541
 
499
542
 
500
- @app.get("/project/translations")
543
+ # specific route for azure ad hybrid flow
544
+ @router.post("/auth/oauth/azure-ad-hybrid/callback")
545
+ async def oauth_azure_hf_callback(
546
+ request: Request,
547
+ error: Optional[str] = None,
548
+ code: Annotated[Optional[str], Form()] = None,
549
+ id_token: Annotated[Optional[str], Form()] = None,
550
+ ):
551
+ provider_id = "azure-ad-hybrid"
552
+ if config.code.oauth_callback is None:
553
+ raise HTTPException(
554
+ status_code=status.HTTP_400_BAD_REQUEST,
555
+ detail="No oauth_callback defined",
556
+ )
557
+
558
+ provider = get_oauth_provider(provider_id)
559
+ if not provider:
560
+ raise HTTPException(
561
+ status_code=status.HTTP_404_NOT_FOUND,
562
+ detail=f"Provider {provider_id} not found",
563
+ )
564
+
565
+ if error:
566
+ params = urllib.parse.urlencode(
567
+ {
568
+ "error": error,
569
+ }
570
+ )
571
+ response = RedirectResponse(
572
+ # FIXME: redirect to the right frontend base url to improve the dev environment
573
+ url=f"/login?{params}",
574
+ )
575
+ return response
576
+
577
+ if not code:
578
+ raise HTTPException(
579
+ status_code=status.HTTP_400_BAD_REQUEST,
580
+ detail="Missing code",
581
+ )
582
+
583
+ url = get_user_facing_url(request.url)
584
+ token = await provider.get_token(code, url)
585
+
586
+ (raw_user_data, default_user) = await provider.get_user_info(token)
587
+
588
+ user = await config.code.oauth_callback(
589
+ provider_id, token, raw_user_data, default_user, id_token
590
+ )
591
+
592
+ if not user:
593
+ raise HTTPException(
594
+ status_code=status.HTTP_401_UNAUTHORIZED,
595
+ detail="Unauthorized",
596
+ )
597
+
598
+ access_token = create_jwt(user)
599
+
600
+ if data_layer := get_data_layer():
601
+ try:
602
+ await data_layer.create_user(user)
603
+ except Exception as e:
604
+ logger.error(f"Error creating user: {e}")
605
+
606
+ params = urllib.parse.urlencode(
607
+ {
608
+ "access_token": access_token,
609
+ "token_type": "bearer",
610
+ }
611
+ )
612
+ response = RedirectResponse(
613
+ # FIXME: redirect to the right frontend base url to improve the dev environment
614
+ url=f"/login/callback?{params}",
615
+ status_code=302,
616
+ )
617
+ response.delete_cookie("oauth_state")
618
+ return response
619
+
620
+
621
+ @router.get("/project/translations")
501
622
  async def project_translations(
502
623
  language: str = Query(default="en-US", description="Language code"),
503
624
  ):
@@ -513,7 +634,7 @@ async def project_translations(
513
634
  )
514
635
 
515
636
 
516
- @app.get("/project/settings")
637
+ @router.get("/project/settings")
517
638
  async def project_settings(
518
639
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
519
640
  language: str = Query(default="en-US", description="Language code"),
@@ -559,7 +680,7 @@ async def project_settings(
559
680
  )
560
681
 
561
682
 
562
- @app.put("/feedback")
683
+ @router.put("/feedback")
563
684
  async def update_feedback(
564
685
  request: Request,
565
686
  update: UpdateFeedbackRequest,
@@ -578,7 +699,7 @@ async def update_feedback(
578
699
  return JSONResponse(content={"success": True, "feedbackId": feedback_id})
579
700
 
580
701
 
581
- @app.delete("/feedback")
702
+ @router.delete("/feedback")
582
703
  async def delete_feedback(
583
704
  request: Request,
584
705
  payload: DeleteFeedbackRequest,
@@ -597,7 +718,7 @@ async def delete_feedback(
597
718
  return JSONResponse(content={"success": True})
598
719
 
599
720
 
600
- @app.post("/project/threads")
721
+ @router.post("/project/threads")
601
722
  async def get_user_threads(
602
723
  request: Request,
603
724
  payload: GetThreadsRequest,
@@ -622,7 +743,7 @@ async def get_user_threads(
622
743
  return JSONResponse(content=res.to_dict())
623
744
 
624
745
 
625
- @app.get("/project/thread/{thread_id}")
746
+ @router.get("/project/thread/{thread_id}")
626
747
  async def get_thread(
627
748
  request: Request,
628
749
  thread_id: str,
@@ -640,7 +761,7 @@ async def get_thread(
640
761
  return JSONResponse(content=res)
641
762
 
642
763
 
643
- @app.get("/project/thread/{thread_id}/element/{element_id}")
764
+ @router.get("/project/thread/{thread_id}/element/{element_id}")
644
765
  async def get_thread_element(
645
766
  request: Request,
646
767
  thread_id: str,
@@ -659,7 +780,7 @@ async def get_thread_element(
659
780
  return JSONResponse(content=res)
660
781
 
661
782
 
662
- @app.delete("/project/thread")
783
+ @router.delete("/project/thread")
663
784
  async def delete_thread(
664
785
  request: Request,
665
786
  payload: DeleteThreadRequest,
@@ -680,7 +801,7 @@ async def delete_thread(
680
801
  return JSONResponse(content={"success": True})
681
802
 
682
803
 
683
- @app.post("/project/file")
804
+ @router.post("/project/file")
684
805
  async def upload_file(
685
806
  session_id: str,
686
807
  file: UploadFile,
@@ -716,7 +837,7 @@ async def upload_file(
716
837
  return JSONResponse(file_response)
717
838
 
718
839
 
719
- @app.get("/project/file/{file_id}")
840
+ @router.get("/project/file/{file_id}")
720
841
  async def get_file(
721
842
  file_id: str,
722
843
  session_id: Optional[str] = None,
@@ -738,7 +859,7 @@ async def get_file(
738
859
  raise HTTPException(status_code=404, detail="File not found")
739
860
 
740
861
 
741
- @app.get("/files/{filename:path}")
862
+ @router.get("/files/{filename:path}")
742
863
  async def serve_file(
743
864
  filename: str,
744
865
  current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
@@ -756,7 +877,7 @@ async def serve_file(
756
877
  raise HTTPException(status_code=404, detail="File not found")
757
878
 
758
879
 
759
- @app.get("/favicon")
880
+ @router.get("/favicon")
760
881
  async def get_favicon():
761
882
  custom_favicon_path = os.path.join(APP_ROOT, "public", "favicon.*")
762
883
  files = glob.glob(custom_favicon_path)
@@ -771,7 +892,7 @@ async def get_favicon():
771
892
  return FileResponse(favicon_path, media_type=media_type)
772
893
 
773
894
 
774
- @app.get("/logo")
895
+ @router.get("/logo")
775
896
  async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
776
897
  theme_value = theme.value if theme else Theme.light.value
777
898
  logo_path = None
@@ -793,7 +914,7 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):
793
914
  return FileResponse(logo_path, media_type=media_type)
794
915
 
795
916
 
796
- @app.get("/avatars/{avatar_id}")
917
+ @router.get("/avatars/{avatar_id}")
797
918
  async def get_avatar(avatar_id: str):
798
919
  if avatar_id == "default":
799
920
  avatar_id = config.ui.name
@@ -812,19 +933,20 @@ async def get_avatar(avatar_id: str):
812
933
  return await get_favicon()
813
934
 
814
935
 
815
- @app.head("/")
936
+ @router.head("/")
816
937
  def status_check():
817
938
  return {"message": "Site is operational"}
818
939
 
819
940
 
820
- def register_wildcard_route_handler():
821
- @app.get("/{path:path}")
822
- async def serve(request: Request, path: str):
823
- html_template = get_html_template()
824
- """Serve the UI files."""
825
- response = HTMLResponse(content=html_template, status_code=200)
941
+ @router.get("/{full_path:path}")
942
+ async def serve():
943
+ html_template = get_html_template()
944
+ """Serve the UI files."""
945
+ response = HTMLResponse(content=html_template, status_code=200)
946
+
947
+ return response
826
948
 
827
- return response
828
949
 
950
+ app.include_router(router)
829
951
 
830
952
  import chainlit.socket # noqa