arize-phoenix 4.36.0__py3-none-any.whl → 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +68 -59
- phoenix/__init__.py +86 -0
- phoenix/auth.py +275 -14
- phoenix/config.py +277 -25
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +112 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/models.py +145 -60
- phoenix/experiments/evaluators/code_evaluators.py +9 -3
- phoenix/experiments/functions.py +1 -4
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +32 -0
- phoenix/server/api/context.py +50 -2
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +7 -0
- phoenix/server/api/mutations/__init__.py +0 -2
- phoenix/server/api/mutations/api_key_mutations.py +104 -86
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/project_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/user_mutations.py +282 -42
- phoenix/server/api/openapi/schema.py +2 -2
- phoenix/server/api/queries.py +48 -39
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +456 -0
- phoenix/server/api/routers/v1/__init__.py +38 -16
- phoenix/server/api/types/ApiKey.py +11 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/User.py +48 -4
- phoenix/server/api/types/UserApiKey.py +35 -1
- phoenix/server/api/types/UserRole.py +7 -0
- phoenix/server/app.py +103 -31
- phoenix/server/bearer_auth.py +161 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +26 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +6 -0
- phoenix/server/jwt_store.py +504 -0
- phoenix/server/main.py +40 -9
- phoenix/server/oauth2.py +51 -0
- phoenix/server/prometheus.py +20 -0
- phoenix/server/rate_limiters.py +191 -0
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
- phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
- phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
- phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
- phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
- phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/server/types.py +157 -1
- phoenix/session/client.py +7 -2
- phoenix/utilities/client.py +16 -0
- phoenix/version.py +1 -1
- phoenix/db/migrations/future_versions/README.md +0 -4
- phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
- phoenix/db/migrations/versions/.gitignore +0 -1
- phoenix/server/api/mutations/auth.py +0 -18
- phoenix/server/api/mutations/auth_mutations.py +0 -65
- phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -34
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from datetime import timedelta
|
|
4
|
+
from random import randrange
|
|
5
|
+
from typing import Any, Dict, Optional, Tuple, TypedDict
|
|
6
|
+
from urllib.parse import unquote, urlparse
|
|
7
|
+
|
|
8
|
+
from authlib.common.security import generate_token
|
|
9
|
+
from authlib.integrations.starlette_client import OAuthError
|
|
10
|
+
from authlib.jose import jwt
|
|
11
|
+
from authlib.jose.errors import JoseError
|
|
12
|
+
from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
|
|
13
|
+
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
|
|
14
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
15
|
+
from sqlalchemy.orm import joinedload
|
|
16
|
+
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
|
|
17
|
+
from starlette.datastructures import URL
|
|
18
|
+
from starlette.responses import RedirectResponse
|
|
19
|
+
from starlette.routing import Router
|
|
20
|
+
from starlette.status import HTTP_302_FOUND
|
|
21
|
+
from typing_extensions import Annotated, NotRequired, TypeGuard
|
|
22
|
+
|
|
23
|
+
from phoenix.auth import (
|
|
24
|
+
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
|
|
25
|
+
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
|
|
26
|
+
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
|
|
27
|
+
delete_oauth2_nonce_cookie,
|
|
28
|
+
delete_oauth2_state_cookie,
|
|
29
|
+
set_access_token_cookie,
|
|
30
|
+
set_oauth2_nonce_cookie,
|
|
31
|
+
set_oauth2_state_cookie,
|
|
32
|
+
set_refresh_token_cookie,
|
|
33
|
+
)
|
|
34
|
+
from phoenix.config import get_env_disable_rate_limit
|
|
35
|
+
from phoenix.db import models
|
|
36
|
+
from phoenix.db.enums import UserRole
|
|
37
|
+
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
|
|
38
|
+
from phoenix.server.oauth2 import OAuth2Client
|
|
39
|
+
from phoenix.server.rate_limiters import (
|
|
40
|
+
ServerRateLimiter,
|
|
41
|
+
fastapi_ip_rate_limiter,
|
|
42
|
+
fastapi_route_rate_limiter,
|
|
43
|
+
)
|
|
44
|
+
from phoenix.server.types import TokenStore
|
|
45
|
+
|
|
46
|
+
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
47
|
+
|
|
48
|
+
login_rate_limiter = fastapi_ip_rate_limiter(
|
|
49
|
+
ServerRateLimiter(
|
|
50
|
+
per_second_rate_limit=0.2,
|
|
51
|
+
enforcement_window_seconds=30,
|
|
52
|
+
partition_seconds=60,
|
|
53
|
+
active_partitions=2,
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
create_tokens_rate_limiter = fastapi_route_rate_limiter(
|
|
58
|
+
ServerRateLimiter(
|
|
59
|
+
per_second_rate_limit=0.5,
|
|
60
|
+
enforcement_window_seconds=30,
|
|
61
|
+
partition_seconds=60,
|
|
62
|
+
active_partitions=2,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
router = APIRouter(
|
|
67
|
+
prefix="/oauth2",
|
|
68
|
+
include_in_schema=False,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if not get_env_disable_rate_limit():
|
|
72
|
+
login_dependencies = [Depends(login_rate_limiter)]
|
|
73
|
+
create_tokens_dependencies = [Depends(create_tokens_rate_limiter)]
|
|
74
|
+
else:
|
|
75
|
+
login_dependencies = []
|
|
76
|
+
create_tokens_dependencies = []
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@router.post("/{idp_name}/login", dependencies=login_dependencies)
|
|
80
|
+
async def login(
|
|
81
|
+
request: Request,
|
|
82
|
+
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
83
|
+
return_url: Optional[str] = Query(default=None, alias="returnUrl"),
|
|
84
|
+
) -> RedirectResponse:
|
|
85
|
+
secret = request.app.state.get_secret()
|
|
86
|
+
if not isinstance(
|
|
87
|
+
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
88
|
+
):
|
|
89
|
+
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
|
|
90
|
+
origin_url = _get_origin_url(request)
|
|
91
|
+
authorization_url_data = await oauth2_client.create_authorization_url(
|
|
92
|
+
redirect_uri=_get_create_tokens_endpoint(
|
|
93
|
+
request=request, origin_url=origin_url, idp_name=idp_name
|
|
94
|
+
),
|
|
95
|
+
state=_generate_state_for_oauth2_authorization_code_flow(
|
|
96
|
+
secret=secret, origin_url=origin_url, return_url=return_url
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
assert isinstance(authorization_url := authorization_url_data.get("url"), str)
|
|
100
|
+
assert isinstance(state := authorization_url_data.get("state"), str)
|
|
101
|
+
assert isinstance(nonce := authorization_url_data.get("nonce"), str)
|
|
102
|
+
response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND)
|
|
103
|
+
response = set_oauth2_state_cookie(
|
|
104
|
+
response=response,
|
|
105
|
+
state=state,
|
|
106
|
+
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
107
|
+
)
|
|
108
|
+
response = set_oauth2_nonce_cookie(
|
|
109
|
+
response=response,
|
|
110
|
+
nonce=nonce,
|
|
111
|
+
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
|
|
112
|
+
)
|
|
113
|
+
return response
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@router.get("/{idp_name}/tokens", dependencies=create_tokens_dependencies)
|
|
117
|
+
async def create_tokens(
|
|
118
|
+
request: Request,
|
|
119
|
+
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
120
|
+
state: str = Query(),
|
|
121
|
+
authorization_code: str = Query(alias="code"),
|
|
122
|
+
stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
|
|
123
|
+
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
|
|
124
|
+
) -> RedirectResponse:
|
|
125
|
+
secret = request.app.state.get_secret()
|
|
126
|
+
if state != stored_state:
|
|
127
|
+
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
128
|
+
try:
|
|
129
|
+
payload = _parse_state_payload(secret=secret, state=state)
|
|
130
|
+
except JoseError:
|
|
131
|
+
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
|
|
132
|
+
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
|
|
133
|
+
unquote(return_url)
|
|
134
|
+
):
|
|
135
|
+
return _redirect_to_login(error="Attempting login with unsafe return URL.")
|
|
136
|
+
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
137
|
+
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
138
|
+
token_store: TokenStore = request.app.state.get_token_store()
|
|
139
|
+
if not isinstance(
|
|
140
|
+
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
141
|
+
):
|
|
142
|
+
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
|
|
143
|
+
try:
|
|
144
|
+
token_data = await oauth2_client.fetch_access_token(
|
|
145
|
+
state=state,
|
|
146
|
+
code=authorization_code,
|
|
147
|
+
redirect_uri=_get_create_tokens_endpoint(
|
|
148
|
+
request=request, origin_url=payload["origin_url"], idp_name=idp_name
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
except OAuthError as error:
|
|
152
|
+
return _redirect_to_login(error=str(error))
|
|
153
|
+
_validate_token_data(token_data)
|
|
154
|
+
if "id_token" not in token_data:
|
|
155
|
+
return _redirect_to_login(
|
|
156
|
+
error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect."
|
|
157
|
+
)
|
|
158
|
+
user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
|
|
159
|
+
user_info = _parse_user_info(user_info)
|
|
160
|
+
try:
|
|
161
|
+
async with request.app.state.db() as session:
|
|
162
|
+
user = await _ensure_user_exists_and_is_up_to_date(
|
|
163
|
+
session,
|
|
164
|
+
oauth2_client_id=str(oauth2_client.client_id),
|
|
165
|
+
user_info=user_info,
|
|
166
|
+
)
|
|
167
|
+
except EmailAlreadyInUse as error:
|
|
168
|
+
return _redirect_to_login(error=str(error))
|
|
169
|
+
access_token, refresh_token = await create_access_and_refresh_tokens(
|
|
170
|
+
user=user,
|
|
171
|
+
token_store=token_store,
|
|
172
|
+
access_token_expiry=access_token_expiry,
|
|
173
|
+
refresh_token_expiry=refresh_token_expiry,
|
|
174
|
+
)
|
|
175
|
+
response = RedirectResponse(url=return_url or "/", status_code=HTTP_302_FOUND)
|
|
176
|
+
response = set_access_token_cookie(
|
|
177
|
+
response=response, access_token=access_token, max_age=access_token_expiry
|
|
178
|
+
)
|
|
179
|
+
response = set_refresh_token_cookie(
|
|
180
|
+
response=response, refresh_token=refresh_token, max_age=refresh_token_expiry
|
|
181
|
+
)
|
|
182
|
+
response = delete_oauth2_state_cookie(response)
|
|
183
|
+
response = delete_oauth2_nonce_cookie(response)
|
|
184
|
+
return response
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class UserInfo:
|
|
189
|
+
idp_user_id: str
|
|
190
|
+
email: str
|
|
191
|
+
username: Optional[str]
|
|
192
|
+
profile_picture_url: Optional[str]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _validate_token_data(token_data: Dict[str, Any]) -> None:
|
|
196
|
+
"""
|
|
197
|
+
Performs basic validations on the token data returned by the IDP.
|
|
198
|
+
"""
|
|
199
|
+
assert isinstance(token_data.get("access_token"), str)
|
|
200
|
+
assert isinstance(token_type := token_data.get("token_type"), str)
|
|
201
|
+
assert token_type.lower() == "bearer"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo:
|
|
205
|
+
"""
|
|
206
|
+
Parses user info from the IDP's ID token.
|
|
207
|
+
"""
|
|
208
|
+
assert isinstance(subject := user_info.get("sub"), (str, int))
|
|
209
|
+
idp_user_id = str(subject)
|
|
210
|
+
assert isinstance(email := user_info.get("email"), str)
|
|
211
|
+
assert isinstance(username := user_info.get("name"), str) or username is None
|
|
212
|
+
assert (
|
|
213
|
+
isinstance(profile_picture_url := user_info.get("picture"), str)
|
|
214
|
+
or profile_picture_url is None
|
|
215
|
+
)
|
|
216
|
+
return UserInfo(
|
|
217
|
+
idp_user_id=idp_user_id,
|
|
218
|
+
email=email,
|
|
219
|
+
username=username,
|
|
220
|
+
profile_picture_url=profile_picture_url,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
async def _ensure_user_exists_and_is_up_to_date(
|
|
225
|
+
session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo
|
|
226
|
+
) -> models.User:
|
|
227
|
+
user = await _get_user(
|
|
228
|
+
session,
|
|
229
|
+
oauth2_client_id=oauth2_client_id,
|
|
230
|
+
idp_user_id=user_info.idp_user_id,
|
|
231
|
+
)
|
|
232
|
+
if user is None:
|
|
233
|
+
user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
|
|
234
|
+
elif user.email != user_info.email:
|
|
235
|
+
user = await _update_user_email(session, user_id=user.id, email=user_info.email)
|
|
236
|
+
return user
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
async def _get_user(
|
|
240
|
+
session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str
|
|
241
|
+
) -> Optional[models.User]:
|
|
242
|
+
"""
|
|
243
|
+
Retrieves the user uniquely identified by the given OAuth2 client ID and IDP
|
|
244
|
+
user ID.
|
|
245
|
+
"""
|
|
246
|
+
user = await session.scalar(
|
|
247
|
+
select(models.User)
|
|
248
|
+
.where(
|
|
249
|
+
and_(
|
|
250
|
+
models.User.oauth2_client_id == oauth2_client_id,
|
|
251
|
+
models.User.oauth2_user_id == idp_user_id,
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
.options(joinedload(models.User.role))
|
|
255
|
+
)
|
|
256
|
+
return user
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
async def _create_user(
|
|
260
|
+
session: AsyncSession,
|
|
261
|
+
/,
|
|
262
|
+
*,
|
|
263
|
+
oauth2_client_id: str,
|
|
264
|
+
user_info: UserInfo,
|
|
265
|
+
) -> models.User:
|
|
266
|
+
"""
|
|
267
|
+
Creates a new user with the user info from the IDP.
|
|
268
|
+
"""
|
|
269
|
+
email_exists, username_exists = await _email_and_username_exist(
|
|
270
|
+
session,
|
|
271
|
+
email=(email := user_info.email),
|
|
272
|
+
username=(username := user_info.username),
|
|
273
|
+
)
|
|
274
|
+
if email_exists:
|
|
275
|
+
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
276
|
+
member_role_id = (
|
|
277
|
+
select(models.UserRole.id)
|
|
278
|
+
.where(models.UserRole.name == UserRole.MEMBER.value)
|
|
279
|
+
.scalar_subquery()
|
|
280
|
+
)
|
|
281
|
+
user_id = await session.scalar(
|
|
282
|
+
insert(models.User)
|
|
283
|
+
.returning(models.User.id)
|
|
284
|
+
.values(
|
|
285
|
+
user_role_id=member_role_id,
|
|
286
|
+
oauth2_client_id=oauth2_client_id,
|
|
287
|
+
oauth2_user_id=user_info.idp_user_id,
|
|
288
|
+
username=_with_random_suffix(username) if username and username_exists else username,
|
|
289
|
+
email=email,
|
|
290
|
+
profile_picture_url=user_info.profile_picture_url,
|
|
291
|
+
reset_password=False,
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
assert isinstance(user_id, int)
|
|
295
|
+
user = await session.scalar(
|
|
296
|
+
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|
|
297
|
+
) # query user again for joined load
|
|
298
|
+
assert isinstance(user, models.User)
|
|
299
|
+
return user
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
|
|
303
|
+
"""
|
|
304
|
+
Updates an existing user's email.
|
|
305
|
+
"""
|
|
306
|
+
try:
|
|
307
|
+
await session.execute(
|
|
308
|
+
update(models.User)
|
|
309
|
+
.where(models.User.id == user_id)
|
|
310
|
+
.values(email=email)
|
|
311
|
+
.options(joinedload(models.User.role))
|
|
312
|
+
)
|
|
313
|
+
except IntegrityError:
|
|
314
|
+
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
|
|
315
|
+
user = await session.scalar(
|
|
316
|
+
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
|
|
317
|
+
) # query user again for joined load
|
|
318
|
+
assert isinstance(user, models.User)
|
|
319
|
+
return user
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
async def _email_and_username_exist(
|
|
323
|
+
session: AsyncSession, /, *, email: str, username: Optional[str]
|
|
324
|
+
) -> Tuple[bool, bool]:
|
|
325
|
+
"""
|
|
326
|
+
Checks whether the email and username are already in use.
|
|
327
|
+
"""
|
|
328
|
+
[(email_exists, username_exists)] = (
|
|
329
|
+
await session.execute(
|
|
330
|
+
select(
|
|
331
|
+
cast(
|
|
332
|
+
func.coalesce(
|
|
333
|
+
func.max(case((models.User.email == email, 1), else_=0)),
|
|
334
|
+
0,
|
|
335
|
+
),
|
|
336
|
+
Boolean,
|
|
337
|
+
).label("email_exists"),
|
|
338
|
+
cast(
|
|
339
|
+
func.coalesce(
|
|
340
|
+
func.max(case((models.User.username == username, 1), else_=0)),
|
|
341
|
+
0,
|
|
342
|
+
),
|
|
343
|
+
Boolean,
|
|
344
|
+
).label("username_exists"),
|
|
345
|
+
).where(or_(models.User.email == email, models.User.username == username))
|
|
346
|
+
)
|
|
347
|
+
).all()
|
|
348
|
+
return email_exists, username_exists
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class EmailAlreadyInUse(Exception):
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _redirect_to_login(*, error: str) -> RedirectResponse:
|
|
356
|
+
"""
|
|
357
|
+
Creates a RedirectResponse to the login page to display an error message.
|
|
358
|
+
"""
|
|
359
|
+
url = URL("/login").include_query_params(error=error)
|
|
360
|
+
response = RedirectResponse(url=url)
|
|
361
|
+
response = delete_oauth2_state_cookie(response)
|
|
362
|
+
response = delete_oauth2_nonce_cookie(response)
|
|
363
|
+
return response
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
367
|
+
"""
|
|
368
|
+
Gets the endpoint for create tokens route.
|
|
369
|
+
"""
|
|
370
|
+
router: Router = request.scope["router"]
|
|
371
|
+
url_path = router.url_path_for(create_tokens.__name__, idp_name=idp_name)
|
|
372
|
+
return str(url_path.make_absolute_url(base_url=origin_url))
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _generate_state_for_oauth2_authorization_code_flow(
|
|
376
|
+
*, secret: str, origin_url: str, return_url: Optional[str]
|
|
377
|
+
) -> str:
|
|
378
|
+
"""
|
|
379
|
+
Generates a JWT whose payload contains both an OAuth2 state (generated using
|
|
380
|
+
the `authlib` default algorithm) and a return URL. This allows us to pass
|
|
381
|
+
the return URL to the OAuth2 authorization server via the `state` query
|
|
382
|
+
parameter and have it returned to us in the callback without needing to
|
|
383
|
+
maintain state.
|
|
384
|
+
"""
|
|
385
|
+
header = {"alg": _JWT_ALGORITHM}
|
|
386
|
+
payload = _OAuth2StatePayload(
|
|
387
|
+
random=generate_token(),
|
|
388
|
+
origin_url=origin_url,
|
|
389
|
+
)
|
|
390
|
+
if return_url is not None:
|
|
391
|
+
payload["return_url"] = return_url
|
|
392
|
+
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=secret)
|
|
393
|
+
return jwt_bytes.decode()
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class _OAuth2StatePayload(TypedDict):
|
|
397
|
+
"""
|
|
398
|
+
Represents the OAuth2 state payload.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
random: str
|
|
402
|
+
origin_url: str
|
|
403
|
+
return_url: NotRequired[str]
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _parse_state_payload(*, secret: str, state: str) -> _OAuth2StatePayload:
|
|
407
|
+
"""
|
|
408
|
+
Validates the JWT signature and parses the return URL from the OAuth2 state.
|
|
409
|
+
"""
|
|
410
|
+
payload = jwt.decode(s=state, key=secret)
|
|
411
|
+
if _is_oauth2_state_payload(payload):
|
|
412
|
+
return payload
|
|
413
|
+
raise ValueError("Invalid OAuth2 state payload.")
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _is_relative_url(url: str) -> bool:
|
|
417
|
+
"""
|
|
418
|
+
Determines whether the URL is relative.
|
|
419
|
+
"""
|
|
420
|
+
return bool(_RELATIVE_URL_PATTERN.match(url))
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _with_random_suffix(string: str) -> str:
|
|
424
|
+
"""
|
|
425
|
+
Appends a random suffix.
|
|
426
|
+
"""
|
|
427
|
+
return f"{string}-{randrange(10_000, 100_000)}"
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _get_origin_url(request: Request) -> str:
|
|
431
|
+
"""
|
|
432
|
+
Infers the origin URL from the request.
|
|
433
|
+
"""
|
|
434
|
+
if (referer := request.headers.get("referer")) is None:
|
|
435
|
+
return str(request.base_url)
|
|
436
|
+
parsed_url = urlparse(referer)
|
|
437
|
+
return f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2StatePayload]:
|
|
441
|
+
"""
|
|
442
|
+
Determines whether the given object is an OAuth2 state payload.
|
|
443
|
+
"""
|
|
444
|
+
|
|
445
|
+
return (
|
|
446
|
+
isinstance(maybe_state_payload, dict)
|
|
447
|
+
and {"random", "origin_url"}.issubset((keys := set(maybe_state_payload.keys())))
|
|
448
|
+
and keys.issubset({"random", "origin_url", "return_url"})
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
_JWT_ALGORITHM = "HS256"
|
|
453
|
+
_INVALID_OAUTH2_STATE_MESSAGE = (
|
|
454
|
+
"Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
|
|
455
|
+
)
|
|
456
|
+
_RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
2
|
+
from fastapi.security import APIKeyHeader
|
|
2
3
|
from starlette.status import HTTP_403_FORBIDDEN
|
|
3
4
|
|
|
5
|
+
from phoenix.server.bearer_auth import is_authenticated
|
|
6
|
+
|
|
4
7
|
from .datasets import router as datasets_router
|
|
5
8
|
from .evaluations import router as evaluations_router
|
|
6
9
|
from .experiment_evaluations import router as experiment_evaluations_router
|
|
@@ -24,19 +27,38 @@ async def prevent_access_in_read_only_mode(request: Request) -> None:
|
|
|
24
27
|
)
|
|
25
28
|
|
|
26
29
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
30
|
+
def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
31
|
+
"""
|
|
32
|
+
Instantiates the v1 REST API router.
|
|
33
|
+
"""
|
|
34
|
+
dependencies = [Depends(prevent_access_in_read_only_mode)]
|
|
35
|
+
if authentication_enabled:
|
|
36
|
+
dependencies.append(
|
|
37
|
+
Depends(
|
|
38
|
+
APIKeyHeader(
|
|
39
|
+
name="Authorization",
|
|
40
|
+
scheme_name="Bearer",
|
|
41
|
+
auto_error=False,
|
|
42
|
+
description="Enter `Bearer` followed by a space and then the token.",
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
dependencies.append(Depends(is_authenticated))
|
|
47
|
+
|
|
48
|
+
router = APIRouter(
|
|
49
|
+
prefix="/v1",
|
|
50
|
+
dependencies=dependencies,
|
|
51
|
+
responses=add_errors_to_responses(
|
|
52
|
+
[
|
|
53
|
+
HTTP_403_FORBIDDEN # adds a 403 response to routes in the generated OpenAPI schema
|
|
54
|
+
]
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
router.include_router(datasets_router)
|
|
58
|
+
router.include_router(experiments_router)
|
|
59
|
+
router.include_router(experiment_runs_router)
|
|
60
|
+
router.include_router(experiment_evaluations_router)
|
|
61
|
+
router.include_router(traces_router)
|
|
62
|
+
router.include_router(spans_router)
|
|
63
|
+
router.include_router(evaluations_router)
|
|
64
|
+
return router
|
|
@@ -3,6 +3,8 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
|
|
6
|
+
from phoenix.db.models import ApiKey as ORMApiKey
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
@strawberry.interface
|
|
8
10
|
class ApiKey:
|
|
@@ -14,3 +16,12 @@ class ApiKey:
|
|
|
14
16
|
expires_at: Optional[datetime] = strawberry.field(
|
|
15
17
|
description="The date and time the API key will expire."
|
|
16
18
|
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
|
|
22
|
+
return ApiKey(
|
|
23
|
+
name=api_key.name,
|
|
24
|
+
description=api_key.description,
|
|
25
|
+
created_at=api_key.created_at,
|
|
26
|
+
expires_at=api_key.expires_at,
|
|
27
|
+
)
|
phoenix/server/api/types/User.py
CHANGED
|
@@ -1,16 +1,60 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from strawberry import Private
|
|
5
7
|
from strawberry.relay import Node, NodeID
|
|
8
|
+
from strawberry.types import Info
|
|
6
9
|
|
|
7
|
-
from .
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.context import Context
|
|
12
|
+
from phoenix.server.api.exceptions import NotFound
|
|
13
|
+
from phoenix.server.api.types.AuthMethod import AuthMethod
|
|
14
|
+
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
|
|
15
|
+
|
|
16
|
+
from .UserRole import UserRole, to_gql_user_role
|
|
8
17
|
|
|
9
18
|
|
|
10
19
|
@strawberry.type
|
|
11
20
|
class User(Node):
|
|
12
21
|
id_attr: NodeID[int]
|
|
22
|
+
password_needs_reset: bool
|
|
13
23
|
email: str
|
|
14
|
-
username:
|
|
24
|
+
username: str
|
|
25
|
+
profile_picture_url: Optional[str]
|
|
15
26
|
created_at: datetime
|
|
16
|
-
|
|
27
|
+
user_role_id: Private[int]
|
|
28
|
+
auth_method: AuthMethod
|
|
29
|
+
|
|
30
|
+
@strawberry.field
|
|
31
|
+
async def role(self, info: Info[Context, None]) -> UserRole:
|
|
32
|
+
role = await info.context.data_loaders.user_roles.load(self.user_role_id)
|
|
33
|
+
if role is None:
|
|
34
|
+
raise NotFound(f"User role with id {self.user_role_id} not found")
|
|
35
|
+
return to_gql_user_role(role)
|
|
36
|
+
|
|
37
|
+
@strawberry.field
|
|
38
|
+
async def api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
|
|
39
|
+
async with info.context.db() as session:
|
|
40
|
+
api_keys = await session.scalars(
|
|
41
|
+
select(models.ApiKey).where(models.ApiKey.user_id == self.id_attr)
|
|
42
|
+
)
|
|
43
|
+
return [to_gql_api_key(api_key) for api_key in api_keys]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = None) -> User:
|
|
47
|
+
"""
|
|
48
|
+
Converts an ORM user to a GraphQL user.
|
|
49
|
+
"""
|
|
50
|
+
assert user.auth_method is not None
|
|
51
|
+
return User(
|
|
52
|
+
id_attr=user.id,
|
|
53
|
+
password_needs_reset=user.reset_password,
|
|
54
|
+
username=user.username,
|
|
55
|
+
email=user.email,
|
|
56
|
+
profile_picture_url=user.profile_picture_url,
|
|
57
|
+
created_at=user.created_at,
|
|
58
|
+
user_role_id=user.user_role_id,
|
|
59
|
+
auth_method=AuthMethod(user.auth_method),
|
|
60
|
+
)
|
|
@@ -1,11 +1,45 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
1
3
|
import strawberry
|
|
2
4
|
from strawberry import Private
|
|
3
|
-
from strawberry.relay
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
from typing_extensions import Annotated
|
|
8
|
+
|
|
9
|
+
from phoenix.db.models import ApiKey as OrmApiKey
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.exceptions import NotFound
|
|
4
12
|
|
|
5
13
|
from .ApiKey import ApiKey
|
|
6
14
|
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .User import User
|
|
17
|
+
|
|
7
18
|
|
|
8
19
|
@strawberry.type
|
|
9
20
|
class UserApiKey(ApiKey, Node):
|
|
10
21
|
id_attr: NodeID[int]
|
|
11
22
|
user_id: Private[int]
|
|
23
|
+
|
|
24
|
+
@strawberry.field
|
|
25
|
+
async def user(self, info: Info[Context, None]) -> Annotated["User", strawberry.lazy(".User")]:
|
|
26
|
+
user = await info.context.data_loaders.users.load(self.user_id)
|
|
27
|
+
if user is None:
|
|
28
|
+
raise NotFound(f"User with id {self.user_id} not found")
|
|
29
|
+
from .User import to_gql_user
|
|
30
|
+
|
|
31
|
+
return to_gql_user(user)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def to_gql_api_key(api_key: OrmApiKey) -> UserApiKey:
|
|
35
|
+
"""
|
|
36
|
+
Converts an ORM API key to a GraphQL UserApiKey type.
|
|
37
|
+
"""
|
|
38
|
+
return UserApiKey(
|
|
39
|
+
id_attr=api_key.id,
|
|
40
|
+
user_id=api_key.user_id,
|
|
41
|
+
name=api_key.name,
|
|
42
|
+
description=api_key.description,
|
|
43
|
+
created_at=api_key.created_at,
|
|
44
|
+
expires_at=api_key.expires_at,
|
|
45
|
+
)
|
|
@@ -1,8 +1,15 @@
|
|
|
1
1
|
import strawberry
|
|
2
2
|
from strawberry.relay import Node, NodeID
|
|
3
3
|
|
|
4
|
+
from phoenix.db import models
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
@strawberry.type
|
|
6
8
|
class UserRole(Node):
|
|
7
9
|
id_attr: NodeID[int]
|
|
8
10
|
name: str
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def to_gql_user_role(role: models.UserRole) -> UserRole:
|
|
14
|
+
"""Convert an ORM user role to a GraphQL user role."""
|
|
15
|
+
return UserRole(id_attr=role.id, name=role.name)
|