arize-phoenix 11.4.0__py3-none-any.whl → 11.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (42) hide show
  1. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/METADATA +2 -2
  2. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/RECORD +42 -40
  3. phoenix/config.py +51 -2
  4. phoenix/server/api/auth.py +1 -1
  5. phoenix/server/api/queries.py +52 -44
  6. phoenix/server/api/routers/v1/annotation_configs.py +4 -1
  7. phoenix/server/api/routers/v1/datasets.py +3 -1
  8. phoenix/server/api/routers/v1/evaluations.py +3 -1
  9. phoenix/server/api/routers/v1/experiment_runs.py +3 -1
  10. phoenix/server/api/routers/v1/experiments.py +3 -1
  11. phoenix/server/api/routers/v1/projects.py +4 -1
  12. phoenix/server/api/routers/v1/prompts.py +4 -1
  13. phoenix/server/api/routers/v1/spans.py +6 -3
  14. phoenix/server/api/routers/v1/traces.py +4 -1
  15. phoenix/server/api/routers/v1/users.py +2 -2
  16. phoenix/server/api/types/Span.py +0 -99
  17. phoenix/server/app.py +47 -12
  18. phoenix/server/authorization.py +9 -0
  19. phoenix/server/bearer_auth.py +18 -15
  20. phoenix/server/cost_tracking/cost_model_lookup.py +1 -1
  21. phoenix/server/cost_tracking/model_cost_manifest.json +107 -107
  22. phoenix/server/daemons/db_disk_usage_monitor.py +215 -0
  23. phoenix/server/email/sender.py +25 -0
  24. phoenix/server/email/templates/db_disk_usage_notification.html +16 -0
  25. phoenix/server/email/types.py +11 -0
  26. phoenix/server/grpc_server.py +3 -3
  27. phoenix/server/prometheus.py +22 -0
  28. phoenix/server/static/.vite/manifest.json +44 -44
  29. phoenix/server/static/assets/{components-CVcMbu2U.js → components-mOUBHJ12.js} +297 -298
  30. phoenix/server/static/assets/{index-Dz7I-Hpn.js → index-CQ_A6K_M.js} +2 -2
  31. phoenix/server/static/assets/{pages-QK2o2V7x.js → pages-CCsLkNZY.js} +517 -498
  32. phoenix/server/static/assets/{vendor-pg5m6BWE.js → vendor-DRWIRkSJ.js} +1 -1
  33. phoenix/server/static/assets/{vendor-arizeai-BwMsgSAG.js → vendor-arizeai-DUhQaeau.js} +1 -1
  34. phoenix/server/static/assets/{vendor-codemirror-BwSDEu2g.js → vendor-codemirror-D_6Q6Auv.js} +1 -1
  35. phoenix/server/static/assets/{vendor-recharts-SW3HwAtG.js → vendor-recharts-BNBwj7vz.js} +1 -1
  36. phoenix/server/static/assets/{vendor-shiki-BsdYoDvs.js → vendor-shiki-k1qj_XjP.js} +1 -1
  37. phoenix/server/types.py +11 -2
  38. phoenix/version.py +1 -1
  39. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/WHEEL +0 -0
  40. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/entry_points.txt +0 -0
  41. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/licenses/IP_NOTICE +0 -0
  42. {arize_phoenix-11.4.0.dist-info → arize_phoenix-11.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import Any, Optional, Union
3
3
 
4
- from fastapi import APIRouter, HTTPException, Path, Query
4
+ from fastapi import APIRouter, Depends, HTTPException, Path, Query
5
5
  from pydantic import ValidationError, model_validator
6
6
  from sqlalchemy import select
7
7
  from sqlalchemy.sql import Select
@@ -33,6 +33,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
33
33
  from phoenix.server.api.types.Prompt import Prompt as PromptNodeType
34
34
  from phoenix.server.api.types.PromptVersion import PromptVersion as PromptVersionNodeType
35
35
  from phoenix.server.api.types.PromptVersionTag import PromptVersionTag as PromptVersionTagNodeType
36
+ from phoenix.server.authorization import is_not_locked
36
37
  from phoenix.server.bearer_auth import PhoenixUser
37
38
 
38
39
  logger = logging.getLogger(__name__)
@@ -393,6 +394,7 @@ async def get_prompt_version_by_latest(
393
394
 
394
395
  @router.post(
395
396
  "/prompts",
397
+ dependencies=[Depends(is_not_locked)],
396
398
  operation_id="postPromptVersion",
397
399
  summary="Create a new prompt",
398
400
  description="Create a new prompt and its initial version. A prompt can have multiple versions.",
@@ -602,6 +604,7 @@ async def list_prompt_version_tags(
602
604
 
603
605
  @router.post(
604
606
  "/prompt_versions/{prompt_version_id}/tags",
607
+ dependencies=[Depends(is_not_locked)],
605
608
  operation_id="createPromptVersionTag",
606
609
  summary="Add tag to prompt version",
607
610
  description="Add a new tag to a specific prompt version. Tags help identify and categorize "
@@ -8,7 +8,7 @@ from secrets import token_urlsafe
8
8
  from typing import Annotated, Any, Literal, Optional, Union
9
9
 
10
10
  import pandas as pd
11
- from fastapi import APIRouter, Header, HTTPException, Path, Query
11
+ from fastapi import APIRouter, Depends, Header, HTTPException, Path, Query
12
12
  from pydantic import BaseModel, Field
13
13
  from sqlalchemy import select
14
14
  from starlette.requests import Request
@@ -28,6 +28,7 @@ from phoenix.db.helpers import SupportedSQLDialect
28
28
  from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
29
29
  from phoenix.db.insertion.types import Precursors
30
30
  from phoenix.server.api.routers.utils import df_to_bytes
31
+ from phoenix.server.authorization import is_not_locked
31
32
  from phoenix.server.bearer_auth import PhoenixUser
32
33
  from phoenix.server.dml_event import SpanAnnotationInsertEvent
33
34
  from phoenix.trace.attributes import flatten
@@ -601,7 +602,7 @@ async def span_search_otlpv1(
601
602
  models.Trace.trace_id,
602
603
  )
603
604
  .join(models.Trace, onclause=models.Trace.id == models.Span.trace_rowid)
604
- .join(models.Project, onclause=models.Project.id == project_id)
605
+ .where(models.Trace.project_rowid == project_id)
605
606
  .order_by(*order_by)
606
607
  )
607
608
 
@@ -736,7 +737,7 @@ async def span_search(
736
737
  models.Trace.trace_id,
737
738
  )
738
739
  .join(models.Trace, onclause=models.Trace.id == models.Span.trace_rowid)
739
- .join(models.Project, onclause=models.Project.id == project_id)
740
+ .where(models.Trace.project_rowid == project_id)
740
741
  .order_by(*order_by)
741
742
  )
742
743
 
@@ -907,6 +908,7 @@ class AnnotateSpansResponseBody(ResponseBody[list[InsertedSpanAnnotation]]):
907
908
 
908
909
  @router.post(
909
910
  "/span_annotations",
911
+ dependencies=[Depends(is_not_locked)],
910
912
  operation_id="annotateSpans",
911
913
  summary="Create span annotations",
912
914
  responses=add_errors_to_responses(
@@ -990,6 +992,7 @@ class CreateSpansResponseBody(V1RoutesBaseModel):
990
992
 
991
993
  @router.post(
992
994
  "/projects/{project_identifier}/spans",
995
+ dependencies=[Depends(is_not_locked)],
993
996
  operation_id="createSpans",
994
997
  summary="Create spans",
995
998
  description=(
@@ -2,7 +2,7 @@ import gzip
2
2
  import zlib
3
3
  from typing import Any, Literal, Optional
4
4
 
5
- from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query
5
+ from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Query
6
6
  from google.protobuf.message import DecodeError
7
7
  from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
8
8
  ExportTraceServiceRequest,
@@ -24,6 +24,7 @@ from strawberry.relay import GlobalID
24
24
  from phoenix.db import models
25
25
  from phoenix.db.insertion.helpers import as_kv
26
26
  from phoenix.db.insertion.types import Precursors
27
+ from phoenix.server.authorization import is_not_locked
27
28
  from phoenix.server.bearer_auth import PhoenixUser
28
29
  from phoenix.server.dml_event import TraceAnnotationInsertEvent
29
30
  from phoenix.trace.otel import decode_otlp_span
@@ -37,6 +38,7 @@ router = APIRouter(tags=["traces"])
37
38
 
38
39
  @router.post(
39
40
  "/traces",
41
+ dependencies=[Depends(is_not_locked)],
40
42
  operation_id="addTraces",
41
43
  summary="Send traces",
42
44
  responses=add_errors_to_responses(
@@ -160,6 +162,7 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
160
162
 
161
163
  @router.post(
162
164
  "/trace_annotations",
165
+ dependencies=[Depends(is_not_locked)],
163
166
  operation_id="annotateTraces",
164
167
  summary="Create trace annotations",
165
168
  responses=add_errors_to_responses(
@@ -43,7 +43,7 @@ from phoenix.server.api.routers.v1.utils import (
43
43
  add_errors_to_responses,
44
44
  )
45
45
  from phoenix.server.api.types.node import from_global_id_with_expected_type
46
- from phoenix.server.authorization import require_admin
46
+ from phoenix.server.authorization import is_not_locked, require_admin
47
47
 
48
48
  logger = logging.getLogger(__name__)
49
49
 
@@ -194,7 +194,7 @@ async def list_users(
194
194
  HTTP_422_UNPROCESSABLE_ENTITY,
195
195
  ]
196
196
  ),
197
- dependencies=[Depends(require_admin)],
197
+ dependencies=[Depends(require_admin), Depends(is_not_locked)],
198
198
  response_model_by_alias=True,
199
199
  response_model_exclude_unset=True,
200
200
  response_model_exclude_defaults=True,
@@ -19,7 +19,6 @@ from typing_extensions import Annotated, TypeAlias
19
19
  import phoenix.trace.schemas as trace_schema
20
20
  from phoenix.db import models
21
21
  from phoenix.server.api.context import Context
22
- from phoenix.server.api.dataloaders import types as dataloader_types
23
22
  from phoenix.server.api.helpers.dataset_helpers import (
24
23
  get_dataset_example_input,
25
24
  get_dataset_example_output,
@@ -829,104 +828,6 @@ class Span(Node):
829
828
  for entry in entries
830
829
  ]
831
830
 
832
- @strawberry.field
833
- async def cumulative_cost_summary(self, info: Info[Context, None]) -> Optional[SpanCostSummary]:
834
- max_depth = 0
835
- descendant_rowids = await info.context.data_loaders.span_descendants.load(
836
- (self.span_rowid, max_depth)
837
- )
838
- span_costs = await info.context.data_loaders.span_cost_by_span.load_many(
839
- (self.span_rowid, *descendant_rowids)
840
- )
841
- total_cost: Optional[float] = None
842
- total_tokens: Optional[float] = None
843
- prompt_cost: Optional[float] = None
844
- prompt_tokens: Optional[float] = None
845
- completion_cost: Optional[float] = None
846
- completion_tokens: Optional[float] = None
847
- for span_cost in span_costs:
848
- if span_cost is None:
849
- continue
850
- if span_cost.total_cost is not None:
851
- total_cost = (total_cost or 0) + span_cost.total_cost
852
- if span_cost.total_tokens is not None:
853
- total_tokens = (total_tokens or 0) + span_cost.total_tokens
854
- if span_cost.prompt_cost is not None:
855
- prompt_cost = (prompt_cost or 0) + span_cost.prompt_cost
856
- if span_cost.prompt_tokens is not None:
857
- prompt_tokens = (prompt_tokens or 0) + span_cost.prompt_tokens
858
- if span_cost.completion_cost is not None:
859
- completion_cost = (completion_cost or 0) + span_cost.completion_cost
860
- if span_cost.completion_tokens is not None:
861
- completion_tokens = (completion_tokens or 0) + span_cost.completion_tokens
862
- return SpanCostSummary(
863
- prompt=CostBreakdown(
864
- tokens=prompt_tokens,
865
- cost=prompt_cost,
866
- ),
867
- completion=CostBreakdown(
868
- tokens=completion_tokens,
869
- cost=completion_cost,
870
- ),
871
- total=CostBreakdown(
872
- tokens=total_tokens,
873
- cost=total_cost,
874
- ),
875
- )
876
-
877
- @strawberry.field
878
- async def cumulative_cost_detail_summary_entries(
879
- self, info: Info[Context, None]
880
- ) -> list[SpanCostDetailSummaryEntry]:
881
- max_depth = 0
882
- descendant_rowids = await info.context.data_loaders.span_descendants.load(
883
- (self.span_rowid, max_depth)
884
- )
885
- entry_lists = (
886
- await info.context.data_loaders.span_cost_detail_summary_entries_by_span.load_many(
887
- (self.span_rowid, *descendant_rowids)
888
- )
889
- )
890
-
891
- TokenType: TypeAlias = str
892
- IsPrompt: TypeAlias = bool
893
- grouped_entries: dict[
894
- IsPrompt, dict[TokenType, list[dataloader_types.SpanCostDetailSummaryEntry]]
895
- ] = {}
896
-
897
- for entries in entry_lists:
898
- for entry in entries:
899
- is_prompt = entry.is_prompt
900
- token_type = entry.token_type
901
- if is_prompt not in grouped_entries:
902
- grouped_entries[is_prompt] = {}
903
- if token_type not in grouped_entries[is_prompt]:
904
- grouped_entries[is_prompt][token_type] = []
905
- grouped_entries[is_prompt][token_type].append(entry)
906
-
907
- result: list[SpanCostDetailSummaryEntry] = []
908
- for is_prompt in (True, False):
909
- entries_by_token_type = grouped_entries[is_prompt]
910
- for token_type, entries in sorted(entries_by_token_type.items()):
911
- cost: Optional[float] = None
912
- tokens: Optional[float] = None
913
- for entry in entries:
914
- if entry.value.cost is not None:
915
- cost = (cost or 0) + entry.value.cost
916
- if entry.value.tokens is not None:
917
- tokens = (tokens or 0) + entry.value.tokens
918
- result.append(
919
- SpanCostDetailSummaryEntry(
920
- token_type=token_type,
921
- is_prompt=is_prompt,
922
- value=CostBreakdown(
923
- tokens=tokens,
924
- cost=cost,
925
- ),
926
- )
927
- )
928
- return result
929
-
930
831
 
931
832
  def _hide_embedding_vectors(attributes: Mapping[str, Any]) -> Mapping[str, Any]:
932
833
  if not (
phoenix/server/app.py CHANGED
@@ -16,17 +16,20 @@ from typing import (
16
16
  Any,
17
17
  NamedTuple,
18
18
  Optional,
19
+ Protocol,
19
20
  TypedDict,
20
21
  Union,
21
22
  cast,
22
23
  )
23
24
  from urllib.parse import urlparse
24
25
 
26
+ import grpc
25
27
  import strawberry
26
28
  from fastapi import APIRouter, Depends, FastAPI
27
29
  from fastapi.middleware.cors import CORSMiddleware
28
30
  from fastapi.utils import is_body_allowed_for_status_code
29
31
  from grpc.aio import ServerInterceptor
32
+ from grpc_interceptor import AsyncServerInterceptor
30
33
  from sqlalchemy import select
31
34
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
32
35
  from starlette.datastructures import URL, Secret
@@ -44,7 +47,7 @@ from starlette.types import Scope, StatefulLifespan
44
47
  from strawberry.extensions import SchemaExtension
45
48
  from strawberry.fastapi import GraphQLRouter
46
49
  from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
47
- from typing_extensions import TypeAlias
50
+ from typing_extensions import TypeAlias, override
48
51
 
49
52
  import phoenix.trace.v1 as pb
50
53
  from phoenix.config import (
@@ -134,6 +137,7 @@ from phoenix.server.api.routers import (
134
137
  from phoenix.server.api.routers.v1 import REST_API_VERSION
135
138
  from phoenix.server.api.schema import build_graphql_schema
136
139
  from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
140
+ from phoenix.server.daemons.db_disk_usage_monitor import DbDiskUsageMonitor
137
141
  from phoenix.server.daemons.generative_model_store import GenerativeModelStore
138
142
  from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
139
143
  from phoenix.server.dml_event import DmlEvent
@@ -382,19 +386,16 @@ async def version() -> PlainTextResponse:
382
386
  return PlainTextResponse(f"{phoenix_version}")
383
387
 
384
388
 
385
- DB_MUTEX: Optional[asyncio.Lock] = None
386
-
387
-
388
389
  def _db(
389
- engine: AsyncEngine, bypass_lock: bool = False
390
- ) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
390
+ engine: AsyncEngine,
391
+ ) -> Callable[[Optional[asyncio.Lock]], AbstractAsyncContextManager[AsyncSession]]:
391
392
  Session = async_sessionmaker(engine, expire_on_commit=False)
392
393
 
393
394
  @contextlib.asynccontextmanager
394
- async def factory() -> AsyncIterator[AsyncSession]:
395
+ async def factory(lock: Optional[asyncio.Lock] = None) -> AsyncIterator[AsyncSession]:
395
396
  async with contextlib.AsyncExitStack() as stack:
396
- if not bypass_lock and DB_MUTEX:
397
- await stack.enter_async_context(DB_MUTEX)
397
+ if lock:
398
+ await stack.enter_async_context(lock)
398
399
  yield await stack.enter_async_context(Session.begin())
399
400
 
400
401
  return factory
@@ -523,6 +524,7 @@ def _lifespan(
523
524
  trace_data_sweeper: Optional[TraceDataSweeper],
524
525
  span_cost_calculator: SpanCostCalculator,
525
526
  generative_model_store: GenerativeModelStore,
527
+ db_disk_usage_monitor: DbDiskUsageMonitor,
526
528
  token_store: Optional[TokenStore] = None,
527
529
  tracer_provider: Optional["TracerProvider"] = None,
528
530
  enable_prometheus: bool = False,
@@ -530,14 +532,14 @@ def _lifespan(
530
532
  shutdown_callbacks: Iterable[_Callback] = (),
531
533
  read_only: bool = False,
532
534
  scaffolder_config: Optional[ScaffolderConfig] = None,
535
+ grpc_interceptors: Iterable[AsyncServerInterceptor] = (),
533
536
  ) -> StatefulLifespan[FastAPI]:
534
537
  @contextlib.asynccontextmanager
535
538
  async def lifespan(_: FastAPI) -> AsyncIterator[dict[str, Any]]:
536
539
  for callback in startup_callbacks:
537
540
  if isinstance((res := callback()), Awaitable):
538
541
  await res
539
- global DB_MUTEX
540
- DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
542
+ db.lock = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
541
543
  async with AsyncExitStack() as stack:
542
544
  (
543
545
  enqueue,
@@ -551,7 +553,7 @@ def _lifespan(
551
553
  tracer_provider=tracer_provider,
552
554
  enable_prometheus=enable_prometheus,
553
555
  token_store=token_store,
554
- interceptors=user_grpc_interceptors(),
556
+ interceptors=user_grpc_interceptors() + list(grpc_interceptors),
555
557
  )
556
558
  await stack.enter_async_context(grpc_server)
557
559
  await stack.enter_async_context(dml_event_handler)
@@ -559,6 +561,7 @@ def _lifespan(
559
561
  await stack.enter_async_context(trace_data_sweeper)
560
562
  await stack.enter_async_context(span_cost_calculator)
561
563
  await stack.enter_async_context(generative_model_store)
564
+ await stack.enter_async_context(db_disk_usage_monitor)
562
565
  if scaffolder_config:
563
566
  scaffolder = Scaffolder(
564
567
  config=scaffolder_config,
@@ -826,6 +829,34 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
826
829
  return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)
827
830
 
828
831
 
832
+ class _HasDbStatus(Protocol):
833
+ @property
834
+ def should_not_insert_or_update(self) -> bool: ...
835
+
836
+
837
+ class DbDiskUsageInterceptor(AsyncServerInterceptor):
838
+ def __init__(self, db: _HasDbStatus) -> None:
839
+ self._db = db
840
+
841
+ @override
842
+ async def intercept(
843
+ self,
844
+ method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]],
845
+ request_or_iterator: Any,
846
+ context: grpc.aio.ServicerContext,
847
+ method_name: str,
848
+ ) -> Any:
849
+ if (
850
+ method_name.endswith("trace.v1.TraceService/Export")
851
+ and self._db.should_not_insert_or_update
852
+ ):
853
+ await context.abort(
854
+ grpc.StatusCode.RESOURCE_EXHAUSTED,
855
+ "Database disk usage threshold exceeded",
856
+ )
857
+ return await method(request_or_iterator, context)
858
+
859
+
829
860
  def create_app(
830
861
  db: DbSessionFactory,
831
862
  export_path: Path,
@@ -971,6 +1002,8 @@ def create_app(
971
1002
  from phoenix.server.prometheus import PrometheusMiddleware
972
1003
 
973
1004
  middlewares.append(Middleware(PrometheusMiddleware))
1005
+ grpc_interceptors: list[AsyncServerInterceptor] = []
1006
+ grpc_interceptors.append(DbDiskUsageInterceptor(db))
974
1007
  app = FastAPI(
975
1008
  title="Arize-Phoenix REST API",
976
1009
  version=REST_API_VERSION,
@@ -982,6 +1015,8 @@ def create_app(
982
1015
  trace_data_sweeper=trace_data_sweeper,
983
1016
  span_cost_calculator=span_cost_calculator,
984
1017
  generative_model_store=generative_model_store,
1018
+ db_disk_usage_monitor=DbDiskUsageMonitor(db, email_sender),
1019
+ grpc_interceptors=grpc_interceptors,
985
1020
  token_store=token_store,
986
1021
  tracer_provider=tracer_provider,
987
1022
  enable_prometheus=enable_prometheus,
@@ -51,3 +51,12 @@ def require_admin(request: Request) -> None:
51
51
  status_code=fastapi_status.HTTP_403_FORBIDDEN,
52
52
  detail="Only admin or system users can perform this action.",
53
53
  )
54
+
55
+
56
+ def is_not_locked(request: Request) -> None:
57
+ if request.app.state.db.should_not_insert_or_update:
58
+ raise HTTPException(
59
+ status_code=fastapi_status.HTTP_507_INSUFFICIENT_STORAGE,
60
+ detail="Operations that insert or update database "
61
+ "records are currently not allowed.",
62
+ )
@@ -7,10 +7,10 @@ from typing import Any, Optional, cast
7
7
  import grpc
8
8
  from fastapi import HTTPException, Request, WebSocket, WebSocketException
9
9
  from grpc_interceptor import AsyncServerInterceptor
10
- from grpc_interceptor.exceptions import Unauthenticated
11
10
  from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
12
11
  from starlette.requests import HTTPConnection
13
12
  from starlette.status import HTTP_401_UNAUTHORIZED
13
+ from typing_extensions import override
14
14
 
15
15
  from phoenix import config
16
16
  from phoenix.auth import (
@@ -100,16 +100,19 @@ class PhoenixSystemUser(PhoenixUser):
100
100
 
101
101
 
102
102
  class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
103
+ @override
103
104
  async def intercept(
104
105
  self,
105
- method: Callable[[Any, grpc.ServicerContext], Awaitable[Any]],
106
+ method: Callable[[Any, grpc.aio.ServicerContext], Awaitable[Any]],
106
107
  request_or_iterator: Any,
107
- context: grpc.ServicerContext,
108
+ context: grpc.aio.ServicerContext,
108
109
  method_name: str,
109
110
  ) -> Any:
110
- for datum in context.invocation_metadata():
111
- if datum.key.lower() == "authorization":
112
- scheme, _, token = datum.value.partition(" ")
111
+ for key, value in context.invocation_metadata() or ():
112
+ if key.lower() == "authorization":
113
+ if isinstance(value, bytes):
114
+ value = value.decode("utf-8")
115
+ scheme, _, token = value.partition(" ")
113
116
  if scheme.lower() != "bearer" or not token:
114
117
  break
115
118
  if (
@@ -119,16 +122,16 @@ class ApiKeyInterceptor(HasTokenStore, AsyncServerInterceptor):
119
122
  ):
120
123
  return await method(request_or_iterator, context)
121
124
  claims = await self._token_store.read(Token(token))
122
- if not (isinstance(claims, UserClaimSet) and isinstance(claims.subject, UserId)):
125
+ if (
126
+ not (
127
+ isinstance(claims, (ApiKeyClaims, AccessTokenClaims))
128
+ and isinstance(claims.subject, UserId)
129
+ )
130
+ or claims.status is not ClaimSetStatus.VALID
131
+ ):
123
132
  break
124
- if not isinstance(claims, (ApiKeyClaims, AccessTokenClaims)):
125
- raise Unauthenticated(details="Invalid token")
126
- if claims.status is ClaimSetStatus.EXPIRED:
127
- raise Unauthenticated(details="Expired token")
128
- if claims.status is ClaimSetStatus.VALID:
129
- return await method(request_or_iterator, context)
130
- raise Unauthenticated()
131
- raise Unauthenticated()
133
+ return await method(request_or_iterator, context)
134
+ await context.abort(grpc.StatusCode.UNAUTHENTICATED)
132
135
 
133
136
 
134
137
  async def is_authenticated(
@@ -109,7 +109,7 @@ class CostModelLookup:
109
109
  model
110
110
  for model in self._models
111
111
  if (not model.start_time or model.start_time <= start_time)
112
- and model.name_pattern.match(model_name)
112
+ and model.name_pattern.search(model_name)
113
113
  ]
114
114
  if not candidates:
115
115
  return None