arize-phoenix 4.36.0__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (81) hide show
  1. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/METADATA +10 -12
  2. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/RECORD +69 -60
  3. phoenix/__init__.py +86 -0
  4. phoenix/auth.py +275 -14
  5. phoenix/config.py +277 -25
  6. phoenix/db/enums.py +20 -0
  7. phoenix/db/facilitator.py +112 -0
  8. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  9. phoenix/db/models.py +145 -60
  10. phoenix/experiments/evaluators/code_evaluators.py +9 -3
  11. phoenix/experiments/functions.py +1 -4
  12. phoenix/server/api/README.md +28 -0
  13. phoenix/server/api/auth.py +32 -0
  14. phoenix/server/api/context.py +50 -2
  15. phoenix/server/api/dataloaders/__init__.py +4 -0
  16. phoenix/server/api/dataloaders/user_roles.py +30 -0
  17. phoenix/server/api/dataloaders/users.py +33 -0
  18. phoenix/server/api/exceptions.py +7 -0
  19. phoenix/server/api/mutations/__init__.py +0 -2
  20. phoenix/server/api/mutations/api_key_mutations.py +104 -86
  21. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  22. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  23. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  24. phoenix/server/api/mutations/project_mutations.py +3 -3
  25. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  26. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  27. phoenix/server/api/mutations/user_mutations.py +282 -42
  28. phoenix/server/api/openapi/schema.py +2 -2
  29. phoenix/server/api/queries.py +48 -39
  30. phoenix/server/api/routers/__init__.py +11 -0
  31. phoenix/server/api/routers/auth.py +284 -0
  32. phoenix/server/api/routers/embeddings.py +26 -0
  33. phoenix/server/api/routers/oauth2.py +456 -0
  34. phoenix/server/api/routers/v1/__init__.py +38 -16
  35. phoenix/server/api/types/ApiKey.py +11 -0
  36. phoenix/server/api/types/AuthMethod.py +9 -0
  37. phoenix/server/api/types/User.py +48 -4
  38. phoenix/server/api/types/UserApiKey.py +35 -1
  39. phoenix/server/api/types/UserRole.py +7 -0
  40. phoenix/server/app.py +103 -31
  41. phoenix/server/bearer_auth.py +161 -0
  42. phoenix/server/email/__init__.py +0 -0
  43. phoenix/server/email/sender.py +26 -0
  44. phoenix/server/email/templates/__init__.py +0 -0
  45. phoenix/server/email/templates/password_reset.html +19 -0
  46. phoenix/server/email/types.py +11 -0
  47. phoenix/server/grpc_server.py +6 -0
  48. phoenix/server/jwt_store.py +504 -0
  49. phoenix/server/main.py +40 -9
  50. phoenix/server/oauth2.py +51 -0
  51. phoenix/server/prometheus.py +20 -0
  52. phoenix/server/rate_limiters.py +191 -0
  53. phoenix/server/static/.vite/manifest.json +31 -31
  54. phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
  55. phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
  56. phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
  57. phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
  58. phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
  59. phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
  60. phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
  61. phoenix/server/templates/index.html +1 -0
  62. phoenix/server/types.py +157 -1
  63. phoenix/session/client.py +7 -2
  64. phoenix/trace/fixtures.py +24 -0
  65. phoenix/utilities/client.py +16 -0
  66. phoenix/version.py +1 -1
  67. phoenix/db/migrations/future_versions/README.md +0 -4
  68. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
  69. phoenix/db/migrations/versions/.gitignore +0 -1
  70. phoenix/server/api/mutations/auth.py +0 -18
  71. phoenix/server/api/mutations/auth_mutations.py +0 -65
  72. phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
  73. phoenix/trace/langchain/__init__.py +0 -3
  74. phoenix/trace/langchain/instrumentor.py +0 -34
  75. phoenix/trace/llama_index/__init__.py +0 -3
  76. phoenix/trace/llama_index/callback.py +0 -102
  77. phoenix/trace/openai/__init__.py +0 -3
  78. phoenix/trace/openai/instrumentor.py +0 -30
  79. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/WHEEL +0 -0
  80. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/IP_NOTICE +0 -0
  81. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,456 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from datetime import timedelta
