arize-phoenix 4.36.0__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
  2. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +68 -59
  3. phoenix/__init__.py +86 -0
  4. phoenix/auth.py +275 -14
  5. phoenix/config.py +277 -25
  6. phoenix/db/enums.py +20 -0
  7. phoenix/db/facilitator.py +112 -0
  8. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  9. phoenix/db/models.py +145 -60
  10. phoenix/experiments/evaluators/code_evaluators.py +9 -3
  11. phoenix/experiments/functions.py +1 -4
  12. phoenix/server/api/README.md +28 -0
  13. phoenix/server/api/auth.py +32 -0
  14. phoenix/server/api/context.py +50 -2
  15. phoenix/server/api/dataloaders/__init__.py +4 -0
  16. phoenix/server/api/dataloaders/user_roles.py +30 -0
  17. phoenix/server/api/dataloaders/users.py +33 -0
  18. phoenix/server/api/exceptions.py +7 -0
  19. phoenix/server/api/mutations/__init__.py +0 -2
  20. phoenix/server/api/mutations/api_key_mutations.py +104 -86
  21. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  22. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  23. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  24. phoenix/server/api/mutations/project_mutations.py +3 -3
  25. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  26. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  27. phoenix/server/api/mutations/user_mutations.py +282 -42
  28. phoenix/server/api/openapi/schema.py +2 -2
  29. phoenix/server/api/queries.py +48 -39
  30. phoenix/server/api/routers/__init__.py +11 -0
  31. phoenix/server/api/routers/auth.py +284 -0
  32. phoenix/server/api/routers/embeddings.py +26 -0
  33. phoenix/server/api/routers/oauth2.py +456 -0
  34. phoenix/server/api/routers/v1/__init__.py +38 -16
  35. phoenix/server/api/types/ApiKey.py +11 -0
  36. phoenix/server/api/types/AuthMethod.py +9 -0
  37. phoenix/server/api/types/User.py +48 -4
  38. phoenix/server/api/types/UserApiKey.py +35 -1
  39. phoenix/server/api/types/UserRole.py +7 -0
  40. phoenix/server/app.py +103 -31
  41. phoenix/server/bearer_auth.py +161 -0
  42. phoenix/server/email/__init__.py +0 -0
  43. phoenix/server/email/sender.py +26 -0
  44. phoenix/server/email/templates/__init__.py +0 -0
  45. phoenix/server/email/templates/password_reset.html +19 -0
  46. phoenix/server/email/types.py +11 -0
  47. phoenix/server/grpc_server.py +6 -0
  48. phoenix/server/jwt_store.py +504 -0
  49. phoenix/server/main.py +40 -9
  50. phoenix/server/oauth2.py +51 -0
  51. phoenix/server/prometheus.py +20 -0
  52. phoenix/server/rate_limiters.py +191 -0
  53. phoenix/server/static/.vite/manifest.json +31 -31
  54. phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
  55. phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
  56. phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
  57. phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
  58. phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
  59. phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
  60. phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
  61. phoenix/server/templates/index.html +1 -0
  62. phoenix/server/types.py +157 -1
  63. phoenix/session/client.py +7 -2
  64. phoenix/utilities/client.py +16 -0
  65. phoenix/version.py +1 -1
  66. phoenix/db/migrations/future_versions/README.md +0 -4
  67. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
  68. phoenix/db/migrations/versions/.gitignore +0 -1
  69. phoenix/server/api/mutations/auth.py +0 -18
  70. phoenix/server/api/mutations/auth_mutations.py +0 -65
  71. phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
  72. phoenix/trace/langchain/__init__.py +0 -3
  73. phoenix/trace/langchain/instrumentor.py +0 -34
  74. phoenix/trace/llama_index/__init__.py +0 -3
  75. phoenix/trace/llama_index/callback.py +0 -102
  76. phoenix/trace/openai/__init__.py +0 -3
  77. phoenix/trace/openai/instrumentor.py +0 -30
  78. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
  79. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  80. {arize_phoenix-4.36.0.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py CHANGED
@@ -20,21 +20,23 @@ from typing import (
20
20
  List,
21
21
  NamedTuple,
22
22
  Optional,
23
+ Sequence,
23
24
  Tuple,
25
+ TypedDict,
24
26
  Union,
25
27
  cast,
26
28
  )
27
29
 
28
30
  import strawberry
29
- from fastapi import APIRouter, FastAPI
31
+ from fastapi import APIRouter, Depends, FastAPI
30
32
  from fastapi.middleware.gzip import GZipMiddleware
31
- from fastapi.responses import FileResponse
32
33
  from fastapi.utils import is_body_allowed_for_status_code
33
34
  from sqlalchemy import select
34
35
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
35
36
  from starlette.datastructures import State as StarletteState
36
37
  from starlette.exceptions import HTTPException
37
38
  from starlette.middleware import Middleware
39
+ from starlette.middleware.authentication import AuthenticationMiddleware
38
40
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
39
41
  from starlette.requests import Request
40
42
  from starlette.responses import PlainTextResponse, Response
@@ -50,6 +52,7 @@ import phoenix.trace.v1 as pb
50
52
  from phoenix.config import (
51
53
  DEFAULT_PROJECT_NAME,
52
54
  SERVER_DIR,
55
+ OAuth2ClientConfig,
53
56
  get_env_host,
54
57
  get_env_port,
55
58
  server_instrumentation_is_enabled,
@@ -58,6 +61,7 @@ from phoenix.core.model_schema import Model
58
61
  from phoenix.db import models
59
62
  from phoenix.db.bulk_inserter import BulkInserter
60
63
  from phoenix.db.engines import create_engine
64
+ from phoenix.db.facilitator import Facilitator
61
65
  from phoenix.db.helpers import SupportedSQLDialect
62
66
  from phoenix.exceptions import PhoenixMigrationError
63
67
  from phoenix.pointcloud.umap_parameters import UMAPParameters
@@ -86,13 +90,24 @@ from phoenix.server.api.dataloaders import (
86
90
  SpanProjectsDataLoader,
87
91
  TokenCountDataLoader,
88
92
  TraceRowIdsDataLoader,
93
+ UserRolesDataLoader,
94
+ UsersDataLoader,
95
+ )
96
+ from phoenix.server.api.routers import (
97
+ auth_router,
98
+ create_embeddings_router,
99
+ create_v1_router,
100
+ oauth2_router,
89
101
  )
90
102
  from phoenix.server.api.routers.v1 import REST_API_VERSION
91
- from phoenix.server.api.routers.v1 import router as v1_router
92
103
  from phoenix.server.api.schema import schema
104
+ from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
93
105
  from phoenix.server.dml_event import DmlEvent
94
106
  from phoenix.server.dml_event_handler import DmlEventHandler
107
+ from phoenix.server.email.types import EmailSender
95
108
  from phoenix.server.grpc_server import GrpcServer
109
+ from phoenix.server.jwt_store import JwtStore
110
+ from phoenix.server.oauth2 import OAuth2Clients
96
111
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
97
112
  from phoenix.server.types import (
98
113
  CanGetLastUpdatedAt,
@@ -100,6 +115,7 @@ from phoenix.server.types import (
100
115
  DaemonTask,
101
116
  DbSessionFactory,
102
117
  LastUpdatedAt,
118
+ TokenStore,
103
119
  )
104
120
  from phoenix.trace.fixtures import (
105
121
  TracesFixture,
@@ -135,6 +151,11 @@ ProjectName: TypeAlias = str
135
151
  _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
136
152
 
137
153
 
154
+ class OAuth2Idp(TypedDict):
155
+ name: str
156
+ displayName: str
157
+
158
+
138
159
  class AppConfig(NamedTuple):
139
160
  has_inferences: bool
140
161
  """ Whether the model has inferences (e.g. a primary dataset) """
@@ -146,6 +167,7 @@ class AppConfig(NamedTuple):
146
167
  web_manifest_path: Path
147
168
  authentication_enabled: bool
148
169
  """ Whether authentication is enabled """
170
+ oauth2_idps: Sequence[OAuth2Idp]
149
171
 
150
172
 
151
173
  class Static(StaticFiles):
@@ -194,6 +216,7 @@ class Static(StaticFiles):
194
216
  "is_development": self._app_config.is_development,
195
217
  "manifest": self._web_manifest,
196
218
  "authentication_enabled": self._app_config.authentication_enabled,
219
+ "oauth2_idps": self._app_config.oauth2_idps,
197
220
  },
198
221
  )
199
222
  except Exception as e:
@@ -218,18 +241,6 @@ class HeadersMiddleware(BaseHTTPMiddleware):
218
241
  ProjectRowId: TypeAlias = int
219
242
 
220
243
 
221
- @router.get("/exports")
222
- async def download_exported_file(request: Request, filename: str) -> FileResponse:
223
- file = request.app.state.export_path / (filename + ".parquet")
224
- if not file.is_file():
225
- raise HTTPException(status_code=404)
226
- return FileResponse(
227
- path=file,
228
- filename=file.name,
229
- media_type="application/x-octet-stream",
230
- )
231
-
232
-
233
244
  @router.get("/arize_phoenix_version")
234
245
  async def version() -> PlainTextResponse:
235
246
  return PlainTextResponse(f"{phoenix.__version__}")
@@ -238,13 +249,15 @@ async def version() -> PlainTextResponse:
238
249
  DB_MUTEX: Optional[asyncio.Lock] = None
239
250
 
240
251
 
241
- def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
252
+ def _db(
253
+ engine: AsyncEngine, bypass_lock: bool = False
254
+ ) -> Callable[[], AsyncContextManager[AsyncSession]]:
242
255
  Session = async_sessionmaker(engine, expire_on_commit=False)
243
256
 
244
257
  @contextlib.asynccontextmanager
245
258
  async def factory() -> AsyncIterator[AsyncSession]:
246
259
  async with contextlib.AsyncExitStack() as stack:
247
- if DB_MUTEX:
260
+ if not bypass_lock and DB_MUTEX:
248
261
  await stack.enter_async_context(DB_MUTEX)
249
262
  yield await stack.enter_async_context(Session.begin())
250
263
 
@@ -283,9 +296,6 @@ class Scaffolder(DaemonTask):
283
296
  return
284
297
  await self.start()
285
298
 
286
- async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
287
- await self.stop()
288
-
289
299
  async def _run(self) -> None:
290
300
  """
291
301
  Main entry point for Scaffolder.
@@ -374,6 +384,7 @@ def _lifespan(
374
384
  db: DbSessionFactory,
375
385
  bulk_inserter: BulkInserter,
376
386
  dml_event_handler: DmlEventHandler,
387
+ token_store: Optional[TokenStore] = None,
377
388
  tracer_provider: Optional["TracerProvider"] = None,
378
389
  enable_prometheus: bool = False,
379
390
  startup_callbacks: Iterable[_Callback] = (),
@@ -400,6 +411,7 @@ def _lifespan(
400
411
  disabled=read_only,
401
412
  tracer_provider=tracer_provider,
402
413
  enable_prometheus=enable_prometheus,
414
+ token_store=token_store,
403
415
  )
404
416
  await stack.enter_async_context(grpc_server)
405
417
  await stack.enter_async_context(dml_event_handler)
@@ -410,6 +422,8 @@ def _lifespan(
410
422
  queue_evaluation=queue_evaluation,
411
423
  )
412
424
  await stack.enter_async_context(scaffolder)
425
+ if isinstance(token_store, AsyncContextManager):
426
+ await stack.enter_async_context(token_store)
413
427
  yield {
414
428
  "event_queue": dml_event_handler,
415
429
  "enqueue": enqueue,
@@ -436,11 +450,13 @@ def create_graphql_router(
436
450
  model: Model,
437
451
  export_path: Path,
438
452
  last_updated_at: CanGetLastUpdatedAt,
453
+ authentication_enabled: bool,
439
454
  corpus: Optional[Model] = None,
440
455
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
441
456
  event_queue: CanPutItem[DmlEvent],
442
457
  read_only: bool = False,
443
458
  secret: Optional[str] = None,
459
+ token_store: Optional[TokenStore] = None,
444
460
  ) -> GraphQLRouter: # type: ignore[type-arg]
445
461
  """Creates the GraphQL router.
446
462
 
@@ -450,6 +466,7 @@ def create_graphql_router(
450
466
  model (Model): The Model representing inferences (legacy)
451
467
  export_path (Path): the file path to export data to for download (legacy)
452
468
  last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
469
+ authentication_enabled (bool): Whether authentication is enabled.
453
470
  event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
454
471
  corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
455
472
  cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
@@ -521,18 +538,23 @@ def create_graphql_router(
521
538
  ),
522
539
  trace_row_ids=TraceRowIdsDataLoader(db),
523
540
  project_by_name=ProjectByNameDataLoader(db),
541
+ users=UsersDataLoader(db),
542
+ user_roles=UserRolesDataLoader(db),
524
543
  ),
525
544
  cache_for_dataloaders=cache_for_dataloaders,
526
545
  read_only=read_only,
546
+ auth_enabled=authentication_enabled,
527
547
  secret=secret,
548
+ token_store=token_store,
528
549
  )
529
550
 
530
551
  return GraphQLRouter(
531
552
  schema,
532
- graphiql=True,
553
+ graphql_ide="graphiql",
533
554
  context_getter=get_context,
534
555
  include_in_schema=False,
535
556
  prefix="/graphql",
557
+ dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
536
558
  )
537
559
 
538
560
 
@@ -600,11 +622,19 @@ def create_app(
600
622
  startup_callbacks: Iterable[_Callback] = (),
601
623
  shutdown_callbacks: Iterable[_Callback] = (),
602
624
  secret: Optional[str] = None,
625
+ password_reset_token_expiry: Optional[timedelta] = None,
626
+ access_token_expiry: Optional[timedelta] = None,
627
+ refresh_token_expiry: Optional[timedelta] = None,
603
628
  scaffolder_config: Optional[ScaffolderConfig] = None,
629
+ email_sender: Optional[EmailSender] = None,
630
+ oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
631
+ bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
604
632
  ) -> FastAPI:
605
633
  logger.info(f"Server umap params: {umap_params}")
634
+ bulk_inserter_factory = bulk_inserter_factory or BulkInserter
606
635
  startup_callbacks_list: List[_Callback] = list(startup_callbacks)
607
636
  shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
637
+ startup_callbacks_list.append(Facilitator(db=db))
608
638
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
609
639
  ()
610
640
  if initial_spans is None
@@ -619,12 +649,22 @@ def create_app(
619
649
  )
620
650
  last_updated_at = LastUpdatedAt()
621
651
  middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
652
+ if authentication_enabled and secret:
653
+ token_store = JwtStore(db, secret)
654
+ middlewares.append(
655
+ Middleware(
656
+ AuthenticationMiddleware,
657
+ backend=BearerTokenAuthBackend(token_store),
658
+ )
659
+ )
660
+ else:
661
+ token_store = None
622
662
  dml_event_handler = DmlEventHandler(
623
663
  db=db,
624
664
  cache_for_dataloaders=cache_for_dataloaders,
625
665
  last_updated_at=last_updated_at,
626
666
  )
627
- bulk_inserter = BulkInserter(
667
+ bulk_inserter = bulk_inserter_factory(
628
668
  db,
629
669
  enable_prometheus=enable_prometheus,
630
670
  event_queue=dml_event_handler,
@@ -662,12 +702,14 @@ def create_app(
662
702
  ),
663
703
  model=model,
664
704
  corpus=corpus,
705
+ authentication_enabled=authentication_enabled,
665
706
  export_path=export_path,
666
707
  last_updated_at=last_updated_at,
667
708
  event_queue=dml_event_handler,
668
709
  cache_for_dataloaders=cache_for_dataloaders,
669
710
  read_only=read_only,
670
711
  secret=secret,
712
+ token_store=token_store,
671
713
  )
672
714
  if enable_prometheus:
673
715
  from phoenix.server.prometheus import PrometheusMiddleware
@@ -681,6 +723,7 @@ def create_app(
681
723
  read_only=read_only,
682
724
  bulk_inserter=bulk_inserter,
683
725
  dml_event_handler=dml_event_handler,
726
+ token_store=token_store,
684
727
  tracer_provider=tracer_provider,
685
728
  enable_prometheus=enable_prometheus,
686
729
  shutdown_callbacks=shutdown_callbacks_list,
@@ -694,14 +737,20 @@ def create_app(
694
737
  "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
695
738
  },
696
739
  )
697
- app.state.read_only = read_only
698
- app.state.export_path = export_path
699
- app.include_router(v1_router)
740
+ app.include_router(create_v1_router(authentication_enabled))
741
+ app.include_router(create_embeddings_router(authentication_enabled))
700
742
  app.include_router(router)
701
743
  app.include_router(graphql_router)
744
+ if authentication_enabled:
745
+ app.include_router(auth_router)
746
+ app.include_router(oauth2_router)
702
747
  app.add_middleware(GZipMiddleware)
703
748
  web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
704
749
  if serve_ui and web_manifest_path.is_file():
750
+ oauth2_idps = [
751
+ OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
752
+ for config in oauth2_client_configs or []
753
+ ]
705
754
  app.mount(
706
755
  "/",
707
756
  app=Static(
@@ -715,11 +764,21 @@ def create_app(
715
764
  is_development=dev,
716
765
  authentication_enabled=authentication_enabled,
717
766
  web_manifest_path=web_manifest_path,
767
+ oauth2_idps=oauth2_idps,
718
768
  ),
719
769
  ),
720
770
  name="static",
721
771
  )
722
- app = _update_app_state(app, db=db, secret=secret)
772
+ app.state.read_only = read_only
773
+ app.state.export_path = export_path
774
+ app.state.password_reset_token_expiry = password_reset_token_expiry
775
+ app.state.access_token_expiry = access_token_expiry
776
+ app.state.refresh_token_expiry = refresh_token_expiry
777
+ app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
778
+ app.state.db = db
779
+ app.state.email_sender = email_sender
780
+ app = _add_get_secret_method(app=app, secret=secret)
781
+ app = _add_get_token_store_method(app=app, token_store=token_store)
723
782
  if tracer_provider:
724
783
  from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
725
784
 
@@ -729,13 +788,10 @@ def create_app(
729
788
  return app
730
789
 
731
790
 
732
- def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
791
+ def _add_get_secret_method(*, app: FastAPI, secret: Optional[str]) -> FastAPI:
733
792
  """
734
- Dynamically updates the app's `state` to include useful fields and methods
735
- (at the time of this writing, FastAPI does not support setting this state
736
- during the creation of the app).
793
+ Dynamically adds a `get_secret` method to the app's `state`.
737
794
  """
738
- app.state.db = db
739
795
  app.state._secret = secret
740
796
 
741
797
  def get_secret(self: StarletteState) -> str:
@@ -746,3 +802,19 @@ def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional
746
802
 
747
803
  app.state.get_secret = MethodType(get_secret, app.state)
748
804
  return app
805
+
806
+
807
+ def _add_get_token_store_method(*, app: FastAPI, token_store: Optional[JwtStore]) -> FastAPI:
808
+ """
809
+ Dynamically adds a `get_token_store` method to the app's `state`.
810
+ """
811
+ app.state._token_store = token_store
812
+
813
+ def get_token_store(self: StarletteState) -> JwtStore:
814
+ if (token_store := self._token_store) is None:
815
+ raise ValueError("token store is not set on the app")
816
+ assert isinstance(token_store, JwtStore)
817
+ return token_store
818
+
819
+ app.state.get_token_store = MethodType(get_token_store, app.state)
820
+ return app
@@ -0,0 +1,161 @@
1
+ from abc import ABC
2
+ from datetime import datetime, timedelta, timezone
3
+ from functools import cached_property
4
+ from typing import (
5
+ Any,
6
+ Awaitable,
7
+ Callable,
8
+ Optional,
9
+ Tuple,
10
+ )
11
+
12
+ import grpc
13
+ from fastapi import HTTPException, Request
14
+ from grpc_interceptor import AsyncServerInterceptor
15
+ from grpc_interceptor.exceptions import Unauthenticated
16
+ from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
17
+ from starlette.requests import HTTPConnection
18
+ from starlette.status import HTTP_401_UNAUTHORIZED
19
+
20
+ from phoenix.auth import (
21
+ PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
22
+ CanReadToken,
23
+ ClaimSetStatus,
24
+ Token,
25
+ )
26
+ from phoenix.db import enums
27
+ from phoenix.db.enums import UserRole
28
+ from phoenix.db.models import User as OrmUser
29
+ from phoenix.server.types import (
30
+ AccessToken,
31
+ AccessTokenAttributes,
32
+ AccessTokenClaims,
33
+ ApiKeyClaims,
34
+ RefreshToken,
35
+ RefreshTokenAttributes,
36
+ RefreshTokenClaims,
37
+ TokenStore,
38
+ UserClaimSet,
39
+ UserId,
40
+ )
41
+
42
+
43
+ class HasTokenStore(ABC):
44
+ def __init__(self, token_store: CanReadToken) -> None:
45
+ super().__init__()
46
+ self._token_store = token_store
47
+
48
+
49
+ class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
50
+ async def authenticate(
51
+ self,
52
+ conn: HTTPConnection,
53
+ ) -> Optional[Tuple[AuthCredentials, BaseUser]]:
54
+ if header := conn.headers.get("Authorization"):
55
+ scheme, _, token = header.partition(" ")
56
+ if scheme.lower() != "bearer" or not token:
57
+ return None
58
+ elif access_token := conn.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME):
59
+ token = access_token
60
+ else:
61
+ return None
62
+ claims = await self._token_store.read(Token(token))
63
+ if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
64
+ return None
65
+ if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
66
+ return None
67
+ return AuthCredentials(), PhoenixUser(claims.subject, claims)
68
+
69
+
70
+ class PhoenixUser(BaseUser):
71
+ def __init__(self, user_id: UserId, claims: UserClaimSet) -> None:
72
+ self._user_id = user_id
73
+ self.claims = claims
74
+ assert claims.attributes
75
+ self._is_admin = (
76
+ claims.status is ClaimSetStatus.VALID
77
+ and claims.attributes.user_role == enums.UserRole.ADMIN
78
+ )
79
+
80
+ @cached_property
81
+ def is_admin(self) -> bool:
82
+ return self._is_admin
83
+
84
+ @cached_property
85
+ def identity(self) -> UserId:
86
+ return self._user_id
87
+
88
+ @cached_property
89
+ def is_authenticated(self) -> bool:
90
+ return True
91
+
92
+
93
+ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
94
+ async def intercept(
95
+ self,
96
+ method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
97
+ request_or_iterator: Any,
98
+ context: grpc.ServicerContext,
99
+ method_name: str,
100
+ ) -> Any:
101
+ for datum in context.invocation_metadata():
102
+ if datum.key.lower() == "authorization":
103
+ scheme, _, token = datum.value.partition(" ")
104
+ if scheme.lower() != "bearer" or not token:
105
+ break
106
+ claims = await self._token_store.read(Token(token))
107
+ if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
108
+ break
109
+ if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
110
+ raise Unauthenticated(details="Invalid token")
111
+ if claims.status is ClaimSetStatus.EXPIRED:
112
+ raise Unauthenticated(details="Expired token")
113
+ if claims.status is ClaimSetStatus.VALID:
114
+ return await method(request_or_iterator, context)
115
+ raise Unauthenticated()
116
+ raise Unauthenticated()
117
+
118
+
119
+ async def is_authenticated(request: Request) -> None:
120
+ """
121
+ Raises a 401 if the request is not authenticated.
122
+ """
123
+ if not isinstance((user := request.user), PhoenixUser):
124
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
125
+ claims = user.claims
126
+ if claims.status is ClaimSetStatus.EXPIRED:
127
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
128
+ if claims.status is not ClaimSetStatus.VALID:
129
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
130
+
131
+
132
+ async def create_access_and_refresh_tokens(
133
+ *,
134
+ token_store: TokenStore,
135
+ user: OrmUser,
136
+ access_token_expiry: timedelta,
137
+ refresh_token_expiry: timedelta,
138
+ ) -> Tuple[AccessToken, RefreshToken]:
139
+ issued_at = datetime.now(timezone.utc)
140
+ user_id = UserId(user.id)
141
+ user_role = UserRole(user.role.name)
142
+ refresh_token_claims = RefreshTokenClaims(
143
+ subject=user_id,
144
+ issued_at=issued_at,
145
+ expiration_time=issued_at + refresh_token_expiry,
146
+ attributes=RefreshTokenAttributes(
147
+ user_role=user_role,
148
+ ),
149
+ )
150
+ refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
151
+ access_token_claims = AccessTokenClaims(
152
+ subject=user_id,
153
+ issued_at=issued_at,
154
+ expiration_time=issued_at + access_token_expiry,
155
+ attributes=AccessTokenAttributes(
156
+ user_role=user_role,
157
+ refresh_token_id=refresh_token_id,
158
+ ),
159
+ )
160
+ access_token, _ = await token_store.create_access_token(access_token_claims)
161
+ return access_token, refresh_token
File without changes
@@ -0,0 +1,26 @@
1
+ from pathlib import Path
2
+
3
+ from fastapi_mail import ConnectionConfig, FastMail, MessageSchema
4
+
5
+ EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
6
+
7
+
8
+ class FastMailSender:
9
+ def __init__(self, conf: ConnectionConfig) -> None:
10
+ self._fm = FastMail(conf)
11
+
12
+ async def send_password_reset_email(
13
+ self,
14
+ email: str,
15
+ reset_url: str,
16
+ ) -> None:
17
+ message = MessageSchema(
18
+ subject="[Phoenix] Password Reset Request",
19
+ recipients=[email],
20
+ template_body=dict(reset_url=reset_url),
21
+ subtype="html",
22
+ )
23
+ await self._fm.send_message(
24
+ message,
25
+ template_name="password_reset.html",
26
+ )
File without changes
@@ -0,0 +1,19 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>Password Reset</title>
6
+ </head>
7
+ <body>
8
+ <p>Hello.</p>
9
+ <p>
10
+ You have requested a password reset. Please click on the link below to
11
+ reset your password:
12
+ </p>
13
+ <p>
14
+ <a id="reset-url" href="{{ reset_url }}">Reset Password</a
15
+ >
16
+ </p>
17
+ <p>If you did not make this request, please contact your administrator.</p>
18
+ </body>
19
+ </html>
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Protocol
4
+
5
+
6
+ class EmailSender(Protocol):
7
+ async def send_password_reset_email(
8
+ self,
9
+ email: str,
10
+ reset_url: str,
11
+ ) -> None: ...
@@ -12,7 +12,9 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
12
12
  )
13
13
  from typing_extensions import TypeAlias
14
14
 
15
+ from phoenix.auth import CanReadToken
15
16
  from phoenix.config import get_env_grpc_port
17
+ from phoenix.server.bearer_auth import ApiKeyInterceptor
16
18
  from phoenix.trace.otel import decode_otlp_span
17
19
  from phoenix.trace.schemas import Span
18
20
  from phoenix.utilities.project import get_project_name
@@ -52,17 +54,21 @@ class GrpcServer:
52
54
  tracer_provider: Optional["TracerProvider"] = None,
53
55
  enable_prometheus: bool = False,
54
56
  disabled: bool = False,
57
+ token_store: Optional[CanReadToken] = None,
55
58
  ) -> None:
56
59
  self._callback = callback
57
60
  self._server: Optional[Server] = None
58
61
  self._tracer_provider = tracer_provider
59
62
  self._enable_prometheus = enable_prometheus
60
63
  self._disabled = disabled
64
+ self._token_store = token_store
61
65
 
62
66
  async def __aenter__(self) -> None:
63
67
  if self._disabled:
64
68
  return
65
69
  interceptors: List[ServerInterceptor] = []
70
+ if self._token_store:
71
+ interceptors.append(ApiKeyInterceptor(self._token_store))
66
72
  if self._enable_prometheus:
67
73
  ...
68
74
  # TODO: convert to async interceptor