chainlit 2.7.0__py3-none-any.whl → 2.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chainlit might be problematic. Click here for more details.
- {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/METADATA +1 -1
- chainlit-2.7.1.dist-info/RECORD +4 -0
- chainlit/__init__.py +0 -207
- chainlit/__main__.py +0 -4
- chainlit/_utils.py +0 -8
- chainlit/action.py +0 -33
- chainlit/auth/__init__.py +0 -95
- chainlit/auth/cookie.py +0 -197
- chainlit/auth/jwt.py +0 -42
- chainlit/cache.py +0 -45
- chainlit/callbacks.py +0 -433
- chainlit/chat_context.py +0 -64
- chainlit/chat_settings.py +0 -34
- chainlit/cli/__init__.py +0 -235
- chainlit/config.py +0 -621
- chainlit/context.py +0 -112
- chainlit/data/__init__.py +0 -111
- chainlit/data/acl.py +0 -19
- chainlit/data/base.py +0 -107
- chainlit/data/chainlit_data_layer.py +0 -687
- chainlit/data/dynamodb.py +0 -616
- chainlit/data/literalai.py +0 -501
- chainlit/data/sql_alchemy.py +0 -741
- chainlit/data/storage_clients/__init__.py +0 -0
- chainlit/data/storage_clients/azure.py +0 -84
- chainlit/data/storage_clients/azure_blob.py +0 -94
- chainlit/data/storage_clients/base.py +0 -28
- chainlit/data/storage_clients/gcs.py +0 -101
- chainlit/data/storage_clients/s3.py +0 -88
- chainlit/data/utils.py +0 -29
- chainlit/discord/__init__.py +0 -6
- chainlit/discord/app.py +0 -364
- chainlit/element.py +0 -454
- chainlit/emitter.py +0 -450
- chainlit/hello.py +0 -12
- chainlit/input_widget.py +0 -182
- chainlit/langchain/__init__.py +0 -6
- chainlit/langchain/callbacks.py +0 -682
- chainlit/langflow/__init__.py +0 -25
- chainlit/llama_index/__init__.py +0 -6
- chainlit/llama_index/callbacks.py +0 -206
- chainlit/logger.py +0 -16
- chainlit/markdown.py +0 -57
- chainlit/mcp.py +0 -99
- chainlit/message.py +0 -619
- chainlit/mistralai/__init__.py +0 -50
- chainlit/oauth_providers.py +0 -835
- chainlit/openai/__init__.py +0 -53
- chainlit/py.typed +0 -0
- chainlit/secret.py +0 -9
- chainlit/semantic_kernel/__init__.py +0 -111
- chainlit/server.py +0 -1616
- chainlit/session.py +0 -304
- chainlit/sidebar.py +0 -55
- chainlit/slack/__init__.py +0 -6
- chainlit/slack/app.py +0 -427
- chainlit/socket.py +0 -381
- chainlit/step.py +0 -490
- chainlit/sync.py +0 -43
- chainlit/teams/__init__.py +0 -6
- chainlit/teams/app.py +0 -348
- chainlit/translations/bn.json +0 -214
- chainlit/translations/el-GR.json +0 -214
- chainlit/translations/en-US.json +0 -214
- chainlit/translations/fr-FR.json +0 -214
- chainlit/translations/gu.json +0 -214
- chainlit/translations/he-IL.json +0 -214
- chainlit/translations/hi.json +0 -214
- chainlit/translations/ja.json +0 -214
- chainlit/translations/kn.json +0 -214
- chainlit/translations/ml.json +0 -214
- chainlit/translations/mr.json +0 -214
- chainlit/translations/nl.json +0 -214
- chainlit/translations/ta.json +0 -214
- chainlit/translations/te.json +0 -214
- chainlit/translations/zh-CN.json +0 -214
- chainlit/translations.py +0 -60
- chainlit/types.py +0 -334
- chainlit/user.py +0 -43
- chainlit/user_session.py +0 -153
- chainlit/utils.py +0 -173
- chainlit/version.py +0 -8
- chainlit-2.7.0.dist-info/RECORD +0 -84
- {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/WHEEL +0 -0
- {chainlit-2.7.0.dist-info → chainlit-2.7.1.dist-info}/entry_points.txt +0 -0
chainlit/oauth_providers.py
DELETED
|
@@ -1,835 +0,0 @@
|
|
|
1
|
-
import base64
|
|
2
|
-
import os
|
|
3
|
-
import urllib.parse
|
|
4
|
-
from typing import Dict, List, Optional, Tuple
|
|
5
|
-
|
|
6
|
-
import httpx
|
|
7
|
-
from fastapi import HTTPException
|
|
8
|
-
|
|
9
|
-
from chainlit.secret import random_secret
|
|
10
|
-
from chainlit.user import User
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class OAuthProvider:
|
|
14
|
-
id: str
|
|
15
|
-
env: List[str]
|
|
16
|
-
client_id: str
|
|
17
|
-
client_secret: str
|
|
18
|
-
authorize_url: str
|
|
19
|
-
authorize_params: Dict[str, str]
|
|
20
|
-
default_prompt: Optional[str] = None
|
|
21
|
-
|
|
22
|
-
def is_configured(self):
|
|
23
|
-
return all([os.environ.get(env) for env in self.env])
|
|
24
|
-
|
|
25
|
-
async def get_token(self, code: str, url: str) -> str:
|
|
26
|
-
raise NotImplementedError
|
|
27
|
-
|
|
28
|
-
async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
|
|
29
|
-
raise NotImplementedError
|
|
30
|
-
|
|
31
|
-
def get_env_prefix(self) -> str:
|
|
32
|
-
"""Return environment prefix, like AZURE_AD."""
|
|
33
|
-
|
|
34
|
-
return self.id.replace("-", "_").upper()
|
|
35
|
-
|
|
36
|
-
def get_prompt(self) -> Optional[str]:
|
|
37
|
-
"""Return OAuth prompt param."""
|
|
38
|
-
if prompt := os.environ.get(f"OAUTH_{self.get_env_prefix()}_PROMPT"):
|
|
39
|
-
return prompt
|
|
40
|
-
|
|
41
|
-
if prompt := os.environ.get("OAUTH_PROMPT"):
|
|
42
|
-
return prompt
|
|
43
|
-
|
|
44
|
-
return self.default_prompt
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class GithubOAuthProvider(OAuthProvider):
|
|
48
|
-
id = "github"
|
|
49
|
-
env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"]
|
|
50
|
-
authorize_url = "https://github.com/login/oauth/authorize"
|
|
51
|
-
|
|
52
|
-
def __init__(self):
|
|
53
|
-
self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID")
|
|
54
|
-
self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET")
|
|
55
|
-
self.authorize_params = {
|
|
56
|
-
"scope": "user:email",
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
if prompt := self.get_prompt():
|
|
60
|
-
self.authorize_params["prompt"] = prompt
|
|
61
|
-
|
|
62
|
-
async def get_token(self, code: str, url: str):
|
|
63
|
-
payload = {
|
|
64
|
-
"client_id": self.client_id,
|
|
65
|
-
"client_secret": self.client_secret,
|
|
66
|
-
"code": code,
|
|
67
|
-
}
|
|
68
|
-
async with httpx.AsyncClient() as client:
|
|
69
|
-
response = await client.post(
|
|
70
|
-
"https://github.com/login/oauth/access_token",
|
|
71
|
-
data=payload,
|
|
72
|
-
)
|
|
73
|
-
response.raise_for_status()
|
|
74
|
-
content = urllib.parse.parse_qs(response.text)
|
|
75
|
-
token = content.get("access_token", [""])[0]
|
|
76
|
-
if not token:
|
|
77
|
-
raise HTTPException(
|
|
78
|
-
status_code=400, detail="Failed to get the access token"
|
|
79
|
-
)
|
|
80
|
-
return token
|
|
81
|
-
|
|
82
|
-
async def get_user_info(self, token: str):
|
|
83
|
-
async with httpx.AsyncClient() as client:
|
|
84
|
-
user_response = await client.get(
|
|
85
|
-
"https://api.github.com/user",
|
|
86
|
-
headers={"Authorization": f"token {token}"},
|
|
87
|
-
)
|
|
88
|
-
user_response.raise_for_status()
|
|
89
|
-
github_user = user_response.json()
|
|
90
|
-
|
|
91
|
-
emails_response = await client.get(
|
|
92
|
-
"https://api.github.com/user/emails",
|
|
93
|
-
headers={"Authorization": f"token {token}"},
|
|
94
|
-
)
|
|
95
|
-
emails_response.raise_for_status()
|
|
96
|
-
emails = emails_response.json()
|
|
97
|
-
|
|
98
|
-
github_user.update({"emails": emails})
|
|
99
|
-
user = User(
|
|
100
|
-
identifier=github_user["login"],
|
|
101
|
-
metadata={"image": github_user["avatar_url"], "provider": "github"},
|
|
102
|
-
)
|
|
103
|
-
return (github_user, user)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class GoogleOAuthProvider(OAuthProvider):
|
|
107
|
-
id = "google"
|
|
108
|
-
env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"]
|
|
109
|
-
authorize_url = "https://accounts.google.com/o/oauth2/v2/auth"
|
|
110
|
-
|
|
111
|
-
def __init__(self):
|
|
112
|
-
self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID")
|
|
113
|
-
self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET")
|
|
114
|
-
self.authorize_params = {
|
|
115
|
-
"scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
|
116
|
-
"response_type": "code",
|
|
117
|
-
"access_type": "offline",
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
if prompt := self.get_prompt():
|
|
121
|
-
self.authorize_params["prompt"] = prompt
|
|
122
|
-
|
|
123
|
-
async def get_token(self, code: str, url: str):
|
|
124
|
-
payload = {
|
|
125
|
-
"client_id": self.client_id,
|
|
126
|
-
"client_secret": self.client_secret,
|
|
127
|
-
"code": code,
|
|
128
|
-
"grant_type": "authorization_code",
|
|
129
|
-
"redirect_uri": url,
|
|
130
|
-
}
|
|
131
|
-
async with httpx.AsyncClient() as client:
|
|
132
|
-
response = await client.post(
|
|
133
|
-
"https://oauth2.googleapis.com/token",
|
|
134
|
-
data=payload,
|
|
135
|
-
)
|
|
136
|
-
response.raise_for_status()
|
|
137
|
-
json = response.json()
|
|
138
|
-
token = json.get("access_token")
|
|
139
|
-
if not token:
|
|
140
|
-
raise httpx.HTTPStatusError(
|
|
141
|
-
"Failed to get the access token",
|
|
142
|
-
request=response.request,
|
|
143
|
-
response=response,
|
|
144
|
-
)
|
|
145
|
-
return token
|
|
146
|
-
|
|
147
|
-
async def get_user_info(self, token: str):
|
|
148
|
-
async with httpx.AsyncClient() as client:
|
|
149
|
-
response = await client.get(
|
|
150
|
-
"https://www.googleapis.com/userinfo/v2/me",
|
|
151
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
152
|
-
)
|
|
153
|
-
response.raise_for_status()
|
|
154
|
-
google_user = response.json()
|
|
155
|
-
user = User(
|
|
156
|
-
identifier=google_user["email"],
|
|
157
|
-
metadata={"image": google_user["picture"], "provider": "google"},
|
|
158
|
-
)
|
|
159
|
-
return (google_user, user)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
class AzureADOAuthProvider(OAuthProvider):
|
|
163
|
-
id = "azure-ad"
|
|
164
|
-
env = [
|
|
165
|
-
"OAUTH_AZURE_AD_CLIENT_ID",
|
|
166
|
-
"OAUTH_AZURE_AD_CLIENT_SECRET",
|
|
167
|
-
"OAUTH_AZURE_AD_TENANT_ID",
|
|
168
|
-
]
|
|
169
|
-
authorize_url = (
|
|
170
|
-
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize"
|
|
171
|
-
if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
|
|
172
|
-
else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
|
173
|
-
)
|
|
174
|
-
token_url = (
|
|
175
|
-
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token"
|
|
176
|
-
if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
|
|
177
|
-
else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
def __init__(self):
|
|
181
|
-
self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID")
|
|
182
|
-
self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET")
|
|
183
|
-
self.authorize_params = {
|
|
184
|
-
"tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"),
|
|
185
|
-
"response_type": "code",
|
|
186
|
-
"scope": "https://graph.microsoft.com/User.Read offline_access",
|
|
187
|
-
"response_mode": "query",
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
if prompt := self.get_prompt():
|
|
191
|
-
self.authorize_params["prompt"] = prompt
|
|
192
|
-
|
|
193
|
-
async def get_token(self, code: str, url: str):
|
|
194
|
-
payload = {
|
|
195
|
-
"client_id": self.client_id,
|
|
196
|
-
"client_secret": self.client_secret,
|
|
197
|
-
"code": code,
|
|
198
|
-
"grant_type": "authorization_code",
|
|
199
|
-
"redirect_uri": url,
|
|
200
|
-
}
|
|
201
|
-
async with httpx.AsyncClient() as client:
|
|
202
|
-
response = await client.post(
|
|
203
|
-
self.token_url,
|
|
204
|
-
data=payload,
|
|
205
|
-
)
|
|
206
|
-
response.raise_for_status()
|
|
207
|
-
json = response.json()
|
|
208
|
-
|
|
209
|
-
token = json["access_token"]
|
|
210
|
-
refresh_token = json.get("refresh_token")
|
|
211
|
-
if not token:
|
|
212
|
-
raise HTTPException(
|
|
213
|
-
status_code=400, detail="Failed to get the access token"
|
|
214
|
-
)
|
|
215
|
-
self._refresh_token = refresh_token
|
|
216
|
-
return token
|
|
217
|
-
|
|
218
|
-
async def get_user_info(self, token: str):
|
|
219
|
-
async with httpx.AsyncClient() as client:
|
|
220
|
-
response = await client.get(
|
|
221
|
-
"https://graph.microsoft.com/v1.0/me",
|
|
222
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
223
|
-
)
|
|
224
|
-
response.raise_for_status()
|
|
225
|
-
|
|
226
|
-
azure_user = response.json()
|
|
227
|
-
|
|
228
|
-
try:
|
|
229
|
-
photo_response = await client.get(
|
|
230
|
-
"https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
|
|
231
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
232
|
-
)
|
|
233
|
-
photo_data = await photo_response.aread()
|
|
234
|
-
base64_image = base64.b64encode(photo_data)
|
|
235
|
-
azure_user["image"] = (
|
|
236
|
-
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
|
237
|
-
)
|
|
238
|
-
except Exception:
|
|
239
|
-
# Ignore errors getting the photo
|
|
240
|
-
pass
|
|
241
|
-
|
|
242
|
-
user = User(
|
|
243
|
-
identifier=azure_user["userPrincipalName"],
|
|
244
|
-
metadata={
|
|
245
|
-
"image": azure_user.get("image"),
|
|
246
|
-
"provider": "azure-ad",
|
|
247
|
-
"refresh_token": getattr(self, "_refresh_token", None),
|
|
248
|
-
},
|
|
249
|
-
)
|
|
250
|
-
return (azure_user, user)
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
class AzureADHybridOAuthProvider(OAuthProvider):
|
|
254
|
-
id = "azure-ad-hybrid"
|
|
255
|
-
env = [
|
|
256
|
-
"OAUTH_AZURE_AD_HYBRID_CLIENT_ID",
|
|
257
|
-
"OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET",
|
|
258
|
-
"OAUTH_AZURE_AD_HYBRID_TENANT_ID",
|
|
259
|
-
]
|
|
260
|
-
authorize_url = (
|
|
261
|
-
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize"
|
|
262
|
-
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
|
263
|
-
else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
|
264
|
-
)
|
|
265
|
-
token_url = (
|
|
266
|
-
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token"
|
|
267
|
-
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
|
268
|
-
else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
def __init__(self):
|
|
272
|
-
self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID")
|
|
273
|
-
self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET")
|
|
274
|
-
nonce = random_secret(16)
|
|
275
|
-
self.authorize_params = {
|
|
276
|
-
"tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
|
|
277
|
-
"response_type": "code id_token",
|
|
278
|
-
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access",
|
|
279
|
-
"response_mode": "form_post",
|
|
280
|
-
"nonce": nonce,
|
|
281
|
-
}
|
|
282
|
-
|
|
283
|
-
if prompt := self.get_prompt():
|
|
284
|
-
self.authorize_params["prompt"] = prompt
|
|
285
|
-
|
|
286
|
-
async def get_token(self, code: str, url: str):
|
|
287
|
-
payload = {
|
|
288
|
-
"client_id": self.client_id,
|
|
289
|
-
"client_secret": self.client_secret,
|
|
290
|
-
"code": code,
|
|
291
|
-
"grant_type": "authorization_code",
|
|
292
|
-
"redirect_uri": url,
|
|
293
|
-
}
|
|
294
|
-
async with httpx.AsyncClient() as client:
|
|
295
|
-
response = await client.post(
|
|
296
|
-
self.token_url,
|
|
297
|
-
data=payload,
|
|
298
|
-
)
|
|
299
|
-
response.raise_for_status()
|
|
300
|
-
json = response.json()
|
|
301
|
-
|
|
302
|
-
token = json["access_token"]
|
|
303
|
-
refresh_token = json.get("refresh_token")
|
|
304
|
-
if not token:
|
|
305
|
-
raise HTTPException(
|
|
306
|
-
status_code=400, detail="Failed to get the access token"
|
|
307
|
-
)
|
|
308
|
-
self._refresh_token = refresh_token
|
|
309
|
-
return token
|
|
310
|
-
|
|
311
|
-
async def get_user_info(self, token: str):
|
|
312
|
-
async with httpx.AsyncClient() as client:
|
|
313
|
-
response = await client.get(
|
|
314
|
-
"https://graph.microsoft.com/v1.0/me",
|
|
315
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
316
|
-
)
|
|
317
|
-
response.raise_for_status()
|
|
318
|
-
|
|
319
|
-
azure_user = response.json()
|
|
320
|
-
|
|
321
|
-
try:
|
|
322
|
-
photo_response = await client.get(
|
|
323
|
-
"https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
|
|
324
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
325
|
-
)
|
|
326
|
-
photo_data = await photo_response.aread()
|
|
327
|
-
base64_image = base64.b64encode(photo_data)
|
|
328
|
-
azure_user["image"] = (
|
|
329
|
-
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
|
330
|
-
)
|
|
331
|
-
except Exception:
|
|
332
|
-
# Ignore errors getting the photo
|
|
333
|
-
pass
|
|
334
|
-
|
|
335
|
-
user = User(
|
|
336
|
-
identifier=azure_user["userPrincipalName"],
|
|
337
|
-
metadata={
|
|
338
|
-
"image": azure_user.get("image"),
|
|
339
|
-
"provider": "azure-ad",
|
|
340
|
-
"refresh_token": getattr(self, "_refresh_token", None),
|
|
341
|
-
},
|
|
342
|
-
)
|
|
343
|
-
return (azure_user, user)
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
class OktaOAuthProvider(OAuthProvider):
|
|
347
|
-
id = "okta"
|
|
348
|
-
env = [
|
|
349
|
-
"OAUTH_OKTA_CLIENT_ID",
|
|
350
|
-
"OAUTH_OKTA_CLIENT_SECRET",
|
|
351
|
-
"OAUTH_OKTA_DOMAIN",
|
|
352
|
-
]
|
|
353
|
-
# Avoid trailing slash in domain if supplied
|
|
354
|
-
domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}"
|
|
355
|
-
|
|
356
|
-
def __init__(self):
|
|
357
|
-
self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID")
|
|
358
|
-
self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET")
|
|
359
|
-
self.authorization_server_id = os.environ.get(
|
|
360
|
-
"OAUTH_OKTA_AUTHORIZATION_SERVER_ID", ""
|
|
361
|
-
)
|
|
362
|
-
self.authorize_url = (
|
|
363
|
-
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize"
|
|
364
|
-
)
|
|
365
|
-
self.authorize_params = {
|
|
366
|
-
"response_type": "code",
|
|
367
|
-
"scope": "openid profile email",
|
|
368
|
-
"response_mode": "query",
|
|
369
|
-
}
|
|
370
|
-
|
|
371
|
-
if prompt := self.get_prompt():
|
|
372
|
-
self.authorize_params["prompt"] = prompt
|
|
373
|
-
|
|
374
|
-
def get_authorization_server_path(self):
|
|
375
|
-
if not self.authorization_server_id:
|
|
376
|
-
return "/default"
|
|
377
|
-
if self.authorization_server_id == "false":
|
|
378
|
-
return ""
|
|
379
|
-
return f"/{self.authorization_server_id}"
|
|
380
|
-
|
|
381
|
-
async def get_token(self, code: str, url: str):
|
|
382
|
-
payload = {
|
|
383
|
-
"client_id": self.client_id,
|
|
384
|
-
"client_secret": self.client_secret,
|
|
385
|
-
"code": code,
|
|
386
|
-
"grant_type": "authorization_code",
|
|
387
|
-
"redirect_uri": url,
|
|
388
|
-
}
|
|
389
|
-
async with httpx.AsyncClient() as client:
|
|
390
|
-
response = await client.post(
|
|
391
|
-
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token",
|
|
392
|
-
data=payload,
|
|
393
|
-
)
|
|
394
|
-
response.raise_for_status()
|
|
395
|
-
json_data = response.json()
|
|
396
|
-
|
|
397
|
-
token = json_data.get("access_token")
|
|
398
|
-
if not token:
|
|
399
|
-
raise httpx.HTTPStatusError(
|
|
400
|
-
"Failed to get the access token",
|
|
401
|
-
request=response.request,
|
|
402
|
-
response=response,
|
|
403
|
-
)
|
|
404
|
-
return token
|
|
405
|
-
|
|
406
|
-
async def get_user_info(self, token: str):
|
|
407
|
-
async with httpx.AsyncClient() as client:
|
|
408
|
-
response = await client.get(
|
|
409
|
-
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo",
|
|
410
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
411
|
-
)
|
|
412
|
-
response.raise_for_status()
|
|
413
|
-
okta_user = response.json()
|
|
414
|
-
|
|
415
|
-
user = User(
|
|
416
|
-
identifier=okta_user.get("email"),
|
|
417
|
-
metadata={"image": "", "provider": "okta"},
|
|
418
|
-
)
|
|
419
|
-
return (okta_user, user)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
class Auth0OAuthProvider(OAuthProvider):
|
|
423
|
-
id = "auth0"
|
|
424
|
-
env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"]
|
|
425
|
-
|
|
426
|
-
def __init__(self):
|
|
427
|
-
self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID")
|
|
428
|
-
self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET")
|
|
429
|
-
# Ensure that the domain does not have a trailing slash
|
|
430
|
-
self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}"
|
|
431
|
-
self.original_domain = (
|
|
432
|
-
f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}"
|
|
433
|
-
if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN")
|
|
434
|
-
else self.domain
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
self.authorize_url = f"{self.domain}/authorize"
|
|
438
|
-
|
|
439
|
-
self.authorize_params = {
|
|
440
|
-
"response_type": "code",
|
|
441
|
-
"scope": "openid profile email",
|
|
442
|
-
"audience": f"{self.original_domain}/userinfo",
|
|
443
|
-
}
|
|
444
|
-
|
|
445
|
-
if prompt := self.get_prompt():
|
|
446
|
-
self.authorize_params["prompt"] = prompt
|
|
447
|
-
|
|
448
|
-
async def get_token(self, code: str, url: str):
|
|
449
|
-
payload = {
|
|
450
|
-
"client_id": self.client_id,
|
|
451
|
-
"client_secret": self.client_secret,
|
|
452
|
-
"code": code,
|
|
453
|
-
"grant_type": "authorization_code",
|
|
454
|
-
"redirect_uri": url,
|
|
455
|
-
}
|
|
456
|
-
async with httpx.AsyncClient() as client:
|
|
457
|
-
response = await client.post(
|
|
458
|
-
f"{self.domain}/oauth/token",
|
|
459
|
-
data=payload,
|
|
460
|
-
)
|
|
461
|
-
response.raise_for_status()
|
|
462
|
-
json_content = response.json()
|
|
463
|
-
token = json_content.get("access_token")
|
|
464
|
-
if not token:
|
|
465
|
-
raise HTTPException(
|
|
466
|
-
status_code=400, detail="Failed to get the access token"
|
|
467
|
-
)
|
|
468
|
-
return token
|
|
469
|
-
|
|
470
|
-
async def get_user_info(self, token: str):
|
|
471
|
-
async with httpx.AsyncClient() as client:
|
|
472
|
-
response = await client.get(
|
|
473
|
-
f"{self.original_domain}/userinfo",
|
|
474
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
475
|
-
)
|
|
476
|
-
response.raise_for_status()
|
|
477
|
-
auth0_user = response.json()
|
|
478
|
-
user = User(
|
|
479
|
-
identifier=auth0_user.get("email"),
|
|
480
|
-
metadata={
|
|
481
|
-
"image": auth0_user.get("picture", ""),
|
|
482
|
-
"provider": "auth0",
|
|
483
|
-
},
|
|
484
|
-
)
|
|
485
|
-
return (auth0_user, user)
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
class DescopeOAuthProvider(OAuthProvider):
|
|
489
|
-
id = "descope"
|
|
490
|
-
env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"]
|
|
491
|
-
# Ensure that the domain does not have a trailing slash
|
|
492
|
-
domain = "https://api.descope.com/oauth2/v1"
|
|
493
|
-
|
|
494
|
-
authorize_url = f"{domain}/authorize"
|
|
495
|
-
|
|
496
|
-
def __init__(self):
|
|
497
|
-
self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID")
|
|
498
|
-
self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET")
|
|
499
|
-
self.authorize_params = {
|
|
500
|
-
"response_type": "code",
|
|
501
|
-
"scope": "openid profile email",
|
|
502
|
-
"audience": f"{self.domain}/userinfo",
|
|
503
|
-
}
|
|
504
|
-
|
|
505
|
-
if prompt := self.get_prompt():
|
|
506
|
-
self.authorize_params["prompt"] = prompt
|
|
507
|
-
|
|
508
|
-
async def get_token(self, code: str, url: str):
|
|
509
|
-
payload = {
|
|
510
|
-
"client_id": self.client_id,
|
|
511
|
-
"client_secret": self.client_secret,
|
|
512
|
-
"code": code,
|
|
513
|
-
"grant_type": "authorization_code",
|
|
514
|
-
"redirect_uri": url,
|
|
515
|
-
}
|
|
516
|
-
async with httpx.AsyncClient() as client:
|
|
517
|
-
response = await client.post(
|
|
518
|
-
f"{self.domain}/token",
|
|
519
|
-
data=payload,
|
|
520
|
-
)
|
|
521
|
-
response.raise_for_status()
|
|
522
|
-
json_content = response.json()
|
|
523
|
-
token = json_content.get("access_token")
|
|
524
|
-
if not token:
|
|
525
|
-
raise httpx.HTTPStatusError(
|
|
526
|
-
"Failed to get the access token",
|
|
527
|
-
request=response.request,
|
|
528
|
-
response=response,
|
|
529
|
-
)
|
|
530
|
-
return token
|
|
531
|
-
|
|
532
|
-
async def get_user_info(self, token: str):
|
|
533
|
-
async with httpx.AsyncClient() as client:
|
|
534
|
-
response = await client.get(
|
|
535
|
-
f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"}
|
|
536
|
-
)
|
|
537
|
-
response.raise_for_status() # This will raise an exception for 4xx/5xx responses
|
|
538
|
-
descope_user = response.json()
|
|
539
|
-
|
|
540
|
-
user = User(
|
|
541
|
-
identifier=descope_user.get("email"),
|
|
542
|
-
metadata={"image": "", "provider": "descope"},
|
|
543
|
-
)
|
|
544
|
-
return (descope_user, user)
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
class AWSCognitoOAuthProvider(OAuthProvider):
|
|
548
|
-
id = "aws-cognito"
|
|
549
|
-
env = [
|
|
550
|
-
"OAUTH_COGNITO_CLIENT_ID",
|
|
551
|
-
"OAUTH_COGNITO_CLIENT_SECRET",
|
|
552
|
-
"OAUTH_COGNITO_DOMAIN",
|
|
553
|
-
]
|
|
554
|
-
authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login"
|
|
555
|
-
token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token"
|
|
556
|
-
|
|
557
|
-
def __init__(self):
|
|
558
|
-
self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID")
|
|
559
|
-
self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET")
|
|
560
|
-
self.scopes = os.environ.get("OAUTH_COGNITO_SCOPE", "openid profile email")
|
|
561
|
-
self.authorize_params = {
|
|
562
|
-
"response_type": "code",
|
|
563
|
-
"client_id": self.client_id,
|
|
564
|
-
"scope": self.scopes,
|
|
565
|
-
}
|
|
566
|
-
|
|
567
|
-
if prompt := self.get_prompt():
|
|
568
|
-
self.authorize_params["prompt"] = prompt
|
|
569
|
-
|
|
570
|
-
async def get_token(self, code: str, url: str):
|
|
571
|
-
payload = {
|
|
572
|
-
"client_id": self.client_id,
|
|
573
|
-
"client_secret": self.client_secret,
|
|
574
|
-
"code": code,
|
|
575
|
-
"grant_type": "authorization_code",
|
|
576
|
-
"redirect_uri": url,
|
|
577
|
-
}
|
|
578
|
-
async with httpx.AsyncClient() as client:
|
|
579
|
-
response = await client.post(
|
|
580
|
-
self.token_url,
|
|
581
|
-
data=payload,
|
|
582
|
-
)
|
|
583
|
-
response.raise_for_status()
|
|
584
|
-
json = response.json()
|
|
585
|
-
|
|
586
|
-
token = json.get("access_token")
|
|
587
|
-
if not token:
|
|
588
|
-
raise HTTPException(
|
|
589
|
-
status_code=400, detail="Failed to get the access token"
|
|
590
|
-
)
|
|
591
|
-
return token
|
|
592
|
-
|
|
593
|
-
async def get_user_info(self, token: str):
|
|
594
|
-
user_info_url = (
|
|
595
|
-
f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo"
|
|
596
|
-
)
|
|
597
|
-
async with httpx.AsyncClient() as client:
|
|
598
|
-
response = await client.get(
|
|
599
|
-
user_info_url,
|
|
600
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
601
|
-
)
|
|
602
|
-
response.raise_for_status()
|
|
603
|
-
|
|
604
|
-
cognito_user = response.json()
|
|
605
|
-
|
|
606
|
-
# Customize user metadata as needed
|
|
607
|
-
user = User(
|
|
608
|
-
identifier=cognito_user["email"],
|
|
609
|
-
metadata={
|
|
610
|
-
"image": cognito_user.get("picture", ""),
|
|
611
|
-
"provider": "aws-cognito",
|
|
612
|
-
},
|
|
613
|
-
)
|
|
614
|
-
return (cognito_user, user)
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
class GitlabOAuthProvider(OAuthProvider):
|
|
618
|
-
id = "gitlab"
|
|
619
|
-
env = [
|
|
620
|
-
"OAUTH_GITLAB_CLIENT_ID",
|
|
621
|
-
"OAUTH_GITLAB_CLIENT_SECRET",
|
|
622
|
-
"OAUTH_GITLAB_DOMAIN",
|
|
623
|
-
]
|
|
624
|
-
|
|
625
|
-
def __init__(self):
|
|
626
|
-
self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID")
|
|
627
|
-
self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET")
|
|
628
|
-
# Ensure that the domain does not have a trailing slash
|
|
629
|
-
self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}"
|
|
630
|
-
|
|
631
|
-
self.authorize_url = f"{self.domain}/oauth/authorize"
|
|
632
|
-
|
|
633
|
-
self.authorize_params = {
|
|
634
|
-
"scope": "openid profile email",
|
|
635
|
-
"response_type": "code",
|
|
636
|
-
}
|
|
637
|
-
|
|
638
|
-
if prompt := self.get_prompt():
|
|
639
|
-
self.authorize_params["prompt"] = prompt
|
|
640
|
-
|
|
641
|
-
async def get_token(self, code: str, url: str):
|
|
642
|
-
payload = {
|
|
643
|
-
"client_id": self.client_id,
|
|
644
|
-
"client_secret": self.client_secret,
|
|
645
|
-
"code": code,
|
|
646
|
-
"grant_type": "authorization_code",
|
|
647
|
-
"redirect_uri": url,
|
|
648
|
-
}
|
|
649
|
-
async with httpx.AsyncClient() as client:
|
|
650
|
-
response = await client.post(
|
|
651
|
-
f"{self.domain}/oauth/token",
|
|
652
|
-
data=payload,
|
|
653
|
-
)
|
|
654
|
-
response.raise_for_status()
|
|
655
|
-
json_content = response.json()
|
|
656
|
-
token = json_content.get("access_token")
|
|
657
|
-
if not token:
|
|
658
|
-
raise HTTPException(
|
|
659
|
-
status_code=400, detail="Failed to get the access token"
|
|
660
|
-
)
|
|
661
|
-
return token
|
|
662
|
-
|
|
663
|
-
async def get_user_info(self, token: str):
|
|
664
|
-
async with httpx.AsyncClient() as client:
|
|
665
|
-
response = await client.get(
|
|
666
|
-
f"{self.domain}/oauth/userinfo",
|
|
667
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
668
|
-
)
|
|
669
|
-
response.raise_for_status()
|
|
670
|
-
gitlab_user = response.json()
|
|
671
|
-
user = User(
|
|
672
|
-
identifier=gitlab_user.get("email"),
|
|
673
|
-
metadata={
|
|
674
|
-
"image": gitlab_user.get("picture", ""),
|
|
675
|
-
"provider": "gitlab",
|
|
676
|
-
},
|
|
677
|
-
)
|
|
678
|
-
return (gitlab_user, user)
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
class KeycloakOAuthProvider(OAuthProvider):
|
|
682
|
-
env = [
|
|
683
|
-
"OAUTH_KEYCLOAK_CLIENT_ID",
|
|
684
|
-
"OAUTH_KEYCLOAK_CLIENT_SECRET",
|
|
685
|
-
"OAUTH_KEYCLOAK_REALM",
|
|
686
|
-
"OAUTH_KEYCLOAK_BASE_URL",
|
|
687
|
-
]
|
|
688
|
-
id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
|
|
689
|
-
|
|
690
|
-
def __init__(self):
|
|
691
|
-
self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
|
|
692
|
-
self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
|
|
693
|
-
self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
|
|
694
|
-
self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
|
|
695
|
-
self.authorize_url = (
|
|
696
|
-
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
|
|
697
|
-
)
|
|
698
|
-
|
|
699
|
-
self.authorize_params = {
|
|
700
|
-
"scope": "profile email openid",
|
|
701
|
-
"response_type": "code",
|
|
702
|
-
}
|
|
703
|
-
|
|
704
|
-
if prompt := self.get_prompt():
|
|
705
|
-
self.authorize_params["prompt"] = prompt
|
|
706
|
-
|
|
707
|
-
async def get_token(self, code: str, url: str):
|
|
708
|
-
payload = {
|
|
709
|
-
"client_id": self.client_id,
|
|
710
|
-
"client_secret": self.client_secret,
|
|
711
|
-
"code": code,
|
|
712
|
-
"grant_type": "authorization_code",
|
|
713
|
-
"redirect_uri": url,
|
|
714
|
-
}
|
|
715
|
-
async with httpx.AsyncClient() as client:
|
|
716
|
-
response = await client.post(
|
|
717
|
-
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
|
|
718
|
-
data=payload,
|
|
719
|
-
)
|
|
720
|
-
response.raise_for_status()
|
|
721
|
-
json = response.json()
|
|
722
|
-
token = json.get("access_token")
|
|
723
|
-
if not token:
|
|
724
|
-
raise httpx.HTTPStatusError(
|
|
725
|
-
"Failed to get the access token",
|
|
726
|
-
request=response.request,
|
|
727
|
-
response=response,
|
|
728
|
-
)
|
|
729
|
-
return token
|
|
730
|
-
|
|
731
|
-
async def get_user_info(self, token: str):
|
|
732
|
-
async with httpx.AsyncClient() as client:
|
|
733
|
-
response = await client.get(
|
|
734
|
-
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
|
|
735
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
736
|
-
)
|
|
737
|
-
response.raise_for_status()
|
|
738
|
-
kc_user = response.json()
|
|
739
|
-
user = User(
|
|
740
|
-
identifier=kc_user["email"],
|
|
741
|
-
metadata={"provider": "keycloak"},
|
|
742
|
-
)
|
|
743
|
-
return (kc_user, user)
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
class GenericOAuthProvider(OAuthProvider):
|
|
747
|
-
env = [
|
|
748
|
-
"OAUTH_GENERIC_CLIENT_ID",
|
|
749
|
-
"OAUTH_GENERIC_CLIENT_SECRET",
|
|
750
|
-
"OAUTH_GENERIC_AUTH_URL",
|
|
751
|
-
"OAUTH_GENERIC_TOKEN_URL",
|
|
752
|
-
"OAUTH_GENERIC_USER_INFO_URL",
|
|
753
|
-
"OAUTH_GENERIC_SCOPES",
|
|
754
|
-
]
|
|
755
|
-
id = os.environ.get("OAUTH_GENERIC_NAME", "generic")
|
|
756
|
-
|
|
757
|
-
def __init__(self):
|
|
758
|
-
self.client_id = os.environ.get("OAUTH_GENERIC_CLIENT_ID")
|
|
759
|
-
self.client_secret = os.environ.get("OAUTH_GENERIC_CLIENT_SECRET")
|
|
760
|
-
self.authorize_url = os.environ.get("OAUTH_GENERIC_AUTH_URL")
|
|
761
|
-
self.token_url = os.environ.get("OAUTH_GENERIC_TOKEN_URL")
|
|
762
|
-
self.user_info_url = os.environ.get("OAUTH_GENERIC_USER_INFO_URL")
|
|
763
|
-
self.scopes = os.environ.get("OAUTH_GENERIC_SCOPES")
|
|
764
|
-
self.user_identifier = os.environ.get("OAUTH_GENERIC_USER_IDENTIFIER", "email")
|
|
765
|
-
|
|
766
|
-
self.authorize_params = {
|
|
767
|
-
"scope": self.scopes,
|
|
768
|
-
"response_type": "code",
|
|
769
|
-
}
|
|
770
|
-
|
|
771
|
-
if prompt := self.get_prompt():
|
|
772
|
-
self.authorize_params["prompt"] = prompt
|
|
773
|
-
|
|
774
|
-
async def get_token(self, code: str, url: str):
|
|
775
|
-
payload = {
|
|
776
|
-
"client_id": self.client_id,
|
|
777
|
-
"client_secret": self.client_secret,
|
|
778
|
-
"code": code,
|
|
779
|
-
"grant_type": "authorization_code",
|
|
780
|
-
"redirect_uri": url,
|
|
781
|
-
}
|
|
782
|
-
async with httpx.AsyncClient() as client:
|
|
783
|
-
response = await client.post(self.token_url, data=payload)
|
|
784
|
-
response.raise_for_status()
|
|
785
|
-
json = response.json()
|
|
786
|
-
token = json.get("access_token")
|
|
787
|
-
if not token:
|
|
788
|
-
raise httpx.HTTPStatusError(
|
|
789
|
-
"Failed to get the access token",
|
|
790
|
-
request=response.request,
|
|
791
|
-
response=response,
|
|
792
|
-
)
|
|
793
|
-
return token
|
|
794
|
-
|
|
795
|
-
async def get_user_info(self, token: str):
|
|
796
|
-
async with httpx.AsyncClient() as client:
|
|
797
|
-
response = await client.get(
|
|
798
|
-
self.user_info_url,
|
|
799
|
-
headers={"Authorization": f"Bearer {token}"},
|
|
800
|
-
)
|
|
801
|
-
response.raise_for_status()
|
|
802
|
-
server_user = response.json()
|
|
803
|
-
user = User(
|
|
804
|
-
identifier=server_user.get(self.user_identifier),
|
|
805
|
-
metadata={
|
|
806
|
-
"provider": self.id,
|
|
807
|
-
},
|
|
808
|
-
)
|
|
809
|
-
return (server_user, user)
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
providers = [
|
|
813
|
-
GithubOAuthProvider(),
|
|
814
|
-
GoogleOAuthProvider(),
|
|
815
|
-
AzureADOAuthProvider(),
|
|
816
|
-
AzureADHybridOAuthProvider(),
|
|
817
|
-
OktaOAuthProvider(),
|
|
818
|
-
Auth0OAuthProvider(),
|
|
819
|
-
DescopeOAuthProvider(),
|
|
820
|
-
AWSCognitoOAuthProvider(),
|
|
821
|
-
GitlabOAuthProvider(),
|
|
822
|
-
KeycloakOAuthProvider(),
|
|
823
|
-
GenericOAuthProvider(),
|
|
824
|
-
]
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
def get_oauth_provider(provider: str) -> Optional[OAuthProvider]:
|
|
828
|
-
for p in providers:
|
|
829
|
-
if p.id == provider:
|
|
830
|
-
return p
|
|
831
|
-
return None
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
def get_configured_oauth_providers():
|
|
835
|
-
return [p.id for p in providers if p.is_configured()]
|