4
+ from random import randrange
5
+ from typing import Any, Dict, Optional, Tuple, TypedDict
6
+ from urllib.parse import unquote, urlparse
7
+
8
+ from authlib.common.security import generate_token
9
+ from authlib.integrations.starlette_client import OAuthError
10
+ from authlib.jose import jwt
11
+ from authlib.jose.errors import JoseError
12
+ from fastapi import APIRouter, Cookie, Depends, Path, Query, Request
13
+ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
14
+ from sqlalchemy.ext.asyncio import AsyncSession
15
+ from sqlalchemy.orm import joinedload
16
+ from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
17
+ from starlette.datastructures import URL
18
+ from starlette.responses import RedirectResponse
19
+ from starlette.routing import Router
20
+ from starlette.status import HTTP_302_FOUND
21
+ from typing_extensions import Annotated, NotRequired, TypeGuard
22
+
23
+ from phoenix.auth import (
24
+ DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
25
+ PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
26
+ PHOENIX_OAUTH2_STATE_COOKIE_NAME,
27
+ delete_oauth2_nonce_cookie,
28
+ delete_oauth2_state_cookie,
29
+ set_access_token_cookie,
30
+ set_oauth2_nonce_cookie,
31
+ set_oauth2_state_cookie,
32
+ set_refresh_token_cookie,
33
+ )
34
+ from phoenix.config import get_env_disable_rate_limit
35
+ from phoenix.db import models
36
+ from phoenix.db.enums import UserRole
37
+ from phoenix.server.bearer_auth import create_access_and_refresh_tokens
38
+ from phoenix.server.oauth2 import OAuth2Client
39
+ from phoenix.server.rate_limiters import (
40
+ ServerRateLimiter,
41
+ fastapi_ip_rate_limiter,
42
+ fastapi_route_rate_limiter,
43
+ )
44
+ from phoenix.server.types import TokenStore
45
+
46
+ _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
47
+
48
+ login_rate_limiter = fastapi_ip_rate_limiter(
49
+ ServerRateLimiter(
50
+ per_second_rate_limit=0.2,
51
+ enforcement_window_seconds=30,
52
+ partition_seconds=60,
53
+ active_partitions=2,
54
+ ),
55
+ )
56
+
57
+ create_tokens_rate_limiter = fastapi_route_rate_limiter(
58
+ ServerRateLimiter(
59
+ per_second_rate_limit=0.5,
60
+ enforcement_window_seconds=30,
61
+ partition_seconds=60,
62
+ active_partitions=2,
63
+ )
64
+ )
65
+
66
+ router = APIRouter(
67
+ prefix="/oauth2",
68
+ include_in_schema=False,
69
+ )
70
+
71
+ if not get_env_disable_rate_limit():
72
+ login_dependencies = [Depends(login_rate_limiter)]
73
+ create_tokens_dependencies = [Depends(create_tokens_rate_limiter)]
74
+ else:
75
+ login_dependencies = []
76
+ create_tokens_dependencies = []
77
+
78
+
79
+ @router.post("/{idp_name}/login", dependencies=login_dependencies)
80
+ async def login(
81
+ request: Request,
82
+ idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
83
+ return_url: Optional[str] = Query(default=None, alias="returnUrl"),
84
+ ) -> RedirectResponse:
85
+ secret = request.app.state.get_secret()
86
+ if not isinstance(
87
+ oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
88
+ ):
89
+ return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
90
+ origin_url = _get_origin_url(request)
91
+ authorization_url_data = await oauth2_client.create_authorization_url(
92
+ redirect_uri=_get_create_tokens_endpoint(
93
+ request=request, origin_url=origin_url, idp_name=idp_name
94
+ ),
95
+ state=_generate_state_for_oauth2_authorization_code_flow(
96
+ secret=secret, origin_url=origin_url, return_url=return_url
97
+ ),
98
+ )
99
+ assert isinstance(authorization_url := authorization_url_data.get("url"), str)
100
+ assert isinstance(state := authorization_url_data.get("state"), str)
101
+ assert isinstance(nonce := authorization_url_data.get("nonce"), str)
102
+ response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND)
103
+ response = set_oauth2_state_cookie(
104
+ response=response,
105
+ state=state,
106
+ max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
107
+ )
108
+ response = set_oauth2_nonce_cookie(
109
+ response=response,
110
+ nonce=nonce,
111
+ max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
112
+ )
113
+ return response
114
+
115
+
116
+ @router.get("/{idp_name}/tokens", dependencies=create_tokens_dependencies)
117
+ async def create_tokens(
118
+ request: Request,
119
+ idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
120
+ state: str = Query(),
121
+ authorization_code: str = Query(alias="code"),
122
+ stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
123
+ stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
124
+ ) -> RedirectResponse:
125
+ secret = request.app.state.get_secret()
126
+ if state != stored_state:
127
+ return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
128
+ try:
129
+ payload = _parse_state_payload(secret=secret, state=state)
130
+ except JoseError:
131
+ return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
132
+ if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
133
+ unquote(return_url)
134
+ ):
135
+ return _redirect_to_login(error="Attempting login with unsafe return URL.")
136
+ assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
137
+ assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
138
+ token_store: TokenStore = request.app.state.get_token_store()
139
+ if not isinstance(
140
+ oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
141
+ ):
142
+ return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
143
+ try:
144
+ token_data = await oauth2_client.fetch_access_token(
145
+ state=state,
146
+ code=authorization_code,
147
+ redirect_uri=_get_create_tokens_endpoint(
148
+ request=request, origin_url=payload["origin_url"], idp_name=idp_name
149
+ ),
150
+ )
151
+ except OAuthError as error:
152
+ return _redirect_to_login(error=str(error))
153
+ _validate_token_data(token_data)
154
+ if "id_token" not in token_data:
155
+ return _redirect_to_login(
156
+ error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect."
157
+ )
158
+ user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
159
+ user_info = _parse_user_info(user_info)
160
+ try:
161
+ async with request.app.state.db() as session:
162
+ user = await _ensure_user_exists_and_is_up_to_date(
163
+ session,
164
+ oauth2_client_id=str(oauth2_client.client_id),
165
+ user_info=user_info,
166
+ )
167
+ except EmailAlreadyInUse as error:
168
+ return _redirect_to_login(error=str(error))
169
+ access_token, refresh_token = await create_access_and_refresh_tokens(
170
+ user=user,
171
+ token_store=token_store,
172
+ access_token_expiry=access_token_expiry,
173
+ refresh_token_expiry=refresh_token_expiry,
174
+ )
175
+ response = RedirectResponse(url=return_url or "/", status_code=HTTP_302_FOUND)
176
+ response = set_access_token_cookie(
177
+ response=response, access_token=access_token, max_age=access_token_expiry
178
+ )
179
+ response = set_refresh_token_cookie(
180
+ response=response, refresh_token=refresh_token, max_age=refresh_token_expiry
181
+ )
182
+ response = delete_oauth2_state_cookie(response)
183
+ response = delete_oauth2_nonce_cookie(response)
184
+ return response
185
+
186
+
187
+ @dataclass
188
+ class UserInfo:
189
+ idp_user_id: str
190
+ email: str
191
+ username: Optional[str]
192
+ profile_picture_url: Optional[str]
193
+
194
+
195
+ def _validate_token_data(token_data: Dict[str, Any]) -> None:
196
+ """
197
+ Performs basic validations on the token data returned by the IDP.
198
+ """
199
+ assert isinstance(token_data.get("access_token"), str)
200
+ assert isinstance(token_type := token_data.get("token_type"), str)
201
+ assert token_type.lower() == "bearer"
202
+
203
+
204
+ def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo:
205
+ """
206
+ Parses user info from the IDP's ID token.
207
+ """
208
+ assert isinstance(subject := user_info.get("sub"), (str, int))
209
+ idp_user_id = str(subject)
210
+ assert isinstance(email := user_info.get("email"), str)
211
+ assert isinstance(username := user_info.get("name"), str) or username is None
212
+ assert (
213
+ isinstance(profile_picture_url := user_info.get("picture"), str)
214
+ or profile_picture_url is None
215
+ )
216
+ return UserInfo(
217
+ idp_user_id=idp_user_id,
218
+ email=email,
219
+ username=username,
220
+ profile_picture_url=profile_picture_url,
221
+ )
222
+
223
+
224
+ async def _ensure_user_exists_and_is_up_to_date(
225
+ session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo
226
+ ) -> models.User:
227
+ user = await _get_user(
228
+ session,
229
+ oauth2_client_id=oauth2_client_id,
230
+ idp_user_id=user_info.idp_user_id,
231
+ )
232
+ if user is None:
233
+ user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
234
+ elif user.email != user_info.email:
235
+ user = await _update_user_email(session, user_id=user.id, email=user_info.email)
236
+ return user
237
+
238
+
239
+ async def _get_user(
240
+ session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str
241
+ ) -> Optional[models.User]:
242
+ """
243
+ Retrieves the user uniquely identified by the given OAuth2 client ID and IDP
244
+ user ID.
245
+ """
246
+ user = await session.scalar(
247
+ select(models.User)
248
+ .where(
249
+ and_(
250
+ models.User.oauth2_client_id == oauth2_client_id,
251
+ models.User.oauth2_user_id == idp_user_id,
252
+ )
253
+ )
254
+ .options(joinedload(models.User.role))
255
+ )
256
+ return user
257
+
258
+
259
+ async def _create_user(
260
+ session: AsyncSession,
261
+ /,
262
+ *,
263
+ oauth2_client_id: str,
264
+ user_info: UserInfo,
265
+ ) -> models.User:
266
+ """
267
+ Creates a new user with the user info from the IDP.
268
+ """
269
+ email_exists, username_exists = await _email_and_username_exist(
270
+ session,
271
+ email=(email := user_info.email),
272
+ username=(username := user_info.username),
273
+ )
274
+ if email_exists:
275
+ raise EmailAlreadyInUse(f"An account for {email} is already in use.")
276
+ member_role_id = (
277
+ select(models.UserRole.id)
278
+ .where(models.UserRole.name == UserRole.MEMBER.value)
279
+ .scalar_subquery()
280
+ )
281
+ user_id = await session.scalar(
282
+ insert(models.User)
283
+ .returning(models.User.id)
284
+ .values(
285
+ user_role_id=member_role_id,
286
+ oauth2_client_id=oauth2_client_id,
287
+ oauth2_user_id=user_info.idp_user_id,
288
+ username=_with_random_suffix(username) if username and username_exists else username,
289
+ email=email,
290
+ profile_picture_url=user_info.profile_picture_url,
291
+ reset_password=False,
292
+ )
293
+ )
294
+ assert isinstance(user_id, int)
295
+ user = await session.scalar(
296
+ select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
297
+ ) # query user again for joined load
298
+ assert isinstance(user, models.User)
299
+ return user
300
+
301
+
302
+ async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
303
+ """
304
+ Updates an existing user's email.
305
+ """
306
+ try:
307
+ await session.execute(
308
+ update(models.User)
309
+ .where(models.User.id == user_id)
310
+ .values(email=email)
311
+ .options(joinedload(models.User.role))
312
+ )
313
+ except IntegrityError:
314
+ raise EmailAlreadyInUse(f"An account for {email} is already in use.")
315
+ user = await session.scalar(
316
+ select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
317
+ ) # query user again for joined load
318
+ assert isinstance(user, models.User)
319
+ return user
320
+
321
+
322
+ async def _email_and_username_exist(
323
+ session: AsyncSession, /, *, email: str, username: Optional[str]
324
+ ) -> Tuple[bool, bool]:
325
+ """
326
+ Checks whether the email and username are already in use.
327
+ """
328
+ [(email_exists, username_exists)] = (
329
+ await session.execute(
330
+ select(
331
+ cast(
332
+ func.coalesce(
333
+ func.max(case((models.User.email == email, 1), else_=0)),
334
+ 0,
335
+ ),
336
+ Boolean,
337
+ ).label("email_exists"),
338
+ cast(
339
+ func.coalesce(
340
+ func.max(case((models.User.username == username, 1), else_=0)),
341
+ 0,
342
+ ),
343
+ Boolean,
344
+ ).label("username_exists"),
345
+ ).where(or_(models.User.email == email, models.User.username == username))
346
+ )
347
+ ).all()
348
+ return email_exists, username_exists
349
+
350
+
351
+ class EmailAlreadyInUse(Exception):
352
+ pass
353
+
354
+
355
+ def _redirect_to_login(*, error: str) -> RedirectResponse:
356
+ """
357
+ Creates a RedirectResponse to the login page to display an error message.
358
+ """
359
+ url = URL("/login").include_query_params(error=error)
360
+ response = RedirectResponse(url=url)
361
+ response = delete_oauth2_state_cookie(response)
362
+ response = delete_oauth2_nonce_cookie(response)
363
+ return response
364
+
365
+
366
+ def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
367
+ """
368
+ Gets the endpoint for create tokens route.
369
+ """
370
+ router: Router = request.scope["router"]
371
+ url_path = router.url_path_for(create_tokens.__name__, idp_name=idp_name)
372
+ return str(url_path.make_absolute_url(base_url=origin_url))
373
+
374
+
375
+ def _generate_state_for_oauth2_authorization_code_flow(
376
+ *, secret: str, origin_url: str, return_url: Optional[str]
377
+ ) -> str:
378
+ """
379
+ Generates a JWT whose payload contains both an OAuth2 state (generated using
380
+ the `authlib` default algorithm) and a return URL. This allows us to pass
381
+ the return URL to the OAuth2 authorization server via the `state` query
382
+ parameter and have it returned to us in the callback without needing to
383
+ maintain state.
384
+ """
385
+ header = {"alg": _JWT_ALGORITHM}
386
+ payload = _OAuth2StatePayload(
387
+ random=generate_token(),
388
+ origin_url=origin_url,
389
+ )
390
+ if return_url is not None:
391
+ payload["return_url"] = return_url
392
+ jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=secret)
393
+ return jwt_bytes.decode()
394
+
395
+
396
+ class _OAuth2StatePayload(TypedDict):
397
+ """
398
+ Represents the OAuth2 state payload.
399
+ """
400
+
401
+ random: str
402
+ origin_url: str
403
+ return_url: NotRequired[str]
404
+
405
+
406
+ def _parse_state_payload(*, secret: str, state: str) -> _OAuth2StatePayload:
407
+ """
408
+ Validates the JWT signature and parses the return URL from the OAuth2 state.
409
+ """
410
+ payload = jwt.decode(s=state, key=secret)
411
+ if _is_oauth2_state_payload(payload):
412
+ return payload
413
+ raise ValueError("Invalid OAuth2 state payload.")
414
+
415
+
416
+ def _is_relative_url(url: str) -> bool:
417
+ """
418
+ Determines whether the URL is relative.
419
+ """
420
+ return bool(_RELATIVE_URL_PATTERN.match(url))
421
+
422
+
423
+ def _with_random_suffix(string: str) -> str:
424
+ """
425
+ Appends a random suffix.
426
+ """
427
+ return f"{string}-{randrange(10_000, 100_000)}"
428
+
429
+
430
+ def _get_origin_url(request: Request) -> str:
431
+ """
432
+ Infers the origin URL from the request.
433
+ """
434
+ if (referer := request.headers.get("referer")) is None:
435
+ return str(request.base_url)
436
+ parsed_url = urlparse(referer)
437
+ return f"{parsed_url.scheme}://{parsed_url.netloc}"
438
+
439
+
440
+ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2StatePayload]:
441
+ """
442
+ Determines whether the given object is an OAuth2 state payload.
443
+ """
444
+
445
+ return (
446
+ isinstance(maybe_state_payload, dict)
447
+ and {"random", "origin_url"}.issubset((keys := set(maybe_state_payload.keys())))
448
+ and keys.issubset({"random", "origin_url", "return_url"})
449
+ )
450
+
451
+
452
+ _JWT_ALGORITHM = "HS256"
453
+ _INVALID_OAUTH2_STATE_MESSAGE = (
454
+ "Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
455
+ )
456
+ _RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")
@@ -1,6 +1,9 @@
1
1
  from fastapi import APIRouter, Depends, HTTPException, Request
