arize-phoenix 9.6.1__py3-none-any.whl → 10.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (33) hide show
  1. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/METADATA +1 -1
  2. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/RECORD +33 -32
  3. phoenix/auth.py +6 -4
  4. phoenix/config.py +150 -25
  5. phoenix/db/README.md +38 -1
  6. phoenix/db/enums.py +1 -2
  7. phoenix/db/facilitator.py +20 -11
  8. phoenix/db/migrations/versions/6a88424799fe_update_users_with_auth_method.py +179 -0
  9. phoenix/db/models.py +66 -37
  10. phoenix/server/api/context.py +5 -4
  11. phoenix/server/api/mutations/user_mutations.py +58 -26
  12. phoenix/server/api/routers/auth.py +16 -4
  13. phoenix/server/api/routers/oauth2.py +196 -15
  14. phoenix/server/app.py +36 -10
  15. phoenix/server/bearer_auth.py +5 -7
  16. phoenix/server/jwt_store.py +5 -4
  17. phoenix/server/main.py +11 -4
  18. phoenix/server/oauth2.py +47 -3
  19. phoenix/server/static/.vite/manifest.json +44 -44
  20. phoenix/server/static/assets/{components-CDvTuTqd.js → components-D6QBwbkV.js} +274 -241
  21. phoenix/server/static/assets/{index-DpcxdHu4.js → index-D4fytZZJ.js} +11 -11
  22. phoenix/server/static/assets/{pages-Bcs41-Zv.js → pages-Bg98duFI.js} +350 -381
  23. phoenix/server/static/assets/{vendor-arizeai-BhbMHqQs.js → vendor-arizeai-Dy-0mSNw.js} +1 -1
  24. phoenix/server/static/assets/{vendor-codemirror-CeLHFooz.js → vendor-codemirror-DBtifKNr.js} +1 -1
  25. phoenix/server/static/assets/{vendor-CToBXdDM.js → vendor-oB4u9zuV.js} +11 -11
  26. phoenix/server/static/assets/{vendor-recharts-PlWJHgM9.js → vendor-recharts-D-T4KPz2.js} +1 -1
  27. phoenix/server/static/assets/{vendor-shiki-CPwL2jwA.js → vendor-shiki-BMn4O_9F.js} +1 -1
  28. phoenix/server/templates/index.html +1 -0
  29. phoenix/version.py +1 -1
  30. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/WHEEL +0 -0
  31. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/entry_points.txt +0 -0
  32. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  33. {arize_phoenix-9.6.1.dist-info → arize_phoenix-10.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -15,7 +15,7 @@ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
15
15
  from sqlalchemy.ext.asyncio import AsyncSession
16
16
  from sqlalchemy.orm import joinedload
17
17
  from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
18
- from starlette.datastructures import URL, URLPath
18
+ from starlette.datastructures import URL, Secret, URLPath
19
19
  from starlette.responses import RedirectResponse
20
20
  from starlette.routing import Router
21
21
  from starlette.status import HTTP_302_FOUND
@@ -32,7 +32,10 @@ from phoenix.auth import (
32
32
  set_oauth2_state_cookie,
33
33
  set_refresh_token_cookie,
34
34
  )
35
- from phoenix.config import get_env_disable_rate_limit
35
+ from phoenix.config import (
36
+ get_env_disable_basic_auth,
37
+ get_env_disable_rate_limit,
38
+ )
36
39
  from phoenix.db import models
37
40
  from phoenix.db.enums import UserRole
38
41
  from phoenix.server.bearer_auth import create_access_and_refresh_tokens
@@ -77,6 +80,7 @@ else:
77
80
  create_tokens_dependencies = []
78
81
 
79
82
 
83
+ @router.get("/{idp_name}/login", dependencies=login_dependencies)
80
84
  @router.post("/{idp_name}/login", dependencies=login_dependencies)
81
85
  async def login(
82
86
  request: Request,
@@ -169,12 +173,13 @@ async def create_tokens(
169
173
  user_info = _parse_user_info(user_info)
170
174
  try:
171
175
  async with request.app.state.db() as session:
172
- user = await _ensure_user_exists_and_is_up_to_date(
176
+ user = await _process_oauth2_user(
173
177
  session,
174
178
  oauth2_client_id=str(oauth2_client.client_id),
175
179
  user_info=user_info,
180
+ allow_sign_up=oauth2_client.allow_sign_up,
176
181
  )
177
- except EmailAlreadyInUse as error:
182
+ except (EmailAlreadyInUse, SignInNotAllowed) as error:
178
183
  return _redirect_to_login(request=request, error=str(error))
179
184
  access_token, refresh_token = await create_access_and_refresh_tokens(
180
185
  user=user,
@@ -198,12 +203,24 @@ async def create_tokens(
198
203
  return response
199
204
 
200
205
 
201
- @dataclass
206
+ @dataclass(frozen=True)
202
207
  class UserInfo:
203
208
  idp_user_id: str
204
209
  email: str
205
- username: Optional[str]
206
- profile_picture_url: Optional[str]
210
+ username: Optional[str] = None
211
+ profile_picture_url: Optional[str] = None
212
+
213
+ def __post_init__(self) -> None:
214
+ if not (idp_user_id := (self.idp_user_id or "").strip()):
215
+ raise ValueError("idp_user_id cannot be empty")
216
+ object.__setattr__(self, "idp_user_id", idp_user_id)
217
+ if not (email := (self.email or "").strip()):
218
+ raise ValueError("email cannot be empty")
219
+ object.__setattr__(self, "email", email)
220
+ if username := (self.username or "").strip():
221
+ object.__setattr__(self, "username", username)
222
+ if profile_picture_url := (self.profile_picture_url or "").strip():
223
+ object.__setattr__(self, "profile_picture_url", profile_picture_url)
207
224
 
208
225
 
209
226
  def _validate_token_data(token_data: dict[str, Any]) -> None:
@@ -235,9 +252,157 @@ def _parse_user_info(user_info: dict[str, Any]) -> UserInfo:
235
252
  )
236
253
 
237
254
 
238
- async def _ensure_user_exists_and_is_up_to_date(
239
- session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo
255
+ async def _process_oauth2_user(
256
+ session: AsyncSession,
257
+ /,
258
+ *,
259
+ oauth2_client_id: str,
260
+ user_info: UserInfo,
261
+ allow_sign_up: bool,
262
+ ) -> models.User:
263
+ """
264
+ Processes an OAuth2 user, either signing in an existing user or creating/updating one.
265
+
266
+ This function handles two main scenarios based on the allow_sign_up parameter:
267
+ 1. When sign-up is not allowed (allow_sign_up=False):
268
+ - Checks if the user exists and can sign in with the given OAuth2 credentials
269
+ - Updates placeholder OAuth2 credentials if needed (e.g., temporary IDs)
270
+ - If the user doesn't exist or has a password set, raises SignInNotAllowed
271
+ 2. When sign-up is allowed (allow_sign_up=True):
272
+ - Finds the user by OAuth2 credentials (client_id and user_id)
273
+ - Creates a new user if one doesn't exist, with default member role
274
+ - Updates the user's email if it has changed
275
+ - Handles username conflicts by adding a random suffix if needed
276
+
277
+ The allow_sign_up parameter is typically controlled by the PHOENIX_OAUTH2_{IDP_NAME}_ALLOW_SIGN_UP
278
+ environment variable for the specific identity provider.
279
+
280
+ Args:
281
+ session: The database session
282
+ oauth2_client_id: The ID of the OAuth2 client
283
+ user_info: User information from the OAuth2 provider
284
+ allow_sign_up: Whether to allow creating new users
285
+
286
+ Returns:
287
+ The user object
288
+
289
+ Raises:
290
+ SignInNotAllowed: When sign-in is not allowed for the user (user doesn't exist or has a password)
291
+ EmailAlreadyInUse: When the email is already in use by another account
292
+ """ # noqa: E501
293
+ if not allow_sign_up:
294
+ return await _get_existing_oauth2_user(
295
+ session,
296
+ oauth2_client_id=oauth2_client_id,
297
+ user_info=user_info,
298
+ )
299
+ return await _create_or_update_user(
300
+ session,
301
+ oauth2_client_id=oauth2_client_id,
302
+ user_info=user_info,
303
+ )
304
+
305
+
306
+ async def _get_existing_oauth2_user(
307
+ session: AsyncSession,
308
+ /,
309
+ *,
310
+ oauth2_client_id: str,
311
+ user_info: UserInfo,
240
312
  ) -> models.User:
313
+ """Signs in an existing user with OAuth2 credentials.
314
+
315
+ This function handles OAuth2 authentication for existing users. It follows a two-step process:
316
+
317
+ 1. First Attempt: Find user by OAuth2 credentials
318
+ - Searches for a user with matching oauth2_client_id and oauth2_user_id
319
+ - If found, updates email if it has changed from IDP info
320
+ - If not found, proceeds to step 2
321
+
322
+ 2. Second Attempt: Find user by email
323
+ - Searches for a user with matching email
324
+ - Verifies the user is an OAuth2 user (no password set)
325
+ - Handles OAuth2 credential updates in three cases:
326
+ a) Different OAuth2 client: Updates both client and user IDs
327
+ b) Same client but missing user ID: Sets the user ID
328
+ c) Same client but different user ID: Rejects sign-in
329
+
330
+ Profile Updates:
331
+ - Email: Updated if different from IDP info
332
+ - Profile Picture: Updated if provided in user_info
333
+ - Username: Never updated (remains unchanged)
334
+ - OAuth2 Credentials: Updated based on the three cases above
335
+
336
+ Args:
337
+ session: The database session
338
+ oauth2_client_id: The ID of the OAuth2 client
339
+ user_info: User information from the OAuth2 provider
340
+
341
+ Returns:
342
+ The signed-in user
343
+
344
+ Raises:
345
+ ValueError: If required fields (email, oauth2_user_id, oauth2_client_id) are empty
346
+ SignInNotAllowed: When sign-in is not allowed because:
347
+ - User doesn't exist
348
+ - User has a password set
349
+ - User has mismatched OAuth2 credentials
350
+ """ # noqa: E501
351
+ if not (email := (user_info.email or "").strip()):
352
+ raise ValueError("Email is required.")
353
+ if not (oauth2_user_id := (user_info.idp_user_id or "").strip()):
354
+ raise ValueError("OAuth2 user ID is required.")
355
+ if not (oauth2_client_id := (oauth2_client_id or "").strip()):
356
+ raise ValueError("OAuth2 client ID is required.")
357
+ profile_picture_url = (user_info.profile_picture_url or "").strip()
358
+ stmt = select(models.User).options(joinedload(models.User.role))
359
+ if user := await session.scalar(
360
+ stmt.filter_by(oauth2_client_id=oauth2_client_id, oauth2_user_id=oauth2_user_id)
361
+ ):
362
+ if email and email != user.email:
363
+ user.email = email
364
+ else:
365
+ user = await session.scalar(stmt.filter_by(email=email))
366
+ if user is None or not isinstance(user, models.OAuth2User):
367
+ raise SignInNotAllowed("Sign in is not allowed.")
368
+ # Case 1: Different OAuth2 client - update both client and user IDs
369
+ if oauth2_client_id != user.oauth2_client_id:
370
+ user.oauth2_client_id = oauth2_client_id
371
+ user.oauth2_user_id = oauth2_user_id
372
+ # Case 2: Same client but missing user ID - set the user ID
373
+ elif not user.oauth2_user_id:
374
+ user.oauth2_user_id = oauth2_user_id
375
+ # Case 3: Same client but different user ID - reject sign-in
376
+ elif oauth2_user_id != user.oauth2_user_id:
377
+ raise SignInNotAllowed("Sign in is not allowed.")
378
+ if profile_picture_url != user.profile_picture_url:
379
+ user.profile_picture_url = profile_picture_url
380
+ if user in session.dirty:
381
+ await session.flush()
382
+ return user
383
+
384
+
385
+ async def _create_or_update_user(
386
+ session: AsyncSession,
387
+ /,
388
+ *,
389
+ oauth2_client_id: str,
390
+ user_info: UserInfo,
391
+ ) -> models.User:
392
+ """
393
+ Creates a new user or updates an existing one with OAuth2 credentials.
394
+
395
+ Args:
396
+ session: The database session
397
+ oauth2_client_id: The ID of the OAuth2 client
398
+ user_info: User information from the OAuth2 provider
399
+
400
+ Returns:
401
+ The created or updated user
402
+
403
+ Raises:
404
+ EmailAlreadyInUse: When the email is already in use by another account
405
+ """
241
406
  user = await _get_user(
242
407
  session,
243
408
  oauth2_client_id=oauth2_client_id,
@@ -251,7 +416,11 @@ async def _ensure_user_exists_and_is_up_to_date(
251
416
 
252
417
 
253
418
  async def _get_user(
254
- session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str
419
+ session: AsyncSession,
420
+ /,
421
+ *,
422
+ oauth2_client_id: str,
423
+ idp_user_id: str,
255
424
  ) -> Optional[models.User]:
256
425
  """
257
426
  Retrieves the user uniquely identified by the given OAuth2 client ID and IDP
@@ -303,6 +472,7 @@ async def _create_user(
303
472
  email=email,
304
473
  profile_picture_url=user_info.profile_picture_url,
305
474
  reset_password=False,
475
+ auth_method="OAUTH2",
306
476
  )
307
477
  )
308
478
  assert isinstance(user_id, int)
@@ -366,11 +536,22 @@ class EmailAlreadyInUse(Exception):
366
536
  pass
367
537
 
368
538
 
539
+ class SignInNotAllowed(Exception):
540
+ pass
541
+
542
+
543
+ class NotInvited(Exception):
544
+ pass
545
+
546
+
369
547
  def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
370
548
  """
371
549
  Creates a RedirectResponse to the login page to display an error message.
372
550
  """
373
- login_path = _prepend_root_path_if_exists(request=request, path="/login")
551
+ # TODO: this needs some cleanup
552
+ login_path = _prepend_root_path_if_exists(
553
+ request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
554
+ )
374
555
  url = URL(login_path).include_query_params(error=error)
375
556
  response = RedirectResponse(url=url)
376
557
  response = delete_oauth2_state_cookie(response)
@@ -416,7 +597,7 @@ def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name:
416
597
 
417
598
 
418
599
  def _generate_state_for_oauth2_authorization_code_flow(
419
- *, secret: str, origin_url: str, return_url: Optional[str]
600
+ *, secret: Secret, origin_url: str, return_url: Optional[str]
420
601
  ) -> str:
421
602
  """
422
603
  Generates a JWT whose payload contains both an OAuth2 state (generated using
@@ -432,7 +613,7 @@ def _generate_state_for_oauth2_authorization_code_flow(
432
613
  )
433
614
  if return_url is not None:
434
615
  payload["return_url"] = return_url
435
- jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=secret)
616
+ jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=str(secret))
436
617
  return jwt_bytes.decode()
437
618
 
438
619
 
@@ -446,11 +627,11 @@ class _OAuth2StatePayload(TypedDict):
446
627
  return_url: NotRequired[str]
447
628
 
448
629
 
449
- def _parse_state_payload(*, secret: str, state: str) -> _OAuth2StatePayload:
630
+ def _parse_state_payload(*, secret: Secret, state: str) -> _OAuth2StatePayload:
450
631
  """
451
632
  Validates the JWT signature and parses the return URL from the OAuth2 state.
452
633
  """
453
- payload = jwt.decode(s=state, key=secret)
634
+ payload = jwt.decode(s=state, key=str(secret))
454
635
  if _is_oauth2_state_payload(payload):
455
636
  return payload
456
637
  raise ValueError("Invalid OAuth2 state payload.")
phoenix/server/app.py CHANGED
@@ -29,13 +29,14 @@ from fastapi.utils import is_body_allowed_for_status_code
29
29
  from grpc.aio import ServerInterceptor
30
30
  from sqlalchemy import select
31
31
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
32
+ from starlette.datastructures import URL, Secret
32
33
  from starlette.datastructures import State as StarletteState
33
34
  from starlette.exceptions import HTTPException
34
35
  from starlette.middleware import Middleware
35
36
  from starlette.middleware.authentication import AuthenticationMiddleware
36
37
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
37
38
  from starlette.requests import Request
38
- from starlette.responses import JSONResponse, PlainTextResponse, Response
39
+ from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse, Response
39
40
  from starlette.staticfiles import StaticFiles
40
41
  from starlette.status import HTTP_401_UNAUTHORIZED
41
42
  from starlette.templating import Jinja2Templates
@@ -56,6 +57,7 @@ from phoenix.config import (
56
57
  get_env_gql_extension_paths,
57
58
  get_env_grpc_interceptor_paths,
58
59
  get_env_host,
60
+ get_env_host_root_path,
59
61
  get_env_port,
60
62
  server_instrumentation_is_enabled,
61
63
  verify_server_environment_variables,
@@ -210,6 +212,8 @@ class AppConfig(NamedTuple):
210
212
  authentication_enabled: bool
211
213
  """ Whether authentication is enabled """
212
214
  oauth2_idps: Sequence[OAuth2Idp]
215
+ basic_auth_disabled: bool = False
216
+ auto_login_idp_name: Optional[str] = None
213
217
 
214
218
 
215
219
  class Static(StaticFiles):
@@ -235,15 +239,29 @@ class Static(StaticFiles):
235
239
  return basename[:-1] if basename.endswith("/") else basename
236
240
 
237
241
  async def get_response(self, path: str, scope: Scope) -> Response:
238
- response = None
242
+ # Redirect to the oauth2 login page if basic auth is disabled and auto_login is enabled
243
+ # TODO: this needs to be refactored to be cleaner
244
+ if (
245
+ path == "login"
246
+ and self._app_config.basic_auth_disabled
247
+ and self._app_config.auto_login_idp_name
248
+ ):
249
+ request = Request(scope)
250
+ url = URL(
251
+ str(
252
+ Path(get_env_host_root_path())
253
+ / f"oauth2/{self._app_config.auto_login_idp_name}/login"
254
+ )
255
+ )
256
+ url = url.include_query_params(**request.query_params)
257
+ return RedirectResponse(url=url)
239
258
  try:
240
259
  response = await super().get_response(path, scope)
241
260
  except HTTPException as e:
242
261
  if e.status_code != 404:
243
262
  raise e
244
- # Fallback to to the index.html
263
+ # Fallback to the index.html
245
264
  request = Request(scope)
246
-
247
265
  response = templates.TemplateResponse(
248
266
  "index.html",
249
267
  context={
@@ -259,6 +277,8 @@ class Static(StaticFiles):
259
277
  "manifest": self._web_manifest,
260
278
  "authentication_enabled": self._app_config.authentication_enabled,
261
279
  "oauth2_idps": self._app_config.oauth2_idps,
280
+ "basic_auth_disabled": self._app_config.basic_auth_disabled,
281
+ "auto_login_idp_name": self._app_config.auto_login_idp_name,
262
282
  },
263
283
  )
264
284
  except Exception as e:
@@ -564,7 +584,7 @@ def create_graphql_router(
564
584
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
565
585
  event_queue: CanPutItem[DmlEvent],
566
586
  read_only: bool = False,
567
- secret: Optional[str] = None,
587
+ secret: Optional[Secret] = None,
568
588
  token_store: Optional[TokenStore] = None,
569
589
  email_sender: Optional[EmailSender] = None,
570
590
  ) -> GraphQLRouter[Context, None]:
@@ -581,7 +601,7 @@ def create_graphql_router(
581
601
  corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
582
602
  cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
583
603
  read_only (bool, optional): Marks the app as read-only. Defaults to False.
584
- secret (Optional[str], optional): The application secret for auth. Defaults to None.
604
+ secret (Optional[Secret], optional): The application secret for auth. Defaults to None.
585
605
  token_store (Optional[TokenStore], optional): The token store for auth. Defaults to None.
586
606
  email_sender (Optional[EmailSender], optional): The email sender. Defaults to None.
587
607
 
@@ -762,13 +782,14 @@ def create_app(
762
782
  serve_ui: bool = True,
763
783
  startup_callbacks: Iterable[_Callback] = (),
764
784
  shutdown_callbacks: Iterable[_Callback] = (),
765
- secret: Optional[str] = None,
785
+ secret: Optional[Secret] = None,
766
786
  password_reset_token_expiry: Optional[timedelta] = None,
767
787
  access_token_expiry: Optional[timedelta] = None,
768
788
  refresh_token_expiry: Optional[timedelta] = None,
769
789
  scaffolder_config: Optional[ScaffolderConfig] = None,
770
790
  email_sender: Optional[EmailSender] = None,
771
791
  oauth2_client_configs: Optional[list[OAuth2ClientConfig]] = None,
792
+ basic_auth_disabled: bool = False,
772
793
  bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
773
794
  allowed_origins: Optional[list[str]] = None,
774
795
  ) -> FastAPI:
@@ -924,6 +945,9 @@ def create_app(
924
945
  OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
925
946
  for config in oauth2_client_configs or []
926
947
  ]
948
+ auto_login_idp_name = next(
949
+ (config.idp_name for config in (oauth2_client_configs or []) if config.auto_login), None
950
+ )
927
951
  app.mount(
928
952
  "/",
929
953
  app=Static(
@@ -938,6 +962,8 @@ def create_app(
938
962
  authentication_enabled=authentication_enabled,
939
963
  web_manifest_path=web_manifest_path,
940
964
  oauth2_idps=oauth2_idps,
965
+ basic_auth_disabled=basic_auth_disabled,
966
+ auto_login_idp_name=auto_login_idp_name,
941
967
  ),
942
968
  ),
943
969
  name="static",
@@ -970,16 +996,16 @@ def create_app(
970
996
  return app
971
997
 
972
998
 
973
- def _add_get_secret_method(*, app: FastAPI, secret: Optional[str]) -> FastAPI:
999
+ def _add_get_secret_method(*, app: FastAPI, secret: Optional[Secret]) -> FastAPI:
974
1000
  """
975
1001
  Dynamically adds a `get_secret` method to the app's `state`.
976
1002
  """
977
1003
  app.state._secret = secret
978
1004
 
979
- def get_secret(self: StarletteState) -> str:
1005
+ def get_secret(self: StarletteState) -> Secret:
980
1006
  if (secret := self._secret) is None:
981
1007
  raise ValueError("app secret is not set")
982
- assert isinstance(secret, str)
1008
+ assert isinstance(secret, Secret)
983
1009
  return secret
984
1010
 
985
1011
  app.state.get_secret = MethodType(get_secret, app.state)
@@ -20,9 +20,7 @@ from phoenix.auth import (
20
20
  Token,
21
21
  )
22
22
  from phoenix.config import get_env_phoenix_admin_secret
23
- from phoenix.db import enums
24
- from phoenix.db.enums import UserRole
25
- from phoenix.db.models import User as OrmUser
23
+ from phoenix.db import enums, models
26
24
  from phoenix.server.types import (
27
25
  AccessToken,
28
26
  AccessTokenAttributes,
@@ -54,7 +52,7 @@ class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
54
52
  return None
55
53
  if (
56
54
  (admin_secret := get_env_phoenix_admin_secret())
57
- and token == admin_secret
55
+ and token == str(admin_secret)
58
56
  and config.SYSTEM_USER_ID is not None
59
57
  ):
60
58
  return AuthCredentials(), PhoenixSystemUser(UserId(config.SYSTEM_USER_ID))
@@ -117,7 +115,7 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
117
115
  break
118
116
  if (
119
117
  (admin_secret := get_env_phoenix_admin_secret())
120
- and token == admin_secret
118
+ and token == str(admin_secret)
121
119
  and config.SYSTEM_USER_ID is not None
122
120
  ):
123
121
  return await method(request_or_iterator, context)
@@ -159,13 +157,13 @@ async def is_authenticated(
159
157
  async def create_access_and_refresh_tokens(
160
158
  *,
161
159
  token_store: TokenStore,
162
- user: OrmUser,
160
+ user: models.User,
163
161
  access_token_expiry: timedelta,
164
162
  refresh_token_expiry: timedelta,
165
163
  ) -> tuple[AccessToken, RefreshToken]:
166
164
  issued_at = datetime.now(timezone.utc)
167
165
  user_id = UserId(user.id)
168
- user_role = UserRole(user.role.name)
166
+ user_role = enums.UserRole(user.role.name)
169
167
  refresh_token_claims = RefreshTokenClaims(
170
168
  subject=user_id,
171
169
  issued_at=issued_at,
@@ -11,6 +11,7 @@ from typing import Any, Generic, Optional, TypeVar
11
11
  from authlib.jose import jwt
12
12
  from authlib.jose.errors import JoseError
13
13
  from sqlalchemy import Select, delete, select
14
+ from starlette.datastructures import Secret
14
15
 
15
16
  from phoenix.auth import (
16
17
  JWT_ALGORITHM,
@@ -50,7 +51,7 @@ class JwtStore:
50
51
  def __init__(
51
52
  self,
52
53
  db: DbSessionFactory,
53
- secret: str,
54
+ secret: Secret,
54
55
  algorithm: str = JWT_ALGORITHM,
55
56
  sleep_seconds: int = 10,
56
57
  **kwargs: Any,
@@ -79,7 +80,7 @@ class JwtStore:
79
80
  try:
80
81
  payload = jwt.decode(
81
82
  s=token,
82
- key=self._secret,
83
+ key=str(self._secret),
83
84
  )
84
85
  except JoseError:
85
86
  return None
@@ -231,7 +232,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
231
232
  def __init__(
232
233
  self,
233
234
  db: DbSessionFactory,
234
- secret: str,
235
+ secret: Secret,
235
236
  algorithm: str = JWT_ALGORITHM,
236
237
  sleep_seconds: int = 10,
237
238
  **kwargs: Any,
@@ -247,7 +248,7 @@ class _Store(DaemonTask, Generic[_ClaimSetT, _TokenT, _TokenIdT, _RecordT], ABC)
247
248
  def _encode(self, claim: ClaimSet) -> str:
248
249
  payload: dict[str, Any] = dict(jti=claim.token_id)
249
250
  header = {"alg": self._algorithm}
250
- jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
251
+ jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=str(self._secret))
251
252
  return jwt_bytes.decode()
252
253
 
253
254
  async def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
phoenix/server/main.py CHANGED
@@ -102,7 +102,11 @@ _WELCOME_MESSAGE = Environment(loader=BaseLoader()).from_string("""
102
102
  |
103
103
  | 🚀 Phoenix Server 🚀
104
104
  | Phoenix UI: {{ ui_path }}
105
+ |
105
106
  | Authentication: {{ auth_enabled }}
107
+ {%- if basic_auth_disabled %}
108
+ | Basic Auth: Disabled
109
+ {%- endif %}
106
110
  {%- if auth_enabled_for_http or auth_enabled_for_grpc %}
107
111
  {%- if tls_enabled_for_http %}
108
112
  | TLS: Enabled for HTTP
@@ -332,7 +336,7 @@ def main() -> None:
332
336
  reference_inferences,
333
337
  )
334
338
 
335
- authentication_enabled, secret = get_env_auth_settings()
339
+ auth_settings = get_env_auth_settings()
336
340
 
337
341
  fixture_spans: list[Span] = []
338
342
  fixture_evals: list[pb.Evaluation] = []
@@ -390,12 +394,14 @@ def main() -> None:
390
394
  http_path=urljoin(root_path, "v1/traces"),
391
395
  storage=get_printable_db_url(db_connection_str),
392
396
  schema=get_env_database_schema(),
393
- auth_enabled=authentication_enabled,
397
+ auth_enabled=auth_settings.enable_auth,
398
+ disable_basic_auth=auth_settings.disable_basic_auth,
394
399
  tls_enabled_for_http=tls_enabled_for_http,
395
400
  tls_enabled_for_grpc=tls_enabled_for_grpc,
396
401
  tls_verify_client=tls_verify_client,
397
402
  allowed_origins=allowed_origins,
398
403
  )
404
+
399
405
  if sys.platform.startswith("win"):
400
406
  msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
401
407
  scaffolder_config = ScaffolderConfig(
@@ -424,7 +430,8 @@ def main() -> None:
424
430
  db=factory,
425
431
  export_path=export_path,
426
432
  model=model,
427
- authentication_enabled=authentication_enabled,
433
+ authentication_enabled=auth_settings.enable_auth,
434
+ basic_auth_disabled=auth_settings.disable_basic_auth,
428
435
  umap_params=umap_params,
429
436
  corpus=corpus_model,
430
437
  debug=args.debug,
@@ -436,7 +443,7 @@ def main() -> None:
436
443
  initial_evaluations=fixture_evals,
437
444
  startup_callbacks=[lambda: print(msg)],
438
445
  shutdown_callbacks=instrumentation_cleanups,
439
- secret=secret,
446
+ secret=auth_settings.phoenix_secret,
440
447
  password_reset_token_expiry=get_env_password_reset_token_expiry(),
441
448
  access_token_expiry=get_env_access_token_expiry(),
442
449
  refresh_token_expiry=get_env_refresh_token_expiry(),
phoenix/server/oauth2.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Iterable
2
- from typing import Any
2
+ from typing import Any, Iterator, Optional
3
3
 
4
4
  from authlib.integrations.base_client import BaseApp
5
5
  from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin
@@ -19,24 +19,68 @@ class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[
19
19
 
20
20
  client_cls = AsyncHttpxOAuth2Client
21
21
 
22
- def __init__(self, *args: Any, **kwargs: Any) -> None:
22
+ def __init__(
23
+ self,
24
+ *args: Any,
25
+ display_name: str,
26
+ allow_sign_up: bool,
27
+ auto_login: bool,
28
+ **kwargs: Any,
29
+ ) -> None:
30
+ self._display_name = display_name
31
+ self._allow_sign_up = allow_sign_up
32
+ self._auto_login = auto_login
23
33
  super().__init__(framework=None, *args, **kwargs)
34
+ self._allow_sign_up = allow_sign_up
35
+
36
+ @property
37
+ def allow_sign_up(self) -> bool:
38
+ return self._allow_sign_up
39
+
40
+ @property
41
+ def auto_login(self) -> bool:
42
+ return self._auto_login
43
+
44
+ @property
45
+ def display_name(self) -> str:
46
+ return self._display_name
24
47
 
25
48
 
26
49
  class OAuth2Clients:
27
50
  def __init__(self) -> None:
28
51
  self._clients: dict[str, OAuth2Client] = {}
52
+ self._auto_login_client: Optional[OAuth2Client] = None
53
+
54
+ def __bool__(self) -> bool:
55
+ return bool(self._clients)
56
+
57
+ def __len__(self) -> int:
58
+ return len(self._clients)
59
+
60
+ def __iter__(self) -> Iterator[OAuth2Client]:
61
+ return iter(self._clients.values())
62
+
63
+ @property
64
+ def auto_login_client(self) -> Optional[OAuth2Client]:
65
+ return self._auto_login_client
29
66
 
30
67
  def add_client(self, config: OAuth2ClientConfig) -> None:
31
68
  if (idp_name := config.idp_name) in self._clients:
32
69
  raise ValueError(f"oauth client already registered: {idp_name}")
33
70
  client = OAuth2Client(
71
+ name=config.idp_name,
34
72
  client_id=config.client_id,
35
73
  client_secret=config.client_secret,
36
74
  server_metadata_url=config.oidc_config_url,
37
75
  client_kwargs={"scope": "openid email profile"},
76
+ display_name=config.idp_display_name,
77
+ allow_sign_up=config.allow_sign_up,
78
+ auto_login=config.auto_login,
38
79
  )
39
- assert isinstance(client, OAuth2Client)
80
+ if config.auto_login:
81
+ if self._auto_login_client:
82
+ raise ValueError("only one auto-login client is allowed")
83
+ self._auto_login_client = client
40
84
  self._clients[config.idp_name] = client
41
85
 
42
86
  def get_client(self, idp_name: str) -> OAuth2Client: