chainlit 1.0.401__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- chainlit/__init__.py +98 -279
- chainlit/_utils.py +8 -0
- chainlit/action.py +12 -10
- chainlit/{auth.py → auth/__init__.py} +28 -36
- chainlit/auth/cookie.py +122 -0
- chainlit/auth/jwt.py +39 -0
- chainlit/cache.py +4 -6
- chainlit/callbacks.py +362 -0
- chainlit/chat_context.py +64 -0
- chainlit/chat_settings.py +3 -1
- chainlit/cli/__init__.py +77 -8
- chainlit/config.py +181 -101
- chainlit/context.py +42 -13
- chainlit/copilot/dist/index.js +8750 -903
- chainlit/data/__init__.py +101 -416
- chainlit/data/acl.py +6 -2
- chainlit/data/base.py +107 -0
- chainlit/data/chainlit_data_layer.py +608 -0
- chainlit/data/dynamodb.py +590 -0
- chainlit/data/literalai.py +500 -0
- chainlit/data/sql_alchemy.py +721 -0
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +81 -0
- chainlit/data/storage_clients/azure_blob.py +89 -0
- chainlit/data/storage_clients/base.py +26 -0
- chainlit/data/storage_clients/gcs.py +88 -0
- chainlit/data/storage_clients/s3.py +75 -0
- chainlit/data/utils.py +29 -0
- chainlit/discord/__init__.py +6 -0
- chainlit/discord/app.py +354 -0
- chainlit/element.py +91 -33
- chainlit/emitter.py +80 -29
- chainlit/frontend/dist/assets/DailyMotion-C_XC7xJI.js +1 -0
- chainlit/frontend/dist/assets/Dataframe-Cs4l4hA1.js +22 -0
- chainlit/frontend/dist/assets/Facebook-CUeCH7hk.js +1 -0
- chainlit/frontend/dist/assets/FilePlayer-CB-fYkx8.js +1 -0
- chainlit/frontend/dist/assets/Kaltura-YX6qaq72.js +1 -0
- chainlit/frontend/dist/assets/Mixcloud-DGV0ldjP.js +1 -0
- chainlit/frontend/dist/assets/Mux-CmRss5oc.js +1 -0
- chainlit/frontend/dist/assets/Preview-DBVJn7-H.js +1 -0
- chainlit/frontend/dist/assets/SoundCloud-qLUb18oY.js +1 -0
- chainlit/frontend/dist/assets/Streamable-BvYP7bFp.js +1 -0
- chainlit/frontend/dist/assets/Twitch-CTHt-sGZ.js +1 -0
- chainlit/frontend/dist/assets/Vidyard-B-0mCJbm.js +1 -0
- chainlit/frontend/dist/assets/Vimeo-Dnp7ri8q.js +1 -0
- chainlit/frontend/dist/assets/Wistia-DW0x_UBn.js +1 -0
- chainlit/frontend/dist/assets/YouTube--98FipvA.js +1 -0
- chainlit/frontend/dist/assets/index-D71nZ46o.js +8665 -0
- chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
- chainlit/frontend/dist/assets/react-plotly-Cn_BQTQw.js +3484 -0
- chainlit/frontend/dist/index.html +2 -4
- chainlit/haystack/callbacks.py +4 -7
- chainlit/input_widget.py +8 -4
- chainlit/langchain/callbacks.py +103 -68
- chainlit/langflow/__init__.py +1 -0
- chainlit/llama_index/callbacks.py +65 -40
- chainlit/markdown.py +22 -6
- chainlit/message.py +54 -56
- chainlit/mistralai/__init__.py +50 -0
- chainlit/oauth_providers.py +266 -8
- chainlit/openai/__init__.py +10 -18
- chainlit/secret.py +1 -1
- chainlit/server.py +789 -228
- chainlit/session.py +108 -90
- chainlit/slack/__init__.py +6 -0
- chainlit/slack/app.py +397 -0
- chainlit/socket.py +199 -116
- chainlit/step.py +141 -89
- chainlit/sync.py +2 -1
- chainlit/teams/__init__.py +6 -0
- chainlit/teams/app.py +338 -0
- chainlit/translations/bn.json +235 -0
- chainlit/translations/en-US.json +83 -4
- chainlit/translations/gu.json +235 -0
- chainlit/translations/he-IL.json +235 -0
- chainlit/translations/hi.json +235 -0
- chainlit/translations/kn.json +235 -0
- chainlit/translations/ml.json +235 -0
- chainlit/translations/mr.json +235 -0
- chainlit/translations/nl-NL.json +233 -0
- chainlit/translations/ta.json +235 -0
- chainlit/translations/te.json +235 -0
- chainlit/translations/zh-CN.json +233 -0
- chainlit/translations.py +60 -0
- chainlit/types.py +133 -28
- chainlit/user.py +14 -3
- chainlit/user_session.py +6 -3
- chainlit/utils.py +52 -5
- chainlit/version.py +3 -2
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/METADATA +48 -50
- chainlit-2.0.3.dist-info/RECORD +106 -0
- chainlit/cli/utils.py +0 -24
- chainlit/frontend/dist/assets/index-9711593e.js +0 -723
- chainlit/frontend/dist/assets/index-d088547c.css +0 -1
- chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
- chainlit/playground/__init__.py +0 -2
- chainlit/playground/config.py +0 -40
- chainlit/playground/provider.py +0 -108
- chainlit/playground/providers/__init__.py +0 -13
- chainlit/playground/providers/anthropic.py +0 -118
- chainlit/playground/providers/huggingface.py +0 -75
- chainlit/playground/providers/langchain.py +0 -89
- chainlit/playground/providers/openai.py +0 -408
- chainlit/playground/providers/vertexai.py +0 -171
- chainlit/translations/pt-BR.json +0 -155
- chainlit-1.0.401.dist-info/RECORD +0 -66
- /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
- /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/WHEEL +0 -0
- {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/entry_points.txt +0 -0
chainlit/oauth_providers.py
CHANGED
|
@@ -4,9 +4,11 @@ import urllib.parse
|
|
|
4
4
|
from typing import Dict, List, Optional, Tuple
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
|
-
from chainlit.user import User
|
|
8
7
|
from fastapi import HTTPException
|
|
9
8
|
|
|
9
|
+
from chainlit.secret import random_secret
|
|
10
|
+
from chainlit.user import User
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class OAuthProvider:
|
|
12
14
|
id: str
|
|
@@ -15,15 +17,31 @@ class OAuthProvider:
|
|
|
15
17
|
client_secret: str
|
|
16
18
|
authorize_url: str
|
|
17
19
|
authorize_params: Dict[str, str]
|
|
20
|
+
default_prompt: Optional[str] = None
|
|
18
21
|
|
|
19
22
|
def is_configured(self):
|
|
20
23
|
return all([os.environ.get(env) for env in self.env])
|
|
21
24
|
|
|
22
25
|
async def get_token(self, code: str, url: str) -> str:
|
|
23
|
-
raise NotImplementedError
|
|
26
|
+
raise NotImplementedError
|
|
24
27
|
|
|
25
28
|
async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
|
|
26
|
-
raise NotImplementedError
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
def get_env_prefix(self) -> str:
|
|
32
|
+
"""Return environment prefix, like AZURE_AD."""
|
|
33
|
+
|
|
34
|
+
return self.id.replace("-", "_").upper()
|
|
35
|
+
|
|
36
|
+
def get_prompt(self) -> Optional[str]:
|
|
37
|
+
"""Return OAuth prompt param."""
|
|
38
|
+
if prompt := os.environ.get(f"OAUTH_{self.get_env_prefix()}_PROMPT"):
|
|
39
|
+
return prompt
|
|
40
|
+
|
|
41
|
+
if prompt := os.environ.get("OAUTH_PROMPT"):
|
|
42
|
+
return prompt
|
|
43
|
+
|
|
44
|
+
return self.default_prompt
|
|
27
45
|
|
|
28
46
|
|
|
29
47
|
class GithubOAuthProvider(OAuthProvider):
|
|
@@ -38,6 +56,9 @@ class GithubOAuthProvider(OAuthProvider):
|
|
|
38
56
|
"scope": "user:email",
|
|
39
57
|
}
|
|
40
58
|
|
|
59
|
+
if prompt := self.get_prompt():
|
|
60
|
+
self.authorize_params["prompt"] = prompt
|
|
61
|
+
|
|
41
62
|
async def get_token(self, code: str, url: str):
|
|
42
63
|
payload = {
|
|
43
64
|
"client_id": self.client_id,
|
|
@@ -96,6 +117,9 @@ class GoogleOAuthProvider(OAuthProvider):
|
|
|
96
117
|
"access_type": "offline",
|
|
97
118
|
}
|
|
98
119
|
|
|
120
|
+
if prompt := self.get_prompt():
|
|
121
|
+
self.authorize_params["prompt"] = prompt
|
|
122
|
+
|
|
99
123
|
async def get_token(self, code: str, url: str):
|
|
100
124
|
payload = {
|
|
101
125
|
"client_id": self.client_id,
|
|
@@ -163,6 +187,9 @@ class AzureADOAuthProvider(OAuthProvider):
|
|
|
163
187
|
"response_mode": "query",
|
|
164
188
|
}
|
|
165
189
|
|
|
190
|
+
if prompt := self.get_prompt():
|
|
191
|
+
self.authorize_params["prompt"] = prompt
|
|
192
|
+
|
|
166
193
|
async def get_token(self, code: str, url: str):
|
|
167
194
|
payload = {
|
|
168
195
|
"client_id": self.client_id,
|
|
@@ -203,10 +230,97 @@ class AzureADOAuthProvider(OAuthProvider):
|
|
|
203
230
|
)
|
|
204
231
|
photo_data = await photo_response.aread()
|
|
205
232
|
base64_image = base64.b64encode(photo_data)
|
|
206
|
-
azure_user[
|
|
207
|
-
"
|
|
208
|
-
|
|
209
|
-
except Exception
|
|
233
|
+
azure_user["image"] = (
|
|
234
|
+
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
|
235
|
+
)
|
|
236
|
+
except Exception:
|
|
237
|
+
# Ignore errors getting the photo
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
user = User(
|
|
241
|
+
identifier=azure_user["userPrincipalName"],
|
|
242
|
+
metadata={"image": azure_user.get("image"), "provider": "azure-ad"},
|
|
243
|
+
)
|
|
244
|
+
return (azure_user, user)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class AzureADHybridOAuthProvider(OAuthProvider):
|
|
248
|
+
id = "azure-ad-hybrid"
|
|
249
|
+
env = [
|
|
250
|
+
"OAUTH_AZURE_AD_HYBRID_CLIENT_ID",
|
|
251
|
+
"OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET",
|
|
252
|
+
"OAUTH_AZURE_AD_HYBRID_TENANT_ID",
|
|
253
|
+
]
|
|
254
|
+
authorize_url = (
|
|
255
|
+
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize"
|
|
256
|
+
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
|
257
|
+
else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
|
258
|
+
)
|
|
259
|
+
token_url = (
|
|
260
|
+
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token"
|
|
261
|
+
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
|
262
|
+
else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def __init__(self):
|
|
266
|
+
self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID")
|
|
267
|
+
self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET")
|
|
268
|
+
nonce = random_secret(16)
|
|
269
|
+
self.authorize_params = {
|
|
270
|
+
"tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
|
|
271
|
+
"response_type": "code id_token",
|
|
272
|
+
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid",
|
|
273
|
+
"response_mode": "form_post",
|
|
274
|
+
"nonce": nonce,
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
if prompt := self.get_prompt():
|
|
278
|
+
self.authorize_params["prompt"] = prompt
|
|
279
|
+
|
|
280
|
+
async def get_token(self, code: str, url: str):
|
|
281
|
+
payload = {
|
|
282
|
+
"client_id": self.client_id,
|
|
283
|
+
"client_secret": self.client_secret,
|
|
284
|
+
"code": code,
|
|
285
|
+
"grant_type": "authorization_code",
|
|
286
|
+
"redirect_uri": url,
|
|
287
|
+
}
|
|
288
|
+
async with httpx.AsyncClient() as client:
|
|
289
|
+
response = await client.post(
|
|
290
|
+
self.token_url,
|
|
291
|
+
data=payload,
|
|
292
|
+
)
|
|
293
|
+
response.raise_for_status()
|
|
294
|
+
json = response.json()
|
|
295
|
+
|
|
296
|
+
token = json["access_token"]
|
|
297
|
+
if not token:
|
|
298
|
+
raise HTTPException(
|
|
299
|
+
status_code=400, detail="Failed to get the access token"
|
|
300
|
+
)
|
|
301
|
+
return token
|
|
302
|
+
|
|
303
|
+
async def get_user_info(self, token: str):
|
|
304
|
+
async with httpx.AsyncClient() as client:
|
|
305
|
+
response = await client.get(
|
|
306
|
+
"https://graph.microsoft.com/v1.0/me",
|
|
307
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
308
|
+
)
|
|
309
|
+
response.raise_for_status()
|
|
310
|
+
|
|
311
|
+
azure_user = response.json()
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
photo_response = await client.get(
|
|
315
|
+
"https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
|
|
316
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
317
|
+
)
|
|
318
|
+
photo_data = await photo_response.aread()
|
|
319
|
+
base64_image = base64.b64encode(photo_data)
|
|
320
|
+
azure_user["image"] = (
|
|
321
|
+
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
|
322
|
+
)
|
|
323
|
+
except Exception:
|
|
210
324
|
# Ignore errors getting the photo
|
|
211
325
|
pass
|
|
212
326
|
|
|
@@ -242,6 +356,9 @@ class OktaOAuthProvider(OAuthProvider):
|
|
|
242
356
|
"response_mode": "query",
|
|
243
357
|
}
|
|
244
358
|
|
|
359
|
+
if prompt := self.get_prompt():
|
|
360
|
+
self.authorize_params["prompt"] = prompt
|
|
361
|
+
|
|
245
362
|
def get_authorization_server_path(self):
|
|
246
363
|
if not self.authorization_server_id:
|
|
247
364
|
return "/default"
|
|
@@ -313,6 +430,9 @@ class Auth0OAuthProvider(OAuthProvider):
|
|
|
313
430
|
"audience": f"{self.original_domain}/userinfo",
|
|
314
431
|
}
|
|
315
432
|
|
|
433
|
+
if prompt := self.get_prompt():
|
|
434
|
+
self.authorize_params["prompt"] = prompt
|
|
435
|
+
|
|
316
436
|
async def get_token(self, code: str, url: str):
|
|
317
437
|
payload = {
|
|
318
438
|
"client_id": self.client_id,
|
|
@@ -357,7 +477,7 @@ class DescopeOAuthProvider(OAuthProvider):
|
|
|
357
477
|
id = "descope"
|
|
358
478
|
env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"]
|
|
359
479
|
# Ensure that the domain does not have a trailing slash
|
|
360
|
-
domain =
|
|
480
|
+
domain = "https://api.descope.com/oauth2/v1"
|
|
361
481
|
|
|
362
482
|
authorize_url = f"{domain}/authorize"
|
|
363
483
|
|
|
@@ -370,6 +490,9 @@ class DescopeOAuthProvider(OAuthProvider):
|
|
|
370
490
|
"audience": f"{self.domain}/userinfo",
|
|
371
491
|
}
|
|
372
492
|
|
|
493
|
+
if prompt := self.get_prompt():
|
|
494
|
+
self.authorize_params["prompt"] = prompt
|
|
495
|
+
|
|
373
496
|
async def get_token(self, code: str, url: str):
|
|
374
497
|
payload = {
|
|
375
498
|
"client_id": self.client_id,
|
|
@@ -428,6 +551,9 @@ class AWSCognitoOAuthProvider(OAuthProvider):
|
|
|
428
551
|
"scope": "openid profile email",
|
|
429
552
|
}
|
|
430
553
|
|
|
554
|
+
if prompt := self.get_prompt():
|
|
555
|
+
self.authorize_params["prompt"] = prompt
|
|
556
|
+
|
|
431
557
|
async def get_token(self, code: str, url: str):
|
|
432
558
|
payload = {
|
|
433
559
|
"client_id": self.client_id,
|
|
@@ -475,14 +601,146 @@ class AWSCognitoOAuthProvider(OAuthProvider):
|
|
|
475
601
|
return (cognito_user, user)
|
|
476
602
|
|
|
477
603
|
|
|
604
|
+
class GitlabOAuthProvider(OAuthProvider):
|
|
605
|
+
id = "gitlab"
|
|
606
|
+
env = [
|
|
607
|
+
"OAUTH_GITLAB_CLIENT_ID",
|
|
608
|
+
"OAUTH_GITLAB_CLIENT_SECRET",
|
|
609
|
+
"OAUTH_GITLAB_DOMAIN",
|
|
610
|
+
]
|
|
611
|
+
|
|
612
|
+
def __init__(self):
|
|
613
|
+
self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID")
|
|
614
|
+
self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET")
|
|
615
|
+
# Ensure that the domain does not have a trailing slash
|
|
616
|
+
self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}"
|
|
617
|
+
|
|
618
|
+
self.authorize_url = f"{self.domain}/oauth/authorize"
|
|
619
|
+
|
|
620
|
+
self.authorize_params = {
|
|
621
|
+
"scope": "openid profile email",
|
|
622
|
+
"response_type": "code",
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
if prompt := self.get_prompt():
|
|
626
|
+
self.authorize_params["prompt"] = prompt
|
|
627
|
+
|
|
628
|
+
async def get_token(self, code: str, url: str):
|
|
629
|
+
payload = {
|
|
630
|
+
"client_id": self.client_id,
|
|
631
|
+
"client_secret": self.client_secret,
|
|
632
|
+
"code": code,
|
|
633
|
+
"grant_type": "authorization_code",
|
|
634
|
+
"redirect_uri": url,
|
|
635
|
+
}
|
|
636
|
+
async with httpx.AsyncClient() as client:
|
|
637
|
+
response = await client.post(
|
|
638
|
+
f"{self.domain}/oauth/token",
|
|
639
|
+
data=payload,
|
|
640
|
+
)
|
|
641
|
+
response.raise_for_status()
|
|
642
|
+
json_content = response.json()
|
|
643
|
+
token = json_content.get("access_token")
|
|
644
|
+
if not token:
|
|
645
|
+
raise HTTPException(
|
|
646
|
+
status_code=400, detail="Failed to get the access token"
|
|
647
|
+
)
|
|
648
|
+
return token
|
|
649
|
+
|
|
650
|
+
async def get_user_info(self, token: str):
|
|
651
|
+
async with httpx.AsyncClient() as client:
|
|
652
|
+
response = await client.get(
|
|
653
|
+
f"{self.domain}/oauth/userinfo",
|
|
654
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
655
|
+
)
|
|
656
|
+
response.raise_for_status()
|
|
657
|
+
gitlab_user = response.json()
|
|
658
|
+
user = User(
|
|
659
|
+
identifier=gitlab_user.get("email"),
|
|
660
|
+
metadata={
|
|
661
|
+
"image": gitlab_user.get("picture", ""),
|
|
662
|
+
"provider": "gitlab",
|
|
663
|
+
},
|
|
664
|
+
)
|
|
665
|
+
return (gitlab_user, user)
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
class KeycloakOAuthProvider(OAuthProvider):
|
|
669
|
+
env = [
|
|
670
|
+
"OAUTH_KEYCLOAK_CLIENT_ID",
|
|
671
|
+
"OAUTH_KEYCLOAK_CLIENT_SECRET",
|
|
672
|
+
"OAUTH_KEYCLOAK_REALM",
|
|
673
|
+
"OAUTH_KEYCLOAK_BASE_URL",
|
|
674
|
+
]
|
|
675
|
+
id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
|
|
676
|
+
|
|
677
|
+
def __init__(self):
|
|
678
|
+
self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
|
|
679
|
+
self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
|
|
680
|
+
self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
|
|
681
|
+
self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
|
|
682
|
+
self.authorize_url = (
|
|
683
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
self.authorize_params = {
|
|
687
|
+
"scope": "profile email openid",
|
|
688
|
+
"response_type": "code",
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
if prompt := self.get_prompt():
|
|
692
|
+
self.authorize_params["prompt"] = prompt
|
|
693
|
+
|
|
694
|
+
async def get_token(self, code: str, url: str):
|
|
695
|
+
payload = {
|
|
696
|
+
"client_id": self.client_id,
|
|
697
|
+
"client_secret": self.client_secret,
|
|
698
|
+
"code": code,
|
|
699
|
+
"grant_type": "authorization_code",
|
|
700
|
+
"redirect_uri": url,
|
|
701
|
+
}
|
|
702
|
+
async with httpx.AsyncClient() as client:
|
|
703
|
+
response = await client.post(
|
|
704
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
|
|
705
|
+
data=payload,
|
|
706
|
+
)
|
|
707
|
+
response.raise_for_status()
|
|
708
|
+
json = response.json()
|
|
709
|
+
token = json.get("access_token")
|
|
710
|
+
if not token:
|
|
711
|
+
raise httpx.HTTPStatusError(
|
|
712
|
+
"Failed to get the access token",
|
|
713
|
+
request=response.request,
|
|
714
|
+
response=response,
|
|
715
|
+
)
|
|
716
|
+
return token
|
|
717
|
+
|
|
718
|
+
async def get_user_info(self, token: str):
|
|
719
|
+
async with httpx.AsyncClient() as client:
|
|
720
|
+
response = await client.get(
|
|
721
|
+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
|
|
722
|
+
headers={"Authorization": f"Bearer {token}"},
|
|
723
|
+
)
|
|
724
|
+
response.raise_for_status()
|
|
725
|
+
kc_user = response.json()
|
|
726
|
+
user = User(
|
|
727
|
+
identifier=kc_user["email"],
|
|
728
|
+
metadata={"provider": "keycloak"},
|
|
729
|
+
)
|
|
730
|
+
return (kc_user, user)
|
|
731
|
+
|
|
732
|
+
|
|
478
733
|
providers = [
|
|
479
734
|
GithubOAuthProvider(),
|
|
480
735
|
GoogleOAuthProvider(),
|
|
481
736
|
AzureADOAuthProvider(),
|
|
737
|
+
AzureADHybridOAuthProvider(),
|
|
482
738
|
OktaOAuthProvider(),
|
|
483
739
|
Auth0OAuthProvider(),
|
|
484
740
|
DescopeOAuthProvider(),
|
|
485
741
|
AWSCognitoOAuthProvider(),
|
|
742
|
+
GitlabOAuthProvider(),
|
|
743
|
+
KeycloakOAuthProvider(),
|
|
486
744
|
]
|
|
487
745
|
|
|
488
746
|
|
chainlit/openai/__init__.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
from typing import Union
|
|
2
3
|
|
|
3
|
-
from chainlit.context import get_context
|
|
4
|
-
from chainlit.step import Step
|
|
5
|
-
from chainlit.sync import run_sync
|
|
6
|
-
from chainlit.utils import check_module_version
|
|
7
4
|
from literalai import ChatGeneration, CompletionGeneration
|
|
8
5
|
from literalai.helper import timestamp_utc
|
|
9
6
|
|
|
7
|
+
from chainlit.context import local_steps
|
|
8
|
+
from chainlit.step import Step
|
|
9
|
+
from chainlit.utils import check_module_version
|
|
10
|
+
|
|
10
11
|
|
|
11
12
|
def instrument_openai():
|
|
12
13
|
if not check_module_version("openai", "1.0.0"):
|
|
@@ -16,16 +17,12 @@ def instrument_openai():
|
|
|
16
17
|
|
|
17
18
|
from literalai.instrumentation.openai import instrument_openai
|
|
18
19
|
|
|
19
|
-
|
|
20
|
+
def on_new_generation(
|
|
20
21
|
generation: Union["ChatGeneration", "CompletionGeneration"], timing
|
|
21
22
|
):
|
|
22
|
-
|
|
23
|
+
previous_steps = local_steps.get()
|
|
23
24
|
|
|
24
|
-
parent_id = None
|
|
25
|
-
if context.current_step:
|
|
26
|
-
parent_id = context.current_step.id
|
|
27
|
-
elif context.session.root_message:
|
|
28
|
-
parent_id = context.session.root_message.id
|
|
25
|
+
parent_id = previous_steps[-1].id if previous_steps else None
|
|
29
26
|
|
|
30
27
|
step = Step(
|
|
31
28
|
name=generation.model if generation.model else generation.provider,
|
|
@@ -52,11 +49,6 @@ def instrument_openai():
|
|
|
52
49
|
step.input = generation.prompt
|
|
53
50
|
step.output = generation.completion
|
|
54
51
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def on_new_generation_sync(
|
|
58
|
-
generation: Union["ChatGeneration", "CompletionGeneration"], timing
|
|
59
|
-
):
|
|
60
|
-
run_sync(on_new_generation(generation, timing))
|
|
52
|
+
asyncio.create_task(step.send())
|
|
61
53
|
|
|
62
|
-
instrument_openai(None,
|
|
54
|
+
instrument_openai(None, on_new_generation)
|
chainlit/secret.py
CHANGED