2
+ from fastapi.security import APIKeyHeader
2
3
  from starlette.status import HTTP_403_FORBIDDEN
3
4
 
5
+ from phoenix.server.bearer_auth import is_authenticated
6
+
4
7
  from .datasets import router as datasets_router
5
8
  from .evaluations import router as evaluations_router
6
9
  from .experiment_evaluations import router as experiment_evaluations_router
@@ -24,19 +27,38 @@ async def prevent_access_in_read_only_mode(request: Request) -> None:
24
27
  )
25
28
 
26
29
 
27
- router = APIRouter(
28
- prefix="/v1",
29
- dependencies=[Depends(prevent_access_in_read_only_mode)],
30
- responses=add_errors_to_responses(
31
- [
32
- HTTP_403_FORBIDDEN # adds a 403 response to each route in the generated OpenAPI schema
33
- ]
34
- ),
35
- )
36
- router.include_router(datasets_router)
37
- router.include_router(experiments_router)
38
- router.include_router(experiment_runs_router)
39
- router.include_router(experiment_evaluations_router)
40
- router.include_router(traces_router)
41
- router.include_router(spans_router)
42
- router.include_router(evaluations_router)
30
+ def create_v1_router(authentication_enabled: bool) -> APIRouter:
31
+ """
32
+ Instantiates the v1 REST API router.
33
+ """
34
+ dependencies = [Depends(prevent_access_in_read_only_mode)]
35
+ if authentication_enabled:
36
+ dependencies.append(
37
+ Depends(
38
+ APIKeyHeader(
39
+ name="Authorization",
40
+ scheme_name="Bearer",
41
+ auto_error=False,
42
+ description="Enter `Bearer` followed by a space and then the token.",
43
+ )
44
+ )
45
+ )
46
+ dependencies.append(Depends(is_authenticated))
47
+
48
+ router = APIRouter(
49
+ prefix="/v1",
50
+ dependencies=dependencies,
51
+ responses=add_errors_to_responses(
52
+ [
53
+ HTTP_403_FORBIDDEN # adds a 403 response to routes in the generated OpenAPI schema
54
+ ]
55
+ ),
56
+ )
57
+ router.include_router(datasets_router)
58
+ router.include_router(experiments_router)
59
+ router.include_router(experiment_runs_router)
60
+ router.include_router(experiment_evaluations_router)
61
+ router.include_router(traces_router)
62
+ router.include_router(spans_router)
63
+ router.include_router(evaluations_router)
64
+ return router
@@ -3,6 +3,8 @@ from typing import Optional
3
3
 
