arize-phoenix 4.35.2__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (104) hide show
  1. {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/METADATA +10 -12
  2. {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/RECORD +92 -79
  3. phoenix/__init__.py +86 -0
  4. phoenix/auth.py +275 -14
  5. phoenix/config.py +369 -27
  6. phoenix/db/alembic.ini +0 -34
  7. phoenix/db/engines.py +27 -10
  8. phoenix/db/enums.py +20 -0
  9. phoenix/db/facilitator.py +112 -0
  10. phoenix/db/insertion/dataset.py +0 -1
  11. phoenix/db/insertion/types.py +1 -1
  12. phoenix/db/migrate.py +3 -3
  13. phoenix/db/migrations/env.py +0 -7
  14. phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
  15. phoenix/db/models.py +145 -60
  16. phoenix/experiments/evaluators/code_evaluators.py +9 -3
  17. phoenix/experiments/functions.py +1 -4
  18. phoenix/inferences/fixtures.py +0 -1
  19. phoenix/inferences/inferences.py +0 -1
  20. phoenix/logging/__init__.py +3 -0
  21. phoenix/logging/_config.py +90 -0
  22. phoenix/logging/_filter.py +6 -0
  23. phoenix/logging/_formatter.py +69 -0
  24. phoenix/metrics/__init__.py +0 -1
  25. phoenix/otel/settings.py +4 -4
  26. phoenix/server/api/README.md +28 -0
  27. phoenix/server/api/auth.py +32 -0
  28. phoenix/server/api/context.py +50 -2
  29. phoenix/server/api/dataloaders/__init__.py +4 -0
  30. phoenix/server/api/dataloaders/user_roles.py +30 -0
  31. phoenix/server/api/dataloaders/users.py +33 -0
  32. phoenix/server/api/exceptions.py +7 -0
  33. phoenix/server/api/mutations/__init__.py +0 -2
  34. phoenix/server/api/mutations/api_key_mutations.py +104 -86
  35. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  36. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  37. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  38. phoenix/server/api/mutations/project_mutations.py +3 -3
  39. phoenix/server/api/mutations/span_annotations_mutations.py +4 -4
  40. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  41. phoenix/server/api/mutations/user_mutations.py +282 -42
  42. phoenix/server/api/openapi/schema.py +2 -2
  43. phoenix/server/api/queries.py +48 -39
  44. phoenix/server/api/routers/__init__.py +11 -0
  45. phoenix/server/api/routers/auth.py +284 -0
  46. phoenix/server/api/routers/embeddings.py +26 -0
  47. phoenix/server/api/routers/oauth2.py +456 -0
  48. phoenix/server/api/routers/v1/__init__.py +38 -16
  49. phoenix/server/api/routers/v1/datasets.py +0 -1
  50. phoenix/server/api/types/ApiKey.py +11 -0
  51. phoenix/server/api/types/AuthMethod.py +9 -0
  52. phoenix/server/api/types/User.py +48 -4
  53. phoenix/server/api/types/UserApiKey.py +35 -1
  54. phoenix/server/api/types/UserRole.py +7 -0
  55. phoenix/server/app.py +105 -34
  56. phoenix/server/bearer_auth.py +161 -0
  57. phoenix/server/email/__init__.py +0 -0
  58. phoenix/server/email/sender.py +26 -0
  59. phoenix/server/email/templates/__init__.py +0 -0
  60. phoenix/server/email/templates/password_reset.html +19 -0
  61. phoenix/server/email/types.py +11 -0
  62. phoenix/server/grpc_server.py +6 -0
  63. phoenix/server/jwt_store.py +504 -0
  64. phoenix/server/main.py +61 -30
  65. phoenix/server/oauth2.py +51 -0
  66. phoenix/server/prometheus.py +20 -0
  67. phoenix/server/rate_limiters.py +191 -0
  68. phoenix/server/static/.vite/manifest.json +31 -31
  69. phoenix/server/static/assets/{components-Dte7_KRd.js → components-REunxTt6.js} +348 -286
  70. phoenix/server/static/assets/index-DAPJxlCw.js +101 -0
  71. phoenix/server/static/assets/{pages-CnTvEGEN.js → pages-1VrMk2pW.js} +559 -291
  72. phoenix/server/static/assets/{vendor-BC3OPQuM.js → vendor-B5IC0ivG.js} +5 -5
  73. phoenix/server/static/assets/{vendor-arizeai-NjB3cZzD.js → vendor-arizeai-aFbT4kl1.js} +2 -2
  74. phoenix/server/static/assets/{vendor-codemirror-gE_JCOgX.js → vendor-codemirror-BEGorXSV.js} +1 -1
  75. phoenix/server/static/assets/{vendor-recharts-BXLYwcXF.js → vendor-recharts-6nUU7gU_.js} +1 -1
  76. phoenix/server/telemetry.py +2 -2
  77. phoenix/server/templates/index.html +1 -0
  78. phoenix/server/types.py +157 -1
  79. phoenix/services.py +0 -1
  80. phoenix/session/client.py +7 -3
  81. phoenix/session/evaluation.py +0 -1
  82. phoenix/session/session.py +0 -1
  83. phoenix/settings.py +9 -0
  84. phoenix/trace/exporter.py +0 -1
  85. phoenix/trace/fixtures.py +0 -2
  86. phoenix/utilities/client.py +16 -0
  87. phoenix/utilities/logging.py +9 -1
  88. phoenix/utilities/re.py +3 -3
  89. phoenix/version.py +1 -1
  90. phoenix/db/migrations/future_versions/README.md +0 -4
  91. phoenix/db/migrations/future_versions/cd164e83824f_users_and_tokens.py +0 -293
  92. phoenix/db/migrations/versions/.gitignore +0 -1
  93. phoenix/server/api/mutations/auth.py +0 -18
  94. phoenix/server/api/mutations/auth_mutations.py +0 -65
  95. phoenix/server/static/assets/index-fq1-hCK4.js +0 -100
  96. phoenix/trace/langchain/__init__.py +0 -3
  97. phoenix/trace/langchain/instrumentor.py +0 -35
  98. phoenix/trace/llama_index/__init__.py +0 -3
  99. phoenix/trace/llama_index/callback.py +0 -103
  100. phoenix/trace/openai/__init__.py +0 -3
  101. phoenix/trace/openai/instrumentor.py +0 -31
  102. {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/WHEEL +0 -0
  103. {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  104. {arize_phoenix-4.35.2.dist-info → arize_phoenix-5.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py CHANGED
@@ -20,21 +20,23 @@ from typing import (
20
20
  List,
21
21
  NamedTuple,
22
22
  Optional,
23
+ Sequence,
23
24
  Tuple,
25
+ TypedDict,
24
26
  Union,
25
27
  cast,
26
28
  )
27
29
 
28
30
  import strawberry
29
- from fastapi import APIRouter, FastAPI
31
+ from fastapi import APIRouter, Depends, FastAPI
30
32
  from fastapi.middleware.gzip import GZipMiddleware
31
- from fastapi.responses import FileResponse
32
33
  from fastapi.utils import is_body_allowed_for_status_code
33
34
  from sqlalchemy import select
34
35
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
35
36
  from starlette.datastructures import State as StarletteState
36
37
  from starlette.exceptions import HTTPException
37
38
  from starlette.middleware import Middleware
39
+ from starlette.middleware.authentication import AuthenticationMiddleware
38
40
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
39
41
  from starlette.requests import Request
40
42
  from starlette.responses import PlainTextResponse, Response
@@ -50,6 +52,7 @@ import phoenix.trace.v1 as pb
50
52
  from phoenix.config import (
51
53
  DEFAULT_PROJECT_NAME,
52
54
  SERVER_DIR,
55
+ OAuth2ClientConfig,
53
56
  get_env_host,
54
57
  get_env_port,
55
58
  server_instrumentation_is_enabled,
@@ -58,6 +61,7 @@ from phoenix.core.model_schema import Model
58
61
  from phoenix.db import models
59
62
  from phoenix.db.bulk_inserter import BulkInserter
60
63
  from phoenix.db.engines import create_engine
64
+ from phoenix.db.facilitator import Facilitator
61
65
  from phoenix.db.helpers import SupportedSQLDialect
62
66
  from phoenix.exceptions import PhoenixMigrationError
63
67
  from phoenix.pointcloud.umap_parameters import UMAPParameters
@@ -86,13 +90,24 @@ from phoenix.server.api.dataloaders import (
86
90
  SpanProjectsDataLoader,
87
91
  TokenCountDataLoader,
88
92
  TraceRowIdsDataLoader,
93
+ UserRolesDataLoader,
94
+ UsersDataLoader,
95
+ )
96
+ from phoenix.server.api.routers import (
97
+ auth_router,
98
+ create_embeddings_router,
99
+ create_v1_router,
100
+ oauth2_router,
89
101
  )
90
102
  from phoenix.server.api.routers.v1 import REST_API_VERSION
91
- from phoenix.server.api.routers.v1 import router as v1_router
92
103
  from phoenix.server.api.schema import schema
104
+ from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
93
105
  from phoenix.server.dml_event import DmlEvent
94
106
  from phoenix.server.dml_event_handler import DmlEventHandler
107
+ from phoenix.server.email.types import EmailSender
95
108
  from phoenix.server.grpc_server import GrpcServer
109
+ from phoenix.server.jwt_store import JwtStore
110
+ from phoenix.server.oauth2 import OAuth2Clients
96
111
  from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
97
112
  from phoenix.server.types import (
98
113
  CanGetLastUpdatedAt,
@@ -100,6 +115,7 @@ from phoenix.server.types import (
100
115
  DaemonTask,
101
116
  DbSessionFactory,
102
117
  LastUpdatedAt,
118
+ TokenStore,
103
119
  )
104
120
  from phoenix.trace.fixtures import (
105
121
  TracesFixture,
@@ -118,8 +134,6 @@ if TYPE_CHECKING:
118
134
  from opentelemetry.trace import TracerProvider
119
135
 
120
136
  logger = logging.getLogger(__name__)
121
- logger.setLevel(logging.INFO)
122
- logger.addHandler(logging.NullHandler())
123
137
 
124
138
  router = APIRouter(include_in_schema=False)
125
139
 
@@ -137,6 +151,11 @@ ProjectName: TypeAlias = str
137
151
  _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
138
152
 
139
153
 
154
+ class OAuth2Idp(TypedDict):
155
+ name: str
156
+ displayName: str
157
+
158
+
140
159
  class AppConfig(NamedTuple):
141
160
  has_inferences: bool
142
161
  """ Whether the model has inferences (e.g. a primary dataset) """
@@ -148,6 +167,7 @@ class AppConfig(NamedTuple):
148
167
  web_manifest_path: Path
149
168
  authentication_enabled: bool
150
169
  """ Whether authentication is enabled """
170
+ oauth2_idps: Sequence[OAuth2Idp]
151
171
 
152
172
 
153
173
  class Static(StaticFiles):
@@ -196,6 +216,7 @@ class Static(StaticFiles):
196
216
  "is_development": self._app_config.is_development,
197
217
  "manifest": self._web_manifest,
198
218
  "authentication_enabled": self._app_config.authentication_enabled,
219
+ "oauth2_idps": self._app_config.oauth2_idps,
199
220
  },
200
221
  )
201
222
  except Exception as e:
@@ -220,18 +241,6 @@ class HeadersMiddleware(BaseHTTPMiddleware):
220
241
  ProjectRowId: TypeAlias = int
221
242
 
222
243
 
223
- @router.get("/exports")
224
- async def download_exported_file(request: Request, filename: str) -> FileResponse:
225
- file = request.app.state.export_path / (filename + ".parquet")
226
- if not file.is_file():
227
- raise HTTPException(status_code=404)
228
- return FileResponse(
229
- path=file,
230
- filename=file.name,
231
- media_type="application/x-octet-stream",
232
- )
233
-
234
-
235
244
  @router.get("/arize_phoenix_version")
236
245
  async def version() -> PlainTextResponse:
237
246
  return PlainTextResponse(f"{phoenix.__version__}")
@@ -240,13 +249,15 @@ async def version() -> PlainTextResponse:
240
249
  DB_MUTEX: Optional[asyncio.Lock] = None
241
250
 
242
251
 
243
- def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
252
+ def _db(
253
+ engine: AsyncEngine, bypass_lock: bool = False
254
+ ) -> Callable[[], AsyncContextManager[AsyncSession]]:
244
255
  Session = async_sessionmaker(engine, expire_on_commit=False)
245
256
 
246
257
  @contextlib.asynccontextmanager
247
258
  async def factory() -> AsyncIterator[AsyncSession]:
248
259
  async with contextlib.AsyncExitStack() as stack:
249
- if DB_MUTEX:
260
+ if not bypass_lock and DB_MUTEX:
250
261
  await stack.enter_async_context(DB_MUTEX)
251
262
  yield await stack.enter_async_context(Session.begin())
252
263
 
@@ -285,9 +296,6 @@ class Scaffolder(DaemonTask):
285
296
  return
286
297
  await self.start()
287
298
 
288
- async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
289
- await self.stop()
290
-
291
299
  async def _run(self) -> None:
292
300
  """
293
301
  Main entry point for Scaffolder.
@@ -376,6 +384,7 @@ def _lifespan(
376
384
  db: DbSessionFactory,
377
385
  bulk_inserter: BulkInserter,
378
386
  dml_event_handler: DmlEventHandler,
387
+ token_store: Optional[TokenStore] = None,
379
388
  tracer_provider: Optional["TracerProvider"] = None,
380
389
  enable_prometheus: bool = False,
381
390
  startup_callbacks: Iterable[_Callback] = (),
@@ -402,6 +411,7 @@ def _lifespan(
402
411
  disabled=read_only,
403
412
  tracer_provider=tracer_provider,
404
413
  enable_prometheus=enable_prometheus,
414
+ token_store=token_store,
405
415
  )
406
416
  await stack.enter_async_context(grpc_server)
407
417
  await stack.enter_async_context(dml_event_handler)
@@ -412,6 +422,8 @@ def _lifespan(
412
422
  queue_evaluation=queue_evaluation,
413
423
  )
414
424
  await stack.enter_async_context(scaffolder)
425
+ if isinstance(token_store, AsyncContextManager):
426
+ await stack.enter_async_context(token_store)
415
427
  yield {
416
428
  "event_queue": dml_event_handler,
417
429
  "enqueue": enqueue,
@@ -438,11 +450,13 @@ def create_graphql_router(
438
450
  model: Model,
439
451
  export_path: Path,
440
452
  last_updated_at: CanGetLastUpdatedAt,
453
+ authentication_enabled: bool,
441
454
  corpus: Optional[Model] = None,
442
455
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
443
456
  event_queue: CanPutItem[DmlEvent],
444
457
  read_only: bool = False,
445
458
  secret: Optional[str] = None,
459
+ token_store: Optional[TokenStore] = None,
446
460
  ) -> GraphQLRouter: # type: ignore[type-arg]
447
461
  """Creates the GraphQL router.
448
462
 
@@ -452,6 +466,7 @@ def create_graphql_router(
452
466
  model (Model): The Model representing inferences (legacy)
453
467
  export_path (Path): the file path to export data to for download (legacy)
454
468
  last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
469
+ authentication_enabled (bool): Whether authentication is enabled.
455
470
  event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
456
471
  corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
457
472
  cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
@@ -523,18 +538,23 @@ def create_graphql_router(
523
538
  ),
524
539
  trace_row_ids=TraceRowIdsDataLoader(db),
525
540
  project_by_name=ProjectByNameDataLoader(db),
541
+ users=UsersDataLoader(db),
542
+ user_roles=UserRolesDataLoader(db),
526
543
  ),
527
544
  cache_for_dataloaders=cache_for_dataloaders,
528
545
  read_only=read_only,
546
+ auth_enabled=authentication_enabled,
529
547
  secret=secret,
548
+ token_store=token_store,
530
549
  )
531
550
 
532
551
  return GraphQLRouter(
533
552
  schema,
534
- graphiql=True,
553
+ graphql_ide="graphiql",
535
554
  context_getter=get_context,
536
555
  include_in_schema=False,
537
556
  prefix="/graphql",
557
+ dependencies=(Depends(is_authenticated),) if authentication_enabled else (),
538
558
  )
539
559
 
540
560
 
@@ -542,7 +562,7 @@ def create_engine_and_run_migrations(
542
562
  database_url: str,
543
563
  ) -> AsyncEngine:
544
564
  try:
545
- return create_engine(database_url)
565
+ return create_engine(connection_str=database_url, migrate=True, log_to_stdout=False)
546
566
  except PhoenixMigrationError as e:
547
567
  msg = (
548
568
  "\n\n⚠️⚠️ Phoenix failed to migrate the database to the latest version. ⚠️⚠️\n\n"
@@ -602,10 +622,19 @@ def create_app(
602
622
  startup_callbacks: Iterable[_Callback] = (),
603
623
  shutdown_callbacks: Iterable[_Callback] = (),
604
624
  secret: Optional[str] = None,
625
+ password_reset_token_expiry: Optional[timedelta] = None,
626
+ access_token_expiry: Optional[timedelta] = None,
627
+ refresh_token_expiry: Optional[timedelta] = None,
605
628
  scaffolder_config: Optional[ScaffolderConfig] = None,
629
+ email_sender: Optional[EmailSender] = None,
630
+ oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None,
631
+ bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None,
606
632
  ) -> FastAPI:
633
+ logger.info(f"Server umap params: {umap_params}")
634
+ bulk_inserter_factory = bulk_inserter_factory or BulkInserter
607
635
  startup_callbacks_list: List[_Callback] = list(startup_callbacks)
608
636
  shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
637
+ startup_callbacks_list.append(Facilitator(db=db))
609
638
  initial_batch_of_spans: Iterable[Tuple[Span, str]] = (
610
639
  ()
611
640
  if initial_spans is None
@@ -620,12 +649,22 @@ def create_app(
620
649
  )
621
650
  last_updated_at = LastUpdatedAt()
622
651
  middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
652
+ if authentication_enabled and secret:
653
+ token_store = JwtStore(db, secret)
654
+ middlewares.append(
655
+ Middleware(
656
+ AuthenticationMiddleware,
657
+ backend=BearerTokenAuthBackend(token_store),
658
+ )
659
+ )
660
+ else:
661
+ token_store = None
623
662
  dml_event_handler = DmlEventHandler(
624
663
  db=db,
625
664
  cache_for_dataloaders=cache_for_dataloaders,
626
665
  last_updated_at=last_updated_at,
627
666
  )
628
- bulk_inserter = BulkInserter(
667
+ bulk_inserter = bulk_inserter_factory(
629
668
  db,
630
669
  enable_prometheus=enable_prometheus,
631
670
  event_queue=dml_event_handler,
@@ -663,12 +702,14 @@ def create_app(
663
702
  ),
664
703
  model=model,
665
704
  corpus=corpus,
705
+ authentication_enabled=authentication_enabled,
666
706
  export_path=export_path,
667
707
  last_updated_at=last_updated_at,
668
708
  event_queue=dml_event_handler,
669
709
  cache_for_dataloaders=cache_for_dataloaders,
670
710
  read_only=read_only,
671
711
  secret=secret,
712
+ token_store=token_store,
672
713
  )
673
714
  if enable_prometheus:
674
715
  from phoenix.server.prometheus import PrometheusMiddleware
@@ -682,6 +723,7 @@ def create_app(
682
723
  read_only=read_only,
683
724
  bulk_inserter=bulk_inserter,
684
725
  dml_event_handler=dml_event_handler,
726
+ token_store=token_store,
685
727
  tracer_provider=tracer_provider,
686
728
  enable_prometheus=enable_prometheus,
687
729
  shutdown_callbacks=shutdown_callbacks_list,
@@ -695,14 +737,20 @@ def create_app(
695
737
  "defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
696
738
  },
697
739
  )
698
- app.state.read_only = read_only
699
- app.state.export_path = export_path
700
- app.include_router(v1_router)
740
+ app.include_router(create_v1_router(authentication_enabled))
741
+ app.include_router(create_embeddings_router(authentication_enabled))
701
742
  app.include_router(router)
702
743
  app.include_router(graphql_router)
744
+ if authentication_enabled:
745
+ app.include_router(auth_router)
746
+ app.include_router(oauth2_router)
703
747
  app.add_middleware(GZipMiddleware)
704
748
  web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
705
749
  if serve_ui and web_manifest_path.is_file():
750
+ oauth2_idps = [
751
+ OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
752
+ for config in oauth2_client_configs or []
753
+ ]
706
754
  app.mount(
707
755
  "/",
708
756
  app=Static(
@@ -716,11 +764,21 @@ def create_app(
716
764
  is_development=dev,
717
765
  authentication_enabled=authentication_enabled,
718
766
  web_manifest_path=web_manifest_path,
767
+ oauth2_idps=oauth2_idps,
719
768
  ),
720
769
  ),
721
770
  name="static",
722
771
  )
723
- app = _update_app_state(app, db=db, secret=secret)
772
+ app.state.read_only = read_only
773
+ app.state.export_path = export_path
774
+ app.state.password_reset_token_expiry = password_reset_token_expiry
775
+ app.state.access_token_expiry = access_token_expiry
776
+ app.state.refresh_token_expiry = refresh_token_expiry
777
+ app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
778
+ app.state.db = db
779
+ app.state.email_sender = email_sender
780
+ app = _add_get_secret_method(app=app, secret=secret)
781
+ app = _add_get_token_store_method(app=app, token_store=token_store)
724
782
  if tracer_provider:
725
783
  from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
726
784
 
@@ -730,13 +788,10 @@ def create_app(
730
788
  return app
731
789
 
732
790
 
733
- def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional[str]) -> FastAPI:
791
+ def _add_get_secret_method(*, app: FastAPI, secret: Optional[str]) -> FastAPI:
734
792
  """
735
- Dynamically updates the app's `state` to include useful fields and methods
736
- (at the time of this writing, FastAPI does not support setting this state
737
- during the creation of the app).
793
+ Dynamically adds a `get_secret` method to the app's `state`.
738
794
  """
739
- app.state.db = db
740
795
  app.state._secret = secret
741
796
 
742
797
  def get_secret(self: StarletteState) -> str:
@@ -747,3 +802,19 @@ def _update_app_state(app: FastAPI, /, *, db: DbSessionFactory, secret: Optional
747
802
 
748
803
  app.state.get_secret = MethodType(get_secret, app.state)
749
804
  return app
805
+
806
+
807
+ def _add_get_token_store_method(*, app: FastAPI, token_store: Optional[JwtStore]) -> FastAPI:
808
+ """
809
+ Dynamically adds a `get_token_store` method to the app's `state`.
810
+ """
811
+ app.state._token_store = token_store
812
+
813
+ def get_token_store(self: StarletteState) -> JwtStore:
814
+ if (token_store := self._token_store) is None:
815
+ raise ValueError("token store is not set on the app")
816
+ assert isinstance(token_store, JwtStore)
817
+ return token_store
818
+
819
+ app.state.get_token_store = MethodType(get_token_store, app.state)
820
+ return app
@@ -0,0 +1,161 @@
1
+ from abc import ABC
2
+ from datetime import datetime, timedelta, timezone
3
+ from functools import cached_property
4
+ from typing import (
5
+ Any,
6
+ Awaitable,
7
+ Callable,
8
+ Optional,
9
+ Tuple,
10
+ )
11
+
12
+ import grpc
13
+ from fastapi import HTTPException, Request
14
+ from grpc_interceptor import AsyncServerInterceptor
15
+ from grpc_interceptor.exceptions import Unauthenticated
16
+ from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
17
+ from starlette.requests import HTTPConnection
18
+ from starlette.status import HTTP_401_UNAUTHORIZED
19
+
20
+ from phoenix.auth import (
21
+ PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
22
+ CanReadToken,
23
+ ClaimSetStatus,
24
+ Token,
25
+ )
26
+ from phoenix.db import enums
27
+ from phoenix.db.enums import UserRole
28
+ from phoenix.db.models import User as OrmUser
29
+ from phoenix.server.types import (
30
+ AccessToken,
31
+ AccessTokenAttributes,
32
+ AccessTokenClaims,
33
+ ApiKeyClaims,
34
+ RefreshToken,
35
+ RefreshTokenAttributes,
36
+ RefreshTokenClaims,
37
+ TokenStore,
38
+ UserClaimSet,
39
+ UserId,
40
+ )
41
+
42
+
43
+ class HasTokenStore(ABC):
44
+ def __init__(self, token_store: CanReadToken) -> None:
45
+ super().__init__()
46
+ self._token_store = token_store
47
+
48
+
49
+ class BearerTokenAuthBackend(HasTokenStore, AuthenticationBackend):
50
+ async def authenticate(
51
+ self,
52
+ conn: HTTPConnection,
53
+ ) -> Optional[Tuple[AuthCredentials, BaseUser]]:
54
+ if header := conn.headers.get("Authorization"):
55
+ scheme, _, token = header.partition(" ")
56
+ if scheme.lower() != "bearer" or not token:
57
+ return None
58
+ elif access_token := conn.cookies.get(PHOENIX_ACCESS_TOKEN_COOKIE_NAME):
59
+ token = access_token
60
+ else:
61
+ return None
62
+ claims = await self._token_store.read(Token(token))
63
+ if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
64
+ return None
65
+ if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
66
+ return None
67
+ return AuthCredentials(), PhoenixUser(claims.subject, claims)
68
+
69
+
70
+ class PhoenixUser(BaseUser):
71
+ def __init__(self, user_id: UserId, claims: UserClaimSet) -> None:
72
+ self._user_id = user_id
73
+ self.claims = claims
74
+ assert claims.attributes
75
+ self._is_admin = (
76
+ claims.status is ClaimSetStatus.VALID
77
+ and claims.attributes.user_role == enums.UserRole.ADMIN
78
+ )
79
+
80
+ @cached_property
81
+ def is_admin(self) -> bool:
82
+ return self._is_admin
83
+
84
+ @cached_property
85
+ def identity(self) -> UserId:
86
+ return self._user_id
87
+
88
+ @cached_property
89
+ def is_authenticated(self) -> bool:
90
+ return True
91
+
92
+
93
+ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
94
+ async def intercept(
95
+ self,
96
+ method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
97
+ request_or_iterator: Any,
98
+ context: grpc.ServicerContext,
99
+ method_name: str,
100
+ ) -> Any:
101
+ for datum in context.invocation_metadata():
102
+ if datum.key.lower() == "authorization":
103
+ scheme, _, token = datum.value.partition(" ")
104
+ if scheme.lower() != "bearer" or not token:
105
+ break
106
+ claims = await self._token_store.read(Token(token))
107
+ if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
108
+ break
109
+ if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
110
+ raise Unauthenticated(details="Invalid token")
111
+ if claims.status is ClaimSetStatus.EXPIRED:
112
+ raise Unauthenticated(details="Expired token")
113
+ if claims.status is ClaimSetStatus.VALID:
114
+ return await method(request_or_iterator, context)
115
+ raise Unauthenticated()
116
+ raise Unauthenticated()
117
+
118
+
119
+ async def is_authenticated(request: Request) -> None:
120
+ """
121
+ Raises a 401 if the request is not authenticated.
122
+ """
123
+ if not isinstance((user := request.user), PhoenixUser):
124
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
125
+ claims = user.claims
126
+ if claims.status is ClaimSetStatus.EXPIRED:
127
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
128
+ if claims.status is not ClaimSetStatus.VALID:
129
+ raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
130
+
131
+
132
+ async def create_access_and_refresh_tokens(
133
+ *,
134
+ token_store: TokenStore,
135
+ user: OrmUser,
136
+ access_token_expiry: timedelta,
137
+ refresh_token_expiry: timedelta,
138
+ ) -> Tuple[AccessToken, RefreshToken]:
139
+ issued_at = datetime.now(timezone.utc)
140
+ user_id = UserId(user.id)
141
+ user_role = UserRole(user.role.name)
142
+ refresh_token_claims = RefreshTokenClaims(
143
+ subject=user_id,
144
+ issued_at=issued_at,
145
+ expiration_time=issued_at + refresh_token_expiry,
146
+ attributes=RefreshTokenAttributes(
147
+ user_role=user_role,
148
+ ),
149
+ )
150
+ refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
151
+ access_token_claims = AccessTokenClaims(
152
+ subject=user_id,
153
+ issued_at=issued_at,
154
+ expiration_time=issued_at + access_token_expiry,
155
+ attributes=AccessTokenAttributes(
156
+ user_role=user_role,
157
+ refresh_token_id=refresh_token_id,
158
+ ),
159
+ )
160
+ access_token, _ = await token_store.create_access_token(access_token_claims)
161
+ return access_token, refresh_token
File without changes
@@ -0,0 +1,26 @@
1
+ from pathlib import Path
2
+
3
+ from fastapi_mail import ConnectionConfig, FastMail, MessageSchema
4
+
5
+ EMAIL_TEMPLATE_FOLDER = Path(__file__).parent / "templates"
6
+
7
+
8
+ class FastMailSender:
9
+ def __init__(self, conf: ConnectionConfig) -> None:
10
+ self._fm = FastMail(conf)
11
+
12
+ async def send_password_reset_email(
13
+ self,
14
+ email: str,
15
+ reset_url: str,
16
+ ) -> None:
17
+ message = MessageSchema(
18
+ subject="[Phoenix] Password Reset Request",
19
+ recipients=[email],
20
+ template_body=dict(reset_url=reset_url),
21
+ subtype="html",
22
+ )
23
+ await self._fm.send_message(
24
+ message,
25
+ template_name="password_reset.html",
26
+ )
File without changes
@@ -0,0 +1,19 @@
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <title>Password Reset</title>
6
+ </head>
7
+ <body>
8
+ <p>Hello.</p>
9
+ <p>
10
+ You have requested a password reset. Please click on the link below to
11
+ reset your password:
12
+ </p>
13
+ <p>
14
+ <a id="reset-url" href="{{ reset_url }}">Reset Password</a
15
+ >
16
+ </p>
17
+ <p>If you did not make this request, please contact your administrator.</p>
18
+ </body>
19
+ </html>
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Protocol
4
+
5
+
6
+ class EmailSender(Protocol):
7
+ async def send_password_reset_email(
8
+ self,
9
+ email: str,
10
+ reset_url: str,
11
+ ) -> None: ...
@@ -12,7 +12,9 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
12
12
  )
13
13
  from typing_extensions import TypeAlias
14
14
 
15
+ from phoenix.auth import CanReadToken
15
16
  from phoenix.config import get_env_grpc_port
17
+ from phoenix.server.bearer_auth import ApiKeyInterceptor
16
18
  from phoenix.trace.otel import decode_otlp_span
17
19
  from phoenix.trace.schemas import Span
18
20
  from phoenix.utilities.project import get_project_name
@@ -52,17 +54,21 @@ class GrpcServer:
52
54
  tracer_provider: Optional["TracerProvider"] = None,
53
55
  enable_prometheus: bool = False,
54
56
  disabled: bool = False,
57
+ token_store: Optional[CanReadToken] = None,
55
58
  ) -> None:
56
59
  self._callback = callback
57
60
  self._server: Optional[Server] = None
58
61
  self._tracer_provider = tracer_provider
59
62
  self._enable_prometheus = enable_prometheus
60
63
  self._disabled = disabled
64
+ self._token_store = token_store
61
65
 
62
66
  async def __aenter__(self) -> None:
63
67
  if self._disabled:
64
68
  return
65
69
  interceptors: List[ServerInterceptor] = []
70
+ if self._token_store:
71
+ interceptors.append(ApiKeyInterceptor(self._token_store))
66
72
  if self._enable_prometheus:
67
73
  ...
68
74
  # TODO: convert to async interceptor