chainlit 2.7.0__py3-none-any.whl → 2.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (85) hide show
  1. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/METADATA +1 -1
  2. chainlit-2.7.1.dist-info/RECORD +4 -0
  3. chainlit/__init__.py +0 -207
  4. chainlit/__main__.py +0 -4
  5. chainlit/_utils.py +0 -8
  6. chainlit/action.py +0 -33
  7. chainlit/auth/__init__.py +0 -95
  8. chainlit/auth/cookie.py +0 -197
  9. chainlit/auth/jwt.py +0 -42
  10. chainlit/cache.py +0 -45
  11. chainlit/callbacks.py +0 -433
  12. chainlit/chat_context.py +0 -64
  13. chainlit/chat_settings.py +0 -34
  14. chainlit/cli/__init__.py +0 -235
  15. chainlit/config.py +0 -621
  16. chainlit/context.py +0 -112
  17. chainlit/data/__init__.py +0 -111
  18. chainlit/data/acl.py +0 -19
  19. chainlit/data/base.py +0 -107
  20. chainlit/data/chainlit_data_layer.py +0 -687
  21. chainlit/data/dynamodb.py +0 -616
  22. chainlit/data/literalai.py +0 -501
  23. chainlit/data/sql_alchemy.py +0 -741
  24. chainlit/data/storage_clients/__init__.py +0 -0
  25. chainlit/data/storage_clients/azure.py +0 -84
  26. chainlit/data/storage_clients/azure_blob.py +0 -94
  27. chainlit/data/storage_clients/base.py +0 -28
  28. chainlit/data/storage_clients/gcs.py +0 -101
  29. chainlit/data/storage_clients/s3.py +0 -88
  30. chainlit/data/utils.py +0 -29
  31. chainlit/discord/__init__.py +0 -6
  32. chainlit/discord/app.py +0 -364
  33. chainlit/element.py +0 -454
  34. chainlit/emitter.py +0 -450
  35. chainlit/hello.py +0 -12
  36. chainlit/input_widget.py +0 -182
  37. chainlit/langchain/__init__.py +0 -6
  38. chainlit/langchain/callbacks.py +0 -682
  39. chainlit/langflow/__init__.py +0 -25
  40. chainlit/llama_index/__init__.py +0 -6
  41. chainlit/llama_index/callbacks.py +0 -206
  42. chainlit/logger.py +0 -16
  43. chainlit/markdown.py +0 -57
  44. chainlit/mcp.py +0 -99
  45. chainlit/message.py +0 -619
  46. chainlit/mistralai/__init__.py +0 -50
  47. chainlit/oauth_providers.py +0 -835
  48. chainlit/openai/__init__.py +0 -53
  49. chainlit/py.typed +0 -0
  50. chainlit/secret.py +0 -9
  51. chainlit/semantic_kernel/__init__.py +0 -111
  52. chainlit/server.py +0 -1616
  53. chainlit/session.py +0 -304
  54. chainlit/sidebar.py +0 -55
  55. chainlit/slack/__init__.py +0 -6
  56. chainlit/slack/app.py +0 -427
  57. chainlit/socket.py +0 -381
  58. chainlit/step.py +0 -490
  59. chainlit/sync.py +0 -43
  60. chainlit/teams/__init__.py +0 -6
  61. chainlit/teams/app.py +0 -348
  62. chainlit/translations/bn.json +0 -214
  63. chainlit/translations/el-GR.json +0 -214
  64. chainlit/translations/en-US.json +0 -214
  65. chainlit/translations/fr-FR.json +0 -214
  66. chainlit/translations/gu.json +0 -214
  67. chainlit/translations/he-IL.json +0 -214
  68. chainlit/translations/hi.json +0 -214
  69. chainlit/translations/ja.json +0 -214
  70. chainlit/translations/kn.json +0 -214
  71. chainlit/translations/ml.json +0 -214
  72. chainlit/translations/mr.json +0 -214
  73. chainlit/translations/nl.json +0 -214
  74. chainlit/translations/ta.json +0 -214
  75. chainlit/translations/te.json +0 -214
  76. chainlit/translations/zh-CN.json +0 -214
  77. chainlit/translations.py +0 -60
  78. chainlit/types.py +0 -334
  79. chainlit/user.py +0 -43
  80. chainlit/user_session.py +0 -153
  81. chainlit/utils.py +0 -173
  82. chainlit/version.py +0 -8
  83. chainlit-2.7.0.dist-info/RECORD +0 -84
  84. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/WHEEL +0 -0
  85. {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/entry_points.txt +0 -0
@@ -1,835 +0,0 @@
1
- import base64
2
- import os
3
- import urllib.parse
4
- from typing import Dict, List, Optional, Tuple
5
-
6
- import httpx
7
- from fastapi import HTTPException
8
-
9
- from chainlit.secret import random_secret
10
- from chainlit.user import User
11
-
12
-
13
- class OAuthProvider:
14
- id: str
15
- env: List[str]
16
- client_id: str
17
- client_secret: str
18
- authorize_url: str
19
- authorize_params: Dict[str, str]
20
- default_prompt: Optional[str] = None
21
-
22
- def is_configured(self):
23
- return all([os.environ.get(env) for env in self.env])
24
-
25
- async def get_token(self, code: str, url: str) -> str:
26
- raise NotImplementedError
27
-
28
- async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
29
- raise NotImplementedError
30
-
31
- def get_env_prefix(self) -> str:
32
- """Return environment prefix, like AZURE_AD."""
33
-
34
- return self.id.replace("-", "_").upper()
35
-
36
- def get_prompt(self) -> Optional[str]:
37
- """Return OAuth prompt param."""
38
- if prompt := os.environ.get(f"OAUTH_{self.get_env_prefix()}_PROMPT"):
39
- return prompt
40
-
41
- if prompt := os.environ.get("OAUTH_PROMPT"):
42
- return prompt
43
-
44
- return self.default_prompt
45
-
46
-
47
- class GithubOAuthProvider(OAuthProvider):
48
- id = "github"
49
- env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"]
50
- authorize_url = "https://github.com/login/oauth/authorize"
51
-
52
- def __init__(self):
53
- self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID")
54
- self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET")
55
- self.authorize_params = {
56
- "scope": "user:email",
57
- }
58
-
59
- if prompt := self.get_prompt():
60
- self.authorize_params["prompt"] = prompt
61
-
62
- async def get_token(self, code: str, url: str):
63
- payload = {
64
- "client_id": self.client_id,
65
- "client_secret": self.client_secret,
66
- "code": code,
67
- }
68
- async with httpx.AsyncClient() as client:
69
- response = await client.post(
70
- "https://github.com/login/oauth/access_token",
71
- data=payload,
72
- )
73
- response.raise_for_status()
74
- content = urllib.parse.parse_qs(response.text)
75
- token = content.get("access_token", [""])[0]
76
- if not token:
77
- raise HTTPException(
78
- status_code=400, detail="Failed to get the access token"
79
- )
80
- return token
81
-
82
- async def get_user_info(self, token: str):
83
- async with httpx.AsyncClient() as client:
84
- user_response = await client.get(
85
- "https://api.github.com/user",
86
- headers={"Authorization": f"token {token}"},
87
- )
88
- user_response.raise_for_status()
89
- github_user = user_response.json()
90
-
91
- emails_response = await client.get(
92
- "https://api.github.com/user/emails",
93
- headers={"Authorization": f"token {token}"},
94
- )
95
- emails_response.raise_for_status()
96
- emails = emails_response.json()
97
-
98
- github_user.update({"emails": emails})
99
- user = User(
100
- identifier=github_user["login"],
101
- metadata={"image": github_user["avatar_url"], "provider": "github"},
102
- )
103
- return (github_user, user)
104
-
105
-
106
- class GoogleOAuthProvider(OAuthProvider):
107
- id = "google"
108
- env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"]
109
- authorize_url = "https://accounts.google.com/o/oauth2/v2/auth"
110
-
111
- def __init__(self):
112
- self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID")
113
- self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET")
114
- self.authorize_params = {
115
- "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
116
- "response_type": "code",
117
- "access_type": "offline",
118
- }
119
-
120
- if prompt := self.get_prompt():
121
- self.authorize_params["prompt"] = prompt
122
-
123
- async def get_token(self, code: str, url: str):
124
- payload = {
125
- "client_id": self.client_id,
126
- "client_secret": self.client_secret,
127
- "code": code,
128
- "grant_type": "authorization_code",
129
- "redirect_uri": url,
130
- }
131
- async with httpx.AsyncClient() as client:
132
- response = await client.post(
133
- "https://oauth2.googleapis.com/token",
134
- data=payload,
135
- )
136
- response.raise_for_status()
137
- json = response.json()
138
- token = json.get("access_token")
139
- if not token:
140
- raise httpx.HTTPStatusError(
141
- "Failed to get the access token",
142
- request=response.request,
143
- response=response,
144
- )
145
- return token
146
-
147
- async def get_user_info(self, token: str):
148
- async with httpx.AsyncClient() as client:
149
- response = await client.get(
150
- "https://www.googleapis.com/userinfo/v2/me",
151
- headers={"Authorization": f"Bearer {token}"},
152
- )
153
- response.raise_for_status()
154
- google_user = response.json()
155
- user = User(
156
- identifier=google_user["email"],
157
- metadata={"image": google_user["picture"], "provider": "google"},
158
- )
159
- return (google_user, user)
160
-
161
-
162
- class AzureADOAuthProvider(OAuthProvider):
163
- id = "azure-ad"
164
- env = [
165
- "OAUTH_AZURE_AD_CLIENT_ID",
166
- "OAUTH_AZURE_AD_CLIENT_SECRET",
167
- "OAUTH_AZURE_AD_TENANT_ID",
168
- ]
169
- authorize_url = (
170
- f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize"
171
- if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
172
- else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
173
- )
174
- token_url = (
175
- f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token"
176
- if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
177
- else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
178
- )
179
-
180
- def __init__(self):
181
- self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID")
182
- self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET")
183
- self.authorize_params = {
184
- "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"),
185
- "response_type": "code",
186
- "scope": "https://graph.microsoft.com/User.Read offline_access",
187
- "response_mode": "query",
188
- }
189
-
190
- if prompt := self.get_prompt():
191
- self.authorize_params["prompt"] = prompt
192
-
193
- async def get_token(self, code: str, url: str):
194
- payload = {
195
- "client_id": self.client_id,
196
- "client_secret": self.client_secret,
197
- "code": code,
198
- "grant_type": "authorization_code",
199
- "redirect_uri": url,
200
- }
201
- async with httpx.AsyncClient() as client:
202
- response = await client.post(
203
- self.token_url,
204
- data=payload,
205
- )
206
- response.raise_for_status()
207
- json = response.json()
208
-
209
- token = json["access_token"]
210
- refresh_token = json.get("refresh_token")
211
- if not token:
212
- raise HTTPException(
213
- status_code=400, detail="Failed to get the access token"
214
- )
215
- self._refresh_token = refresh_token
216
- return token
217
-
218
- async def get_user_info(self, token: str):
219
- async with httpx.AsyncClient() as client:
220
- response = await client.get(
221
- "https://graph.microsoft.com/v1.0/me",
222
- headers={"Authorization": f"Bearer {token}"},
223
- )
224
- response.raise_for_status()
225
-
226
- azure_user = response.json()
227
-
228
- try:
229
- photo_response = await client.get(
230
- "https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
231
- headers={"Authorization": f"Bearer {token}"},
232
- )
233
- photo_data = await photo_response.aread()
234
- base64_image = base64.b64encode(photo_data)
235
- azure_user["image"] = (
236
- f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
237
- )
238
- except Exception:
239
- # Ignore errors getting the photo
240
- pass
241
-
242
- user = User(
243
- identifier=azure_user["userPrincipalName"],
244
- metadata={
245
- "image": azure_user.get("image"),
246
- "provider": "azure-ad",
247
- "refresh_token": getattr(self, "_refresh_token", None),
248
- },
249
- )
250
- return (azure_user, user)
251
-
252
-
253
- class AzureADHybridOAuthProvider(OAuthProvider):
254
- id = "azure-ad-hybrid"
255
- env = [
256
- "OAUTH_AZURE_AD_HYBRID_CLIENT_ID",
257
- "OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET",
258
- "OAUTH_AZURE_AD_HYBRID_TENANT_ID",
259
- ]
260
- authorize_url = (
261
- f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize"
262
- if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
263
- else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
264
- )
265
- token_url = (
266
- f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token"
267
- if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
268
- else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
269
- )
270
-
271
- def __init__(self):
272
- self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID")
273
- self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET")
274
- nonce = random_secret(16)
275
- self.authorize_params = {
276
- "tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
277
- "response_type": "code id_token",
278
- "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access",
279
- "response_mode": "form_post",
280
- "nonce": nonce,
281
- }
282
-
283
- if prompt := self.get_prompt():
284
- self.authorize_params["prompt"] = prompt
285
-
286
- async def get_token(self, code: str, url: str):
287
- payload = {
288
- "client_id": self.client_id,
289
- "client_secret": self.client_secret,
290
- "code": code,
291
- "grant_type": "authorization_code",
292
- "redirect_uri": url,
293
- }
294
- async with httpx.AsyncClient() as client:
295
- response = await client.post(
296
- self.token_url,
297
- data=payload,
298
- )
299
- response.raise_for_status()
300
- json = response.json()
301
-
302
- token = json["access_token"]
303
- refresh_token = json.get("refresh_token")
304
- if not token:
305
- raise HTTPException(
306
- status_code=400, detail="Failed to get the access token"
307
- )
308
- self._refresh_token = refresh_token
309
- return token
310
-
311
- async def get_user_info(self, token: str):
312
- async with httpx.AsyncClient() as client:
313
- response = await client.get(
314
- "https://graph.microsoft.com/v1.0/me",
315
- headers={"Authorization": f"Bearer {token}"},
316
- )
317
- response.raise_for_status()
318
-
319
- azure_user = response.json()
320
-
321
- try:
322
- photo_response = await client.get(
323
- "https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
324
- headers={"Authorization": f"Bearer {token}"},
325
- )
326
- photo_data = await photo_response.aread()
327
- base64_image = base64.b64encode(photo_data)
328
- azure_user["image"] = (
329
- f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
330
- )
331
- except Exception:
332
- # Ignore errors getting the photo
333
- pass
334
-
335
- user = User(
336
- identifier=azure_user["userPrincipalName"],
337
- metadata={
338
- "image": azure_user.get("image"),
339
- "provider": "azure-ad",
340
- "refresh_token": getattr(self, "_refresh_token", None),
341
- },
342
- )
343
- return (azure_user, user)
344
-
345
-
346
- class OktaOAuthProvider(OAuthProvider):
347
- id = "okta"
348
- env = [
349
- "OAUTH_OKTA_CLIENT_ID",
350
- "OAUTH_OKTA_CLIENT_SECRET",
351
- "OAUTH_OKTA_DOMAIN",
352
- ]
353
- # Avoid trailing slash in domain if supplied
354
- domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}"
355
-
356
- def __init__(self):
357
- self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID")
358
- self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET")
359
- self.authorization_server_id = os.environ.get(
360
- "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", ""
361
- )
362
- self.authorize_url = (
363
- f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize"
364
- )
365
- self.authorize_params = {
366
- "response_type": "code",
367
- "scope": "openid profile email",
368
- "response_mode": "query",
369
- }
370
-
371
- if prompt := self.get_prompt():
372
- self.authorize_params["prompt"] = prompt
373
-
374
- def get_authorization_server_path(self):
375
- if not self.authorization_server_id:
376
- return "/default"
377
- if self.authorization_server_id == "false":
378
- return ""
379
- return f"/{self.authorization_server_id}"
380
-
381
- async def get_token(self, code: str, url: str):
382
- payload = {
383
- "client_id": self.client_id,
384
- "client_secret": self.client_secret,
385
- "code": code,
386
- "grant_type": "authorization_code",
387
- "redirect_uri": url,
388
- }
389
- async with httpx.AsyncClient() as client:
390
- response = await client.post(
391
- f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token",
392
- data=payload,
393
- )
394
- response.raise_for_status()
395
- json_data = response.json()
396
-
397
- token = json_data.get("access_token")
398
- if not token:
399
- raise httpx.HTTPStatusError(
400
- "Failed to get the access token",
401
- request=response.request,
402
- response=response,
403
- )
404
- return token
405
-
406
- async def get_user_info(self, token: str):
407
- async with httpx.AsyncClient() as client:
408
- response = await client.get(
409
- f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo",
410
- headers={"Authorization": f"Bearer {token}"},
411
- )
412
- response.raise_for_status()
413
- okta_user = response.json()
414
-
415
- user = User(
416
- identifier=okta_user.get("email"),
417
- metadata={"image": "", "provider": "okta"},
418
- )
419
- return (okta_user, user)
420
-
421
-
422
- class Auth0OAuthProvider(OAuthProvider):
423
- id = "auth0"
424
- env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"]
425
-
426
- def __init__(self):
427
- self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID")
428
- self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET")
429
- # Ensure that the domain does not have a trailing slash
430
- self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}"
431
- self.original_domain = (
432
- f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}"
433
- if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN")
434
- else self.domain
435
- )
436
-
437
- self.authorize_url = f"{self.domain}/authorize"
438
-
439
- self.authorize_params = {
440
- "response_type": "code",
441
- "scope": "openid profile email",
442
- "audience": f"{self.original_domain}/userinfo",
443
- }
444
-
445
- if prompt := self.get_prompt():
446
- self.authorize_params["prompt"] = prompt
447
-
448
- async def get_token(self, code: str, url: str):
449
- payload = {
450
- "client_id": self.client_id,
451
- "client_secret": self.client_secret,
452
- "code": code,
453
- "grant_type": "authorization_code",
454
- "redirect_uri": url,
455
- }
456
- async with httpx.AsyncClient() as client:
457
- response = await client.post(
458
- f"{self.domain}/oauth/token",
459
- data=payload,
460
- )
461
- response.raise_for_status()
462
- json_content = response.json()
463
- token = json_content.get("access_token")
464
- if not token:
465
- raise HTTPException(
466
- status_code=400, detail="Failed to get the access token"
467
- )
468
- return token
469
-
470
- async def get_user_info(self, token: str):
471
- async with httpx.AsyncClient() as client:
472
- response = await client.get(
473
- f"{self.original_domain}/userinfo",
474
- headers={"Authorization": f"Bearer {token}"},
475
- )
476
- response.raise_for_status()
477
- auth0_user = response.json()
478
- user = User(
479
- identifier=auth0_user.get("email"),
480
- metadata={
481
- "image": auth0_user.get("picture", ""),
482
- "provider": "auth0",
483
- },
484
- )
485
- return (auth0_user, user)
486
-
487
-
488
- class DescopeOAuthProvider(OAuthProvider):
489
- id = "descope"
490
- env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"]
491
- # Ensure that the domain does not have a trailing slash
492
- domain = "https://api.descope.com/oauth2/v1"
493
-
494
- authorize_url = f"{domain}/authorize"
495
-
496
- def __init__(self):
497
- self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID")
498
- self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET")
499
- self.authorize_params = {
500
- "response_type": "code",
501
- "scope": "openid profile email",
502
- "audience": f"{self.domain}/userinfo",
503
- }
504
-
505
- if prompt := self.get_prompt():
506
- self.authorize_params["prompt"] = prompt
507
-
508
- async def get_token(self, code: str, url: str):
509
- payload = {
510
- "client_id": self.client_id,
511
- "client_secret": self.client_secret,
512
- "code": code,
513
- "grant_type": "authorization_code",
514
- "redirect_uri": url,
515
- }
516
- async with httpx.AsyncClient() as client:
517
- response = await client.post(
518
- f"{self.domain}/token",
519
- data=payload,
520
- )
521
- response.raise_for_status()
522
- json_content = response.json()
523
- token = json_content.get("access_token")
524
- if not token:
525
- raise httpx.HTTPStatusError(
526
- "Failed to get the access token",
527
- request=response.request,
528
- response=response,
529
- )
530
- return token
531
-
532
- async def get_user_info(self, token: str):
533
- async with httpx.AsyncClient() as client:
534
- response = await client.get(
535
- f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"}
536
- )
537
- response.raise_for_status() # This will raise an exception for 4xx/5xx responses
538
- descope_user = response.json()
539
-
540
- user = User(
541
- identifier=descope_user.get("email"),
542
- metadata={"image": "", "provider": "descope"},
543
- )
544
- return (descope_user, user)
545
-
546
-
547
- class AWSCognitoOAuthProvider(OAuthProvider):
548
- id = "aws-cognito"
549
- env = [
550
- "OAUTH_COGNITO_CLIENT_ID",
551
- "OAUTH_COGNITO_CLIENT_SECRET",
552
- "OAUTH_COGNITO_DOMAIN",
553
- ]
554
- authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login"
555
- token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token"
556
-
557
- def __init__(self):
558
- self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID")
559
- self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET")
560
- self.scopes = os.environ.get("OAUTH_COGNITO_SCOPE", "openid profile email")
561
- self.authorize_params = {
562
- "response_type": "code",
563
- "client_id": self.client_id,
564
- "scope": self.scopes,
565
- }
566
-
567
- if prompt := self.get_prompt():
568
- self.authorize_params["prompt"] = prompt
569
-
570
- async def get_token(self, code: str, url: str):
571
- payload = {
572
- "client_id": self.client_id,
573
- "client_secret": self.client_secret,
574
- "code": code,
575
- "grant_type": "authorization_code",
576
- "redirect_uri": url,
577
- }
578
- async with httpx.AsyncClient() as client:
579
- response = await client.post(
580
- self.token_url,
581
- data=payload,
582
- )
583
- response.raise_for_status()
584
- json = response.json()
585
-
586
- token = json.get("access_token")
587
- if not token:
588
- raise HTTPException(
589
- status_code=400, detail="Failed to get the access token"
590
- )
591
- return token
592
-
593
- async def get_user_info(self, token: str):
594
- user_info_url = (
595
- f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo"
596
- )
597
- async with httpx.AsyncClient() as client:
598
- response = await client.get(
599
- user_info_url,
600
- headers={"Authorization": f"Bearer {token}"},
601
- )
602
- response.raise_for_status()
603
-
604
- cognito_user = response.json()
605
-
606
- # Customize user metadata as needed
607
- user = User(
608
- identifier=cognito_user["email"],
609
- metadata={
610
- "image": cognito_user.get("picture", ""),
611
- "provider": "aws-cognito",
612
- },
613
- )
614
- return (cognito_user, user)
615
-
616
-
617
- class GitlabOAuthProvider(OAuthProvider):
618
- id = "gitlab"
619
- env = [
620
- "OAUTH_GITLAB_CLIENT_ID",
621
- "OAUTH_GITLAB_CLIENT_SECRET",
622
- "OAUTH_GITLAB_DOMAIN",
623
- ]
624
-
625
- def __init__(self):
626
- self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID")
627
- self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET")
628
- # Ensure that the domain does not have a trailing slash
629
- self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}"
630
-
631
- self.authorize_url = f"{self.domain}/oauth/authorize"
632
-
633
- self.authorize_params = {
634
- "scope": "openid profile email",
635
- "response_type": "code",
636
- }
637
-
638
- if prompt := self.get_prompt():
639
- self.authorize_params["prompt"] = prompt
640
-
641
- async def get_token(self, code: str, url: str):
642
- payload = {
643
- "client_id": self.client_id,
644
- "client_secret": self.client_secret,
645
- "code": code,
646
- "grant_type": "authorization_code",
647
- "redirect_uri": url,
648
- }
649
- async with httpx.AsyncClient() as client:
650
- response = await client.post(
651
- f"{self.domain}/oauth/token",
652
- data=payload,
653
- )
654
- response.raise_for_status()
655
- json_content = response.json()
656
- token = json_content.get("access_token")
657
- if not token:
658
- raise HTTPException(
659
- status_code=400, detail="Failed to get the access token"
660
- )
661
- return token
662
-
663
- async def get_user_info(self, token: str):
664
- async with httpx.AsyncClient() as client:
665
- response = await client.get(
666
- f"{self.domain}/oauth/userinfo",
667
- headers={"Authorization": f"Bearer {token}"},
668
- )
669
- response.raise_for_status()
670
- gitlab_user = response.json()
671
- user = User(
672
- identifier=gitlab_user.get("email"),
673
- metadata={
674
- "image": gitlab_user.get("picture", ""),
675
- "provider": "gitlab",
676
- },
677
- )
678
- return (gitlab_user, user)
679
-
680
-
681
- class KeycloakOAuthProvider(OAuthProvider):
682
- env = [
683
- "OAUTH_KEYCLOAK_CLIENT_ID",
684
- "OAUTH_KEYCLOAK_CLIENT_SECRET",
685
- "OAUTH_KEYCLOAK_REALM",
686
- "OAUTH_KEYCLOAK_BASE_URL",
687
- ]
688
- id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
689
-
690
- def __init__(self):
691
- self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
692
- self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
693
- self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
694
- self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
695
- self.authorize_url = (
696
- f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
697
- )
698
-
699
- self.authorize_params = {
700
- "scope": "profile email openid",
701
- "response_type": "code",
702
- }
703
-
704
- if prompt := self.get_prompt():
705
- self.authorize_params["prompt"] = prompt
706
-
707
- async def get_token(self, code: str, url: str):
708
- payload = {
709
- "client_id": self.client_id,
710
- "client_secret": self.client_secret,
711
- "code": code,
712
- "grant_type": "authorization_code",
713
- "redirect_uri": url,
714
- }
715
- async with httpx.AsyncClient() as client:
716
- response = await client.post(
717
- f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
718
- data=payload,
719
- )
720
- response.raise_for_status()
721
- json = response.json()
722
- token = json.get("access_token")
723
- if not token:
724
- raise httpx.HTTPStatusError(
725
- "Failed to get the access token",
726
- request=response.request,
727
- response=response,
728
- )
729
- return token
730
-
731
- async def get_user_info(self, token: str):
732
- async with httpx.AsyncClient() as client:
733
- response = await client.get(
734
- f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
735
- headers={"Authorization": f"Bearer {token}"},
736
- )
737
- response.raise_for_status()
738
- kc_user = response.json()
739
- user = User(
740
- identifier=kc_user["email"],
741
- metadata={"provider": "keycloak"},
742
- )
743
- return (kc_user, user)
744
-
745
-
746
- class GenericOAuthProvider(OAuthProvider):
747
- env = [
748
- "OAUTH_GENERIC_CLIENT_ID",
749
- "OAUTH_GENERIC_CLIENT_SECRET",
750
- "OAUTH_GENERIC_AUTH_URL",
751
- "OAUTH_GENERIC_TOKEN_URL",
752
- "OAUTH_GENERIC_USER_INFO_URL",
753
- "OAUTH_GENERIC_SCOPES",
754
- ]
755
- id = os.environ.get("OAUTH_GENERIC_NAME", "generic")
756
-
757
- def __init__(self):
758
- self.client_id = os.environ.get("OAUTH_GENERIC_CLIENT_ID")
759
- self.client_secret = os.environ.get("OAUTH_GENERIC_CLIENT_SECRET")
760
- self.authorize_url = os.environ.get("OAUTH_GENERIC_AUTH_URL")
761
- self.token_url = os.environ.get("OAUTH_GENERIC_TOKEN_URL")
762
- self.user_info_url = os.environ.get("OAUTH_GENERIC_USER_INFO_URL")
763
- self.scopes = os.environ.get("OAUTH_GENERIC_SCOPES")
764
- self.user_identifier = os.environ.get("OAUTH_GENERIC_USER_IDENTIFIER", "email")
765
-
766
- self.authorize_params = {
767
- "scope": self.scopes,
768
- "response_type": "code",
769
- }
770
-
771
- if prompt := self.get_prompt():
772
- self.authorize_params["prompt"] = prompt
773
-
774
- async def get_token(self, code: str, url: str):
775
- payload = {
776
- "client_id": self.client_id,
777
- "client_secret": self.client_secret,
778
- "code": code,
779
- "grant_type": "authorization_code",
780
- "redirect_uri": url,
781
- }
782
- async with httpx.AsyncClient() as client:
783
- response = await client.post(self.token_url, data=payload)
784
- response.raise_for_status()
785
- json = response.json()
786
- token = json.get("access_token")
787
- if not token:
788
- raise httpx.HTTPStatusError(
789
- "Failed to get the access token",
790
- request=response.request,
791
- response=response,
792
- )
793
- return token
794
-
795
- async def get_user_info(self, token: str):
796
- async with httpx.AsyncClient() as client:
797
- response = await client.get(
798
- self.user_info_url,
799
- headers={"Authorization": f"Bearer {token}"},
800
- )
801
- response.raise_for_status()
802
- server_user = response.json()
803
- user = User(
804
- identifier=server_user.get(self.user_identifier),
805
- metadata={
806
- "provider": self.id,
807
- },
808
- )
809
- return (server_user, user)
810
-
811
-
812
- providers = [
813
- GithubOAuthProvider(),
814
- GoogleOAuthProvider(),
815
- AzureADOAuthProvider(),
816
- AzureADHybridOAuthProvider(),
817
- OktaOAuthProvider(),
818
- Auth0OAuthProvider(),
819
- DescopeOAuthProvider(),
820
- AWSCognitoOAuthProvider(),
821
- GitlabOAuthProvider(),
822
- KeycloakOAuthProvider(),
823
- GenericOAuthProvider(),
824
- ]
825
-
826
-
827
- def get_oauth_provider(provider: str) -> Optional[OAuthProvider]:
828
- for p in providers:
829
- if p.id == provider:
830
- return p
831
- return None
832
-
833
-
834
- def get_configured_oauth_providers():
835
- return [p.id for p in providers if p.is_configured()]