4
4
  import strawberry
5
5
 
6
+ from phoenix.db.models import ApiKey as ORMApiKey
7
+
6
8
 
7
9
  @strawberry.interface
8
10
  class ApiKey:
@@ -14,3 +16,12 @@ class ApiKey:
14
16
  expires_at: Optional[datetime] = strawberry.field(
15
17
  description="The date and time the API key will expire."
16
18
  )
19
+
20
+
21
+ def to_gql_api_key(api_key: ORMApiKey) -> ApiKey:
22
+ return ApiKey(
23
+ name=api_key.name,
24
+ description=api_key.description,
25
+ created_at=api_key.created_at,
26
+ expires_at=api_key.expires_at,
27
+ )
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class AuthMethod(Enum):
8
+ LOCAL = "LOCAL"
9
+ OAUTH2 = "OAUTH2"
@@ -1,16 +1,60 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import List, Optional
3
3
 
4
4
  import strawberry
5
+ from sqlalchemy import select
6
+ from strawberry import Private
5
7
  from strawberry.relay import Node, NodeID
8
+ from strawberry.types import Info
6
9
 
7
- from .UserRole import UserRole
10
+ from phoenix.db import models
11
+ from phoenix.server.api.context import Context
12
+ from phoenix.server.api.exceptions import NotFound
13
+ from phoenix.server.api.types.AuthMethod import AuthMethod
14
+ from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
15
+
16
+ from .UserRole import UserRole, to_gql_user_role
8
17
 
9
18
 
10
19
  @strawberry.type
11
20
  class User(Node):
12
21
  id_attr: NodeID[int]
22
+ password_needs_reset: bool
13
23
  email: str
14
- username: Optional[str]
24
+ username: str
25
+ profile_picture_url: Optional[str]
15
26
  created_at: datetime
16
- role: UserRole
27
+ user_role_id: Private[int]
28
+ auth_method: AuthMethod
29
+
30
+ @strawberry.field
31
+ async def role(self, info: Info[Context, None]) -> UserRole:
32
+ role = await info.context.data_loaders.user_roles.load(self.user_role_id)
33
+ if role is None:
34
+ raise NotFound(f"User role with id {self.user_role_id} not found")
35
+ return to_gql_user_role(role)
36
+
37
+ @strawberry.field
38
+ async def api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
39
+ async with info.context.db() as session:
40
+ api_keys = await session.scalars(
41
+ select(models.ApiKey).where(models.ApiKey.user_id == self.id_attr)
42
+ )
43
+ return [to_gql_api_key(api_key) for api_key in api_keys]
44
+
45
+
46
+ def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = None) -> User:
47
+ """
48
+ Converts an ORM user to a GraphQL user.
49
+ """
50
+ assert user.auth_method is not None
51
+ return User(
52
+ id_attr=user.id,
53
+ password_needs_reset=user.reset_password,
54
+ username=user.username,
55
+ email=user.email,
56
+ profile_picture_url=user.profile_picture_url,
57
+ created_at=user.created_at,
58
+ user_role_id=user.user_role_id,
59
+ auth_method=AuthMethod(user.auth_method),
60
+ )
@@ -1,11 +1,45 @@
1
+ from typing import TYPE_CHECKING
2
+
1
3
  import strawberry
2
4
  from strawberry import Private
3
- from strawberry.relay.types import Node, NodeID
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.types import Info
7
+ from typing_extensions import Annotated
8
+
9
+ from phoenix.db.models import ApiKey as OrmApiKey
10
+ from phoenix.server.api.context import Context
11
+ from phoenix.server.api.exceptions import NotFound
4
12
 
5
13
  from .ApiKey import ApiKey
6
14
 
15
+ if TYPE_CHECKING:
16
+ from .User import User
17
+
7
18
 
8
19
  @strawberry.type
9
20
  class UserApiKey(ApiKey, Node):
10
21
  id_attr: NodeID[int]
11
22
  user_id: Private[int]
23
+
24
+ @strawberry.field
25
+ async def user(self, info: Info[Context, None]) -> Annotated["User", strawberry.lazy(".User")]:
26
+ user = await info.context.data_loaders.users.load(self.user_id)
27
+ if user is None:
28
+ raise NotFound(f"User with id {self.user_id} not found")
29
+ from .User import to_gql_user
30
+
31
+ return to_gql_user(user)
32
+
33
+
34
+ def to_gql_api_key(api_key: OrmApiKey) -> UserApiKey:
35
+ """
36
+ Converts an ORM API key to a GraphQL UserApiKey type.
37
+ """
38
+ return UserApiKey(
39
+ id_attr=api_key.id,
40
+ user_id=api_key.user_id,
41
+ name=api_key.name,
42
+ description=api_key.description,
43
+ created_at=api_key.created_at,
44
+ expires_at=api_key.expires_at,
45
+ )
@@ -1,8 +1,15 @@
1
1
  import strawberry
2
2
  from strawberry.relay import Node, NodeID
3
3
 
4
+ from phoenix.db import models
5
+
4
6
 
5
7
  @strawberry.type
6
8
  class UserRole(Node):
7
9
  id_attr: NodeID[int]
8
10
  name: str
11
+
12
+
13
+ def to_gql_user_role(role: models.UserRole) -> UserRole:
14
+ """Convert an ORM user role to a GraphQL user role."""
15
+ return UserRole(id_attr=role.id, name=role